Compare commits

...

19 commits

Author SHA1 Message Date
ayoungbloodrbx
6b33251b89
Sync to upstream/release/667 (#1754)
After a very auspicious release last week, we have a new bevy of changes
for you!

## What's Changed

### Deprecated Attribute

This release includes an implementation of the `@deprecated` attribute
proposed in [this
RFC](https://rfcs.luau.org/syntax-attribute-functions-deprecated.html).
It relies on the new type solver to propagate deprecation information
from function and method AST nodes to the corresponding type objects.
These objects are queried by a linter pass when it encounters local,
global, or indexed variables, to issue deprecation warnings. Uses of
deprecated functions and methods in recursion are ignored. To support
deprecation of class methods, the parser has been extended to allow
attribute declarations on class methods. The implementation does not
support parameters, so it is not currently possible for users to
customize deprecation messages.

### General

- Add a limit for normalization of function types.

### New Type Solver

- Fix type checker to accept numbers as concat operands (Fixes #1671).
- Fix user-defined type functions failing when used inside type
aliases/nested calls (Fixes #1738, Fixes #1679).
- Improve constraint generation for overloaded functions (in part thanks
to @vvatheus in #1694).
- Improve type inference for indexers on table literals, especially when
passing table literals directly as a function call argument.
- Equate regular error type and intersection with a negation of an error
type.
- Avoid swapping types in 2-part union when RHS is optional.
- Use simplification when doing `~nil` refinements.
- `len<>` now works on metatables without `__len` function.

### AST

- Retain source information for `AstTypeUnion` and
`AstTypeIntersection`.

### Transpiler

- Print attributes on functions.

### Parser

- Allow types in indexers to begin with string literals by @jackdotink
in #1750.

### Autocomplete

- Evaluate user-defined type functions in ill-formed source code to
provide autocomplete.
- Fix the start location of functions that have attributes.
- Implement better fragment selection.

### Internal Contributors

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Sora Kanosue <skanosue@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

**Full Changelog**:
https://github.com/luau-lang/luau/compare/0.666...0.667

---------

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Menarul Alam <malam@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
2025-03-28 16:15:46 -07:00
Jack
12dac2f1f4
fix parsing string union indexers (#1750)
Today this code results in a syntax error: `type foo = { ["bar" |
"baz"]: number }`. This is odd and I believe it is a bug. I have fixed
this so that it is now parsed as an indexer field with a union type.
This change should not affect the way any code is parsed today, and
allow types in indexers to begin with string literals.

---------

Co-authored-by: ariel <aweiss@hey.com>
2025-03-25 16:18:22 -07:00
Matheus
2621488abe
Fix singleton parameters in overloaded functions (#1694)
- Fixes #1691 
- Fixes #1589

---------

Co-authored-by: Math <175355178+maffeus@users.noreply.github.com>
Co-authored-by: ariel <aaronweiss@roblox.com>
Co-authored-by: Matheus <175355178+m4fh@users.noreply.github.com>
Co-authored-by: ariel <aweiss@hey.com>
2025-03-24 09:27:13 -07:00
Varun Saini
5f42e63a73
Sync to upstream/release/666 (#1747)
Another week, another release. Happy spring! 🌷 

## New Type Solver

- Add typechecking and autocomplete support for user-defined type
functions!
- Improve the display of type paths, making type mismatch errors far
more human-readable.
- Enhance various aspects of the `index` type function: support function
type metamethods, fix crashes involving cyclic metatables, and forward
`any` types through the type function.
- Fix incorrect subtyping results involving the `buffer` type.
- Fix crashes related to typechecking anonymous functions in nonstrict
mode.

## AST

- Retain source information for type packs, functions, and type
functions.
- Introduce `AstTypeOptional` to differentiate `T?` from `T | nil` in
the AST.
- Prevent the transpiler from advancing before tokens when the AST has
errors.

## Autocomplete

- Introduce demand-based cloning and better module isolation for
fragment autocomplete, leading to a substantial speedup in performance.
- Guard against recursive unions in `autocompleteProps`.

## Miscellaneous

- #1720 (thank you!)

## Internal Contributors

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2025-03-21 14:43:00 -07:00
Hunter Goldstein
e0b55a9cb1
Sync to upstream/release/665 (#1732)
Hello all! Another week, another Luau release!

# Change to `lua_setuserdatametatable`

This release fixes #1710: `lua_setuserdatametatable` is being changed so
that it _only_ operates on the top of the stack: the `idx` parameter is
being removed. Prior to this, `lua_setuserdatametable` would set the
metatable of the value in the stack at `idx`, but _always_ pop the top
of the stack. The old behavior is available in this release as
`lua_setuserdatametatable_DEPRECATED`.

# General

This release exposes a generalized implementation of require-by-string's
autocomplete logic. `FileResolver` can now be optionally constructed
with a `RequireSuggester`, which provides an interface for converting a
given module to a `RequireNode`. Consumers of this new API implement a
`RequireNode` to define how modules are represented in their embedded
context, and the new API manages the logic specific to
require-by-string, including providing suggestions for require aliases.
This enhancement moves toward integrating require-by-string's semantics
into the language itself, rather than merely providing a specification
for community members to implement themselves.

# New Type Solver
* Fixed a source of potential `Luau::follow detected a Type cycle`
internal compiler exceptions when assigning a global to itself.
* Fixed an issue whereby `*no-refine*` (a type which should not be
visible at the end of type checking) was not being properly elided,
causing inference of class-like tables to become unreadable / induce
crashes in autocomplete.
* Fixed a case of incomplete constraint solving when performing basic
math in a loop

# Fragment Autocomplete
* Fixed several crashes related to not properly filling in scope
information for the fragments
* Fixed a source of memory corruption by isolating the return type of a
fragment when it is type checked.
* Improved performance by opting not to clone persistent types for the
fragment (e.g.: built in types)
 
# Internal Contributors
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
Co-authored-by: Menarul Alam <malam@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
2025-03-14 13:11:24 -07:00
Kostadin
b0c3f40b0c
Add #include <stdint.h> to fix building with gcc 15 (#1720)
With gcc 15, the C++ Standard Library no longer includes other headers
that were internally used by the library. In Luau's case the missing
header is `<stdint.h>`

Downstream Gentoo bug: https://bugs.gentoo.org/938122
Signed-off-by: Kostadin Shishmanov <kostadinshishmanov@protonmail.com>

---------

Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com>
2025-03-10 06:02:09 -07:00
vegorov-rbx
de9f5d6eb6
Sync to upstream/release/664 (#1715)
As always, a weekly Luau update!
This week we have further improvements to new type solver, fixing a few
of the popular issues reported. The fragment autocomplete is even more
stable and we believe it's ready for broader use.

Aside from that we have a few general fixes/improvements:
* Fixed data race when multi-threaded typechecking is used, appearing as
a random crash at the end of typechecking
* AST data is now available from `Luau::Module`

## New Type Solver

* Fixed type refinements made by function calls which could attach `nil`
as an option of a type before (Fixes #1528)
* Improved bidirectional typechecking in tables (Fixes #1596)
* Fixed normalization of negated types
* `getmetatable()` on `any` type should no longer report an error

## Fragment Autocomplete

* Fixed auto-complete suggestions being provided inside multiline
comments
* Fixed an assertion failure that could happen when old type solver was
used
* Fixed issues with missing suggestions when multiple statements are on
the same line
* Fixed memory safety issues

## Internal Contributors

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2025-03-07 10:07:27 -08:00
ariel
640ebbc0a5
Sync to upstream/release/663 (#1699)
Hey folks, another week means another Luau release! This one features a
number of bug fixes in the New Type Solver including improvements to
user-defined type functions and a bunch of work to untangle some of the
outstanding issues we've been seeing with constraint solving not
completing in real world use. We're also continuing to make progress on
crashes and other problems that affect the stability of fragment
autocomplete, as we work towards delivering consistent, low-latency
autocomplete for any editor environment.

## New Type Solver

- Fix a bug in user-defined type functions where `print` would
incorrectly insert `\1` a number of times.
- Fix a bug where attempting to refine an optional generic with a type
test will cause a false positive type error (fixes #1666)
- Fix a bug where the `refine` type family would not skip over
`*no-refine*` discriminants (partial resolution for #1424)
- Fix a constraint solving bug where recursive function calls would
consistently produce cyclic constraints leading to incomplete or
inaccurate type inference.
- Implement `readparent` and `writeparent` for class types in
user-defined type functions, replacing the incorrectly included `parent`
method.
- Add initial groundwork (under a debug flag) for eager free type
generalization, moving us towards further improvements to constraint
solving incomplete errors.

## Fragment Autocomplete

- Ease up some assertions to improve stability of mixed-mode use of the
two type solvers (i.e. using Fragment Autocomplete on a type graph
originally produced by the old type solver)
- Resolve a bug with type compatibility checks causing internal compiler
errors in autocomplete.

## Lexer and Parser

- Improve the accuracy of the roundtrippable AST parsing mode by
correctly placing closing parentheses on type groupings.
- Add a getter for `offset` in the Lexer by @aduermael in #1688
- Add a second entry point to the parser to parse an expression,
`parseExpr`

## Internal Contributors

Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: James McNellis <jmcnellis@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
Co-authored-by: Menarul Alam <malam@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Vighnesh <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2025-02-28 14:42:30 -08:00
Adrien Duermael
6a21dba682
Lexer: add offset getter (#1688)
Added a getter for the Lexer's private `offset` to track its cursor
position in the buffer.
This helps me index tokens by buffer address in a project where I'm
rendering Luau code with [Dear ImGui](https://github.com/ocornut/imgui).
Would love this merged so I can use official Luau releases again!

Co-authored-by: Adrian Duermael <adrien@cu.bzh>
2025-02-24 16:25:57 -08:00
ramdoys
c1e2f650db
chore: update applicable .lua files to .luau (#1560)
Updates all of the APPLICABLE .lua files to the .luau file ending.

---------

Co-authored-by: ramdoys <ramdoysdirect@gmail.com>
2025-02-21 14:29:20 -08:00
vegorov-rbx
c2e72666d9
Sync to upstream/release/662 (#1681)
## What's new

This update brings improvements to the new type solver, roundtrippable
AST parsing mode and closes multiple issues reported in this repository.

* `require` dependency tracing for non-string requires now supports `()`
groups in expressions and types as well as an ability to type annotate a
value with a `typeof` of a different module path
* Fixed rare misaligned memory access in Compiler/Typechecker on 32 bit
platforms (Closes #1572)

## New Solver

* Fixed crash/UB in subtyping of type packs (Closes #1449)
* Fixed incorrect type errors when calling `debug.info` (Closes #1534
and Resolves #966)
* Fixed incorrect boolean and string equality comparison result in
user-defined type functions (Closes #1623)
* Fixed incorrect class types being produced in user-defined type
functions when multiple classes share the same name (Closes #1639)
* Improved bidirectional typechecking for table literals containing
elements that have not been solved yet (Closes #1641)

## Roundtrippable AST

* Added source information for `AstStatTypeAlias`
* Fixed an issue with `AstTypeGroup` node (added in #1643) producing
invalid AST json. Contained type is now named 'inner' instead of 'type'
* Fixed end location of the `do ... end` statement

---

Internal Contributors:

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2025-02-21 10:24:12 -08:00
ariel
86bf4ae42d
Revise some of the copytext in markdown files. (#1677)
We have a bunch of small grammatical nits and slightly awkward phrasings
present in our existing markdown files. This is a small pass over all of
them to fix those, and to provide some additional updated information
that has become more clear over time (like additional users of Luau, or
our leveraging something akin to the Minus 100 Points philosophy for
evaluating RFCs).

---------

Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
2025-02-20 13:56:00 -08:00
karl-police
29a5198055
Fix the order of the arguments being inputted into result.append in print() from the Type Function Runtime (#1676)
---

Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com>
2025-02-20 12:32:47 -08:00
vegorov-rbx
14ccae9b44
Update minimal Ubuntu version in workflows from 20.04 to 22.04 (#1670)
Ubuntu 20.04 LTS is reaching the end of its standard support period.
GitHub runners are going to stop providing images for it soon.

This PR updates all workflows to use 22.04 instead of 20.04.
While this will impact the compatibility of prebuilt release binaries on
20.04, users of those systems can always build Luau from source.

`coverage` workflow is also updating from clang++10 to clang++ (defaults
to clang++14 today).
Line coverage with this switch drops from 89.65% to 87.45%, function
coverage remains 91.26%.
As a bonus, we now get branch coverage information.
2025-02-20 10:32:33 -08:00
nothing
9c198413ec
Analysis: Make typeof on a type userdata return "type" (#1568)
This change introduces a flag (`LuauUserTypeFunTypeofReturnsType`) that,
when enabled, sets `__type` on the type userdata's metatable to "type".
This behaviour was described in the user-defined type function RFC
(https://rfcs.luau.org/user-defined-type-functions.html), but seems to
have been missed; this change implements that behaviour.

Currently this does not change `typeof(t) == 'type'` emitting an unknown
type warning as I don't trust myself to implement it due to my general
lack of C++ knowledge; this can be worked on later.
2025-02-17 09:36:52 -08:00
dependabot[bot]
bd4fe54f4b
Bump jinja2 from 3.1.4 to 3.1.5 in /tools/fuzz (#1607) 2025-02-17 08:58:48 -08:00
Vighnesh-V
77642988c2
Sync to upstream/release/661 (#1664)
# General
- Additional logging enabled for fragment autocomplete.

## Roundtrippable AST
- Add a new `AstNode`, `AstGenericType`
- Retain source information for `AstExprTypeAssertion`
## New Type Solver
- New non-strict mode will report unknown symbol errors, e.g
```
foo = 5
local wrong1 = foob <- issue warning
```
- Fixed a bug where new non-strict mode failed to visit large parts of
the program.
- We now infer the types of unnanotated local variables in statements
with multiple assignments, e.g. `local x: "a", y, z = "a", f()`
- Fixed bugs in constraint dispatch ordering.
- Fixed a bug that caused an infinite loop between `Subtyping`,
`OverloadResolution`, and `Type Function Reduction`, by preventing calls
to `Type Function Reduction` being re-entrant.
- Fixed a crash in bidirectional type inference caused by asserting read
and write properties on a type that was readonly.

## Runtime
- Fix a stack overflow caused by `luaL_checkstack` consuming stack space
even if the function fails to reserve memory.
- Using '%c' with a 0 value in Luau string.format will append a '\0'.
Resolves https://github.com/luau-lang/luau/issues/1650

## Miscellaneous
- Miscellaneous small bugfixes for the new solver.

**Full Changelog**:
https://github.com/luau-lang/luau/compare/0.660...0.661
----
Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
Co-authored-by: Menarul Alam <malam@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
2025-02-14 13:57:46 -08:00
Aviral Goel
2e61028cba
Sync to upstream/release/660 (#1643)
# General

This release introduces initial work on a Roundtrippable AST for Luau,
and numerous fixes to the new type solver, runtime, and fragment
autocomplete.

## Roundtrippable AST

To support tooling around source code transformations, we are extending
the parser to retain source information so that we can re-emit the
initial source code exactly as the author wrote it. We have made
numerous changes to the Transpiler, added new AST types such as
`AstTypeGroup`, and added source information to AST nodes such as
`AstExprInterpString`, `AstExprIfElse`, `AstTypeTable`,
`AstTypeReference`, `AstTypeSingletonString`, and `AstTypeTypeof`.

## New Type Solver

* Implement `setmetatable` and `getmetatable` type functions.
* Fix handling of nested and recursive union type functions to prevent
the solver from getting stuck.
* Free types in both old and new solver now have an upper and lower
bound to resolve mixed mode usage of the solvers in fragment
autocomplete.
* Fix infinite recursion during normalization of cyclic tables.
* Add normalization support for intersections of subclasses with negated
superclasses.

## Runtime
* Fix compilation error in Luau buffer bit operations for big-endian
machines.

## Miscellaneous
* Add test and bugfixes to fragment autocomplete.
* Fixed `clang-tidy` warnings in `Simplify.cpp`.

**Full Changelog**:
https://github.com/luau-lang/luau/compare/0.659...0.660

---

Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Talha Pathan <tpathan@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
Co-authored-by: Menarul Alam <malam@roblox.com>
2025-02-07 16:17:11 -08:00
menarulalam
f8a1e0129d
Sync to upstream/release/659 (#1637)
## What's Changed

General performance improvements and bug fixes. `lua_clonetable` was
added too.

### General

## Runtime
- Improvements were made to Luau's performance, including a
`lua_clonetable` function and optimizations to string caching. Buffer
read/write operations were optimized for big-endian machines.
## New Solver
- Crashes related to duplicate keys in table literals, fragment AC
crashes, and potential hash collisions in the StringCache.
- We now handle user-defined type functions as opaque and track interior
free table types.
## Require By String
- Require-by-string path resolution was simplified.

**Full Changelog**:
https://github.com/luau-lang/luau/compare/0.658...0.659

---

Co-authored-by: Ariel Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Yohoo Lin <yohoo@roblox.com>

---------

Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
2025-01-31 18:58:36 -08:00
233 changed files with 19171 additions and 6866 deletions

View file

@ -46,9 +46,9 @@ jobs:
- name: make cli - name: make cli
run: | run: |
make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time
./luau tests/conformance/assert.lua ./luau tests/conformance/assert.luau
./luau-analyze tests/conformance/assert.lua ./luau-analyze tests/conformance/assert.luau
./luau-compile tests/conformance/assert.lua ./luau-compile tests/conformance/assert.luau
windows: windows:
runs-on: windows-latest runs-on: windows-latest
@ -81,12 +81,12 @@ jobs:
shell: bash # necessary for fail-fast shell: bash # necessary for fail-fast
run: | run: |
cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time
Debug/luau tests/conformance/assert.lua Debug/luau tests/conformance/assert.luau
Debug/luau-analyze tests/conformance/assert.lua Debug/luau-analyze tests/conformance/assert.luau
Debug/luau-compile tests/conformance/assert.lua Debug/luau-compile tests/conformance/assert.luau
coverage: coverage:
runs-on: ubuntu-20.04 # needed for clang++-10 to avoid gcov compatibility issues runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: install - name: install
@ -94,7 +94,7 @@ jobs:
sudo apt install llvm sudo apt install llvm
- name: make coverage - name: make coverage
run: | run: |
CXX=clang++-10 make -j2 config=coverage native=1 coverage CXX=clang++ make -j2 config=coverage native=1 coverage
- name: upload coverage - name: upload coverage
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3
with: with:

View file

@ -29,8 +29,8 @@ jobs:
build: build:
needs: ["create-release"] needs: ["create-release"]
strategy: strategy:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility matrix: # not using ubuntu-latest to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}} name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}} runs-on: ${{matrix.os.version}}
steps: steps:

View file

@ -13,8 +13,8 @@ on:
jobs: jobs:
build: build:
strategy: strategy:
matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility matrix: # not using ubuntu-latest to improve compatibility
os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}]
name: ${{matrix.os.name}} name: ${{matrix.os.name}}
runs-on: ${{matrix.os.version}} runs-on: ${{matrix.os.version}}
steps: steps:

View file

@ -1,148 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/AstQuery.h"
#include "Luau/Config.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Scope.h"
#include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include "Luau/TypePack.h"
#include "Luau/TypeArena.h"
#include <mutex>
#include <string>
#include <vector>
#include <optional>
namespace Luau
{
class AstStat;
class ParseError;
struct TypeError;
struct LintWarning;
struct GlobalTypes;
struct ModuleResolver;
struct ParseResult;
struct DcrLogger;
struct TelemetryTypePair
{
std::string annotatedType;
std::string inferredType;
};
struct AnyTypeSummary
{
TypeArena arena;
AstStatBlock* rootSrc = nullptr;
DenseHashSet<TypeId> seenTypeFamilyInstances{nullptr};
int recursionCount = 0;
std::string root;
int strictCount = 0;
DenseHashMap<const void*, bool> seen{nullptr};
AnyTypeSummary();
void traverse(const Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes);
std::pair<bool, TypeId> checkForAnyCast(const Scope* scope, AstExprTypeAssertion* expr);
bool containsAny(TypePackId typ);
bool containsAny(TypeId typ);
bool isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
bool hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
TypeId checkForFamilyInhabitance(const TypeId instance, Location location);
TypeId lookupType(const AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
TypePackId reconstructTypePack(const AstArray<AstExpr*> exprs, const Module* module, NotNull<BuiltinTypes> builtinTypes);
DenseHashSet<TypeId> seenTypeFunctionInstances{nullptr};
TypeId lookupAnnotation(AstType* annotation, const Module* module, NotNull<BuiltinTypes> builtintypes);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation, const Module* module);
TypeId checkForTypeFunctionInhabitance(const TypeId instance, const Location location);
enum Pattern : uint64_t
{
Casts,
FuncArg,
FuncRet,
FuncApp,
VarAnnot,
VarAny,
TableProp,
Alias,
Assign,
TypePk
};
struct TypeInfo
{
Pattern code;
std::string node;
TelemetryTypePair type;
explicit TypeInfo(Pattern code, std::string node, TelemetryTypePair type);
};
struct FindReturnAncestry final : public AstVisitor
{
AstNode* currNode{nullptr};
AstNode* stat{nullptr};
Position rootEnd;
bool found = false;
explicit FindReturnAncestry(AstNode* stat, Position rootEnd);
bool visit(AstType* node) override;
bool visit(AstNode* node) override;
bool visit(AstStatFunction* node) override;
bool visit(AstStatLocalFunction* node) override;
};
std::vector<TypeInfo> typeInfo;
/**
* Fabricates a scope that is a child of another scope.
* @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null.
*/
const Scope* childScope(const AstNode* node, const Scope* parent);
std::optional<AstExpr*> matchRequire(const AstExprCall& call);
AstNode* getNode(AstStatBlock* root, AstNode* node);
const Scope* findInnerMostScope(const Location location, const Module* module);
const AstNode* findAstAncestryAtLocation(const AstStatBlock* root, AstNode* node);
void visit(const Scope* scope, AstStat* stat, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull<BuiltinTypes> builtinTypes);
void visit(const Scope* scope, AstStatError* error, const Module* module, NotNull<BuiltinTypes> builtinTypes);
};
} // namespace Luau

View file

@ -70,6 +70,7 @@ Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol
void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName);
std::string getBuiltinDefinitionSource(); std::string getBuiltinDefinitionSource();
std::string getTypeFunctionDefinitionSource();
void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName);
void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding); void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding);

View file

@ -4,6 +4,7 @@
#include <Luau/NotNull.h> #include <Luau/NotNull.h>
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/Scope.h"
#include <unordered_map> #include <unordered_map>
@ -26,13 +27,22 @@ struct CloneState
* while `clone` will make a deep copy of the entire type and its every component. * while `clone` will make a deep copy of the entire type and its every component.
* *
* Be mindful about which behavior you actually _want_. * Be mindful about which behavior you actually _want_.
*
* Persistent types are not cloned as an optimization.
* If a type is cloned in order to mutate it, 'ignorePersistent' has to be set
*/ */
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState); TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false);
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState);
TypePackId cloneIncremental(TypePackId tp, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes);
TypeId cloneIncremental(TypeId typeId, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes);
TypeFun cloneIncremental(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes);
Binding cloneIncremental(const Binding& binding, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes);
} // namespace Luau } // namespace Luau

View file

@ -50,6 +50,7 @@ struct GeneralizationConstraint
TypeId sourceType; TypeId sourceType;
std::vector<TypeId> interiorTypes; std::vector<TypeId> interiorTypes;
bool hasDeprecatedAttribute = false;
}; };
// variables ~ iterate iterator // variables ~ iterate iterator
@ -109,6 +110,21 @@ struct FunctionCheckConstraint
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes; NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
}; };
// table_check expectedType exprType
//
// If `expectedType` is a table type and `exprType` is _also_ a table type,
// propogate the member types of `expectedType` into the types of `exprType`.
// This is used to implement bidirectional inference on table assignment.
// Also see: FunctionCheckConstraint.
struct TableCheckConstraint
{
TypeId expectedType;
TypeId exprType;
AstExprTable* table = nullptr;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
};
// prim FreeType ExpectedType PrimitiveType // prim FreeType ExpectedType PrimitiveType
// //
// FreeType is bounded below by the singleton type and above by PrimitiveType // FreeType is bounded below by the singleton type and above by PrimitiveType
@ -273,7 +289,8 @@ using ConstraintV = Variant<
UnpackConstraint, UnpackConstraint,
ReduceConstraint, ReduceConstraint,
ReducePackConstraint, ReducePackConstraint,
EqualityConstraint>; EqualityConstraint,
TableCheckConstraint>;
struct Constraint struct Constraint
{ {

View file

@ -96,6 +96,9 @@ struct ConstraintGenerator
// will enqueue them during solving. // will enqueue them during solving.
std::vector<ConstraintPtr> unqueuedConstraints; std::vector<ConstraintPtr> unqueuedConstraints;
// Map a function's signature scope back to its signature type.
DenseHashMap<Scope*, TypeId> scopeToFunction{nullptr};
// The private scope of type aliases for which the type parameters belong to. // The private scope of type aliases for which the type parameters belong to.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr}; DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
@ -114,12 +117,15 @@ struct ConstraintGenerator
// Needed to register all available type functions for execution at later stages. // Needed to register all available type functions for execution at later stages.
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
DenseHashMap<const AstStatTypeFunction*, ScopePtr> astTypeFunctionEnvironmentScopes{nullptr};
// Needed to resolve modules to make 'require' import types properly. // Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver; NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE. // Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice; const NotNull<InternalErrorReporter> ice;
ScopePtr globalScope; ScopePtr globalScope;
ScopePtr typeFunctionScope;
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope; std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles; std::vector<RequireCycle> requireCycles;
@ -137,6 +143,7 @@ struct ConstraintGenerator
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> ice, NotNull<InternalErrorReporter> ice,
const ScopePtr& globalScope, const ScopePtr& globalScope,
const ScopePtr& typeFunctionScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
DcrLogger* logger, DcrLogger* logger,
NotNull<DataFlowGraph> dfg, NotNull<DataFlowGraph> dfg,
@ -392,7 +399,7 @@ private:
**/ **/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics( std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope, const ScopePtr& scope,
AstArray<AstGenericType> generics, AstArray<AstGenericType*> generics,
bool useCache = false, bool useCache = false,
bool addTypes = true bool addTypes = true
); );
@ -409,7 +416,7 @@ private:
**/ **/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks( std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope, const ScopePtr& scope,
AstArray<AstGenericTypePack> packs, AstArray<AstGenericTypePack*> packs,
bool useCache = false, bool useCache = false,
bool addTypes = true bool addTypes = true
); );

View file

@ -88,6 +88,7 @@ struct ConstraintSolver
NotNull<TypeFunctionRuntime> typeFunctionRuntime; NotNull<TypeFunctionRuntime> typeFunctionRuntime;
// The entire set of constraints that the solver is trying to resolve. // The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints; std::vector<NotNull<Constraint>> constraints;
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction;
NotNull<Scope> rootScope; NotNull<Scope> rootScope;
ModuleName currentModuleName; ModuleName currentModuleName;
@ -118,6 +119,9 @@ struct ConstraintSolver
// A mapping from free types to the number of unresolved constraints that mention them. // A mapping from free types to the number of unresolved constraints that mention them.
DenseHashMap<TypeId, size_t> unresolvedConstraints{{}}; DenseHashMap<TypeId, size_t> unresolvedConstraints{{}};
std::unordered_map<NotNull<const Constraint>, DenseHashSet<TypeId>> maybeMutatedFreeTypes;
std::unordered_map<TypeId, DenseHashSet<const Constraint*>> mutatedFreeTypeToConstraint;
// Irreducible/uninhabited type functions or type pack functions. // Irreducible/uninhabited type functions or type pack functions.
DenseHashSet<const void*> uninhabitedTypeFunctions{{}}; DenseHashSet<const void*> uninhabitedTypeFunctions{{}};
@ -142,6 +146,7 @@ struct ConstraintSolver
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction,
ModuleName moduleName, ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles, std::vector<RequireCycle> requireCycles,
@ -169,6 +174,8 @@ struct ConstraintSolver
bool isDone() const; bool isDone() const;
private: private:
void generalizeOneType(TypeId ty);
/** /**
* Bind a type variable to another type. * Bind a type variable to another type.
* *
@ -201,6 +208,7 @@ public:
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TableCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint); bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
@ -357,7 +365,7 @@ public:
* @returns a non-free type that generalizes the argument, or `std::nullopt` if one * @returns a non-free type that generalizes the argument, or `std::nullopt` if one
* does not exist * does not exist
*/ */
std::optional<TypeId> generalizeFreeType(NotNull<Scope> scope, TypeId type, bool avoidSealingTables = false); std::optional<TypeId> generalizeFreeType(NotNull<Scope> scope, TypeId type);
/** /**
* Checks the existing set of constraints to see if there exist any that contain * Checks the existing set of constraints to see if there exist any that contain
@ -421,10 +429,7 @@ public:
ToStringOptions opts; ToStringOptions opts;
void fillInDiscriminantTypes( void fillInDiscriminantTypes(NotNull<const Constraint> constraint, const std::vector<std::optional<TypeId>>& discriminantTypes);
NotNull<const Constraint> constraint,
const std::vector<std::optional<TypeId>>& discriminantTypes
);
}; };
void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts); void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts);

View file

@ -38,8 +38,6 @@ struct DataFlowGraph
DefId getDef(const AstExpr* expr) const; DefId getDef(const AstExpr* expr) const;
// Look up the definition optionally, knowing it may not be present. // Look up the definition optionally, knowing it may not be present.
std::optional<DefId> getDefOptional(const AstExpr* expr) const; std::optional<DefId> getDefOptional(const AstExpr* expr) const;
// Look up for the rvalue def for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;
DefId getDef(const AstLocal* local) const; DefId getDef(const AstLocal* local) const;
@ -66,10 +64,6 @@ private:
// All keys in this maps are really only statements that ambiently declares a symbol. // All keys in this maps are really only statements that ambiently declares a symbol.
DenseHashMap<const AstStat*, const Def*> declaredDefs{nullptr}; DenseHashMap<const AstStat*, const Def*> declaredDefs{nullptr};
// Compound assignments are in a weird situation where the local being assigned to is also being used at its
// previous type implicitly in an rvalue position. This map provides the previous binding.
DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr}; DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
friend struct DataFlowGraphBuilder; friend struct DataFlowGraphBuilder;
}; };
@ -221,8 +215,8 @@ private:
void visitTypeList(AstTypeList l); void visitTypeList(AstTypeList l);
void visitGenerics(AstArray<AstGenericType> g); void visitGenerics(AstArray<AstGenericType*> g);
void visitGenericPacks(AstArray<AstGenericTypePack> g); void visitGenericPacks(AstArray<AstGenericTypePack*> g);
}; };
} // namespace Luau } // namespace Luau

View file

@ -105,7 +105,7 @@ private:
std::vector<Id> storage; std::vector<Id> storage;
}; };
template <typename L> template<typename L>
using Node = EqSat::Node<L>; using Node = EqSat::Node<L>;
using EType = EqSat::Language< using EType = EqSat::Language<
@ -149,7 +149,7 @@ using EType = EqSat::Language<
struct StringCache struct StringCache
{ {
Allocator allocator; Allocator allocator;
DenseHashMap<size_t, StringId> strings{{}}; DenseHashMap<std::string_view, StringId> strings{{}};
std::vector<std::string_view> views; std::vector<std::string_view> views;
StringId add(std::string_view s); StringId add(std::string_view s);

View file

@ -1,8 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include <string> #include <memory>
#include <optional> #include <optional>
#include <string>
#include <vector> #include <vector>
namespace Luau namespace Luau
@ -32,15 +33,71 @@ struct ModuleInfo
bool optional = false; bool optional = false;
}; };
struct RequireAlias
{
std::string alias; // Unprefixed alias name (no leading `@`).
std::vector<std::string> tags = {};
};
struct RequireNode
{
virtual ~RequireNode() {}
// Get the path component representing this node.
virtual std::string getPathComponent() const = 0;
// Get the displayed user-facing label for this node, defaults to getPathComponent()
virtual std::string getLabel() const
{
return getPathComponent();
}
// Get tags to attach to this node's RequireSuggestion (defaults to none).
virtual std::vector<std::string> getTags() const
{
return {};
}
// TODO: resolvePathToNode() can ultimately be replaced with a call into
// require-by-string's path resolution algorithm. This will first require
// generalizing that algorithm to work with a virtual file system.
virtual std::unique_ptr<RequireNode> resolvePathToNode(const std::string& path) const = 0;
// Get children of this node, if any (if this node represents a directory).
virtual std::vector<std::unique_ptr<RequireNode>> getChildren() const = 0;
// A list of the aliases available to this node.
virtual std::vector<RequireAlias> getAvailableAliases() const = 0;
};
struct RequireSuggestion struct RequireSuggestion
{ {
std::string label; std::string label;
std::string fullPath; std::string fullPath;
std::vector<std::string> tags;
}; };
using RequireSuggestions = std::vector<RequireSuggestion>; using RequireSuggestions = std::vector<RequireSuggestion>;
struct RequireSuggester
{
virtual ~RequireSuggester() {}
std::optional<RequireSuggestions> getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& pathString) const;
protected:
virtual std::unique_ptr<RequireNode> getNode(const ModuleName& name) const = 0;
private:
std::optional<RequireSuggestions> getRequireSuggestionsImpl(const ModuleName& requirer, const std::optional<std::string>& path) const;
};
struct FileResolver struct FileResolver
{ {
FileResolver() = default;
FileResolver(std::shared_ptr<RequireSuggester> requireSuggester)
: requireSuggester(std::move(requireSuggester))
{
}
virtual ~FileResolver() {} virtual ~FileResolver() {}
virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0; virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0;
@ -60,10 +117,10 @@ struct FileResolver
return std::nullopt; return std::nullopt;
} }
virtual std::optional<RequireSuggestions> getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& pathString) const // Make non-virtual when removing FFlagLuauImproveRequireByStringAutocomplete.
{ virtual std::optional<RequireSuggestions> getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& pathString) const;
return std::nullopt;
} std::shared_ptr<RequireSuggester> requireSuggester;
}; };
struct NullFileResolver : FileResolver struct NullFileResolver : FileResolver

View file

@ -15,6 +15,28 @@ namespace Luau
{ {
struct FrontendOptions; struct FrontendOptions;
enum class FragmentAutocompleteWaypoint
{
ParseFragmentEnd,
CloneModuleStart,
CloneModuleEnd,
DfgBuildEnd,
CloneAndSquashScopeStart,
CloneAndSquashScopeEnd,
ConstraintSolverStart,
ConstraintSolverEnd,
TypecheckFragmentEnd,
AutocompleteEnd,
COUNT,
};
class IFragmentAutocompleteReporter
{
public:
virtual void reportWaypoint(FragmentAutocompleteWaypoint) = 0;
virtual void reportFragmentString(std::string_view) = 0;
};
enum class FragmentTypeCheckStatus enum class FragmentTypeCheckStatus
{ {
SkipAutocomplete, SkipAutocomplete,
@ -27,6 +49,8 @@ struct FragmentAutocompleteAncestryResult
std::vector<AstLocal*> localStack; std::vector<AstLocal*> localStack;
std::vector<AstNode*> ancestry; std::vector<AstNode*> ancestry;
AstStat* nearestStatement = nullptr; AstStat* nearestStatement = nullptr;
AstStatBlock* parentBlock = nullptr;
Location fragmentSelectionRegion;
}; };
struct FragmentParseResult struct FragmentParseResult
@ -37,6 +61,7 @@ struct FragmentParseResult
AstStat* nearestStatement = nullptr; AstStat* nearestStatement = nullptr;
std::vector<Comment> commentLocations; std::vector<Comment> commentLocations;
std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>(); std::unique_ptr<Allocator> alloc = std::make_unique<Allocator>();
Position scopePos{0, 0};
}; };
struct FragmentTypeCheckResult struct FragmentTypeCheckResult
@ -54,10 +79,29 @@ struct FragmentAutocompleteResult
AutocompleteResult acResults; AutocompleteResult acResults;
}; };
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); struct FragmentRegion
{
Location fragmentLocation;
AstStat* nearestStatement = nullptr; // used for tests
AstStatBlock* parentBlock = nullptr; // used for scope detection
};
FragmentRegion getFragmentRegion(AstStatBlock* root, const Position& cursorPosition);
FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* stale, const Position& cursorPos, AstStatBlock* lastGoodParse);
FragmentAutocompleteAncestryResult findAncestryForFragmentParse_DEPRECATED(AstStatBlock* root, const Position& cursorPos);
std::optional<FragmentParseResult> parseFragment_DEPRECATED(
AstStatBlock* root,
AstNameTable* names,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
);
std::optional<FragmentParseResult> parseFragment( std::optional<FragmentParseResult> parseFragment(
const SourceModule& srcModule, AstStatBlock* stale,
AstStatBlock* mostRecentParse,
AstNameTable* names,
std::string_view src, std::string_view src,
const Position& cursorPos, const Position& cursorPos,
std::optional<Position> fragmentEndPosition std::optional<Position> fragmentEndPosition
@ -69,7 +113,9 @@ std::pair<FragmentTypeCheckStatus, FragmentTypeCheckResult> typecheckFragment(
const Position& cursorPos, const Position& cursorPos,
std::optional<FrontendOptions> opts, std::optional<FrontendOptions> opts,
std::string_view src, std::string_view src,
std::optional<Position> fragmentEndPosition std::optional<Position> fragmentEndPosition,
AstStatBlock* recentParse = nullptr,
IFragmentAutocompleteReporter* reporter = nullptr
); );
FragmentAutocompleteResult fragmentAutocomplete( FragmentAutocompleteResult fragmentAutocomplete(
@ -79,8 +125,71 @@ FragmentAutocompleteResult fragmentAutocomplete(
Position cursorPosition, Position cursorPosition,
std::optional<FrontendOptions> opts, std::optional<FrontendOptions> opts,
StringCompletionCallback callback, StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition = std::nullopt std::optional<Position> fragmentEndPosition = std::nullopt,
AstStatBlock* recentParse = nullptr,
IFragmentAutocompleteReporter* reporter = nullptr
); );
enum class FragmentAutocompleteStatus
{
Success,
FragmentTypeCheckFail,
InternalIce
};
struct FragmentAutocompleteStatusResult
{
FragmentAutocompleteStatus status;
std::optional<FragmentAutocompleteResult> result;
};
struct FragmentContext
{
std::string_view newSrc;
const ParseResult& freshParse;
std::optional<FrontendOptions> opts;
std::optional<Position> DEPRECATED_fragmentEndPosition;
IFragmentAutocompleteReporter* reporter = nullptr;
};
/**
* @brief Attempts to compute autocomplete suggestions from the fragment context.
*
* This function computes autocomplete suggestions using outdated frontend typechecking data
* by patching the fragment context of the new script source content.
*
* @param frontend The Luau Frontend data structure, which may contain outdated typechecking data.
*
* @param moduleName The name of the target module, specifying which script the caller wants to request autocomplete for.
*
* @param cursorPosition The position in the script where the caller wants to trigger autocomplete.
*
* @param context The fragment context that this API will use to patch the outdated typechecking data.
*
* @param stringCompletionCB A callback function that provides autocomplete suggestions for string contexts.
*
* @return
* The status indicating whether `fragmentAutocomplete` ran successfully or failed, along with the reason for failure.
* Also includes autocomplete suggestions if the status is successful.
*
* @usage
* FragmentAutocompleteStatusResult acStatusResult;
* if (shouldFragmentAC)
* acStatusResult = Luau::tryFragmentAutocomplete(...);
*
* if (acStatusResult.status != Successful)
* {
* frontend.check(moduleName, options);
* acStatusResult.acResult = Luau::autocomplete(...);
* }
* return convertResultWithContext(acStatusResult.acResult);
*/
FragmentAutocompleteStatusResult tryFragmentAutocomplete(
Frontend& frontend,
const ModuleName& moduleName,
Position cursorPosition,
FragmentContext context,
StringCompletionCallback stringCompletionCB
);
} // namespace Luau } // namespace Luau

View file

@ -7,9 +7,9 @@
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h" #include "Luau/RequireTracer.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/Set.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/AnyTypeSummary.h"
#include <mutex> #include <mutex>
#include <string> #include <string>
@ -31,8 +31,8 @@ struct ModuleResolver;
struct ParseResult; struct ParseResult;
struct HotComment; struct HotComment;
struct BuildQueueItem; struct BuildQueueItem;
struct BuildQueueWorkState;
struct FrontendCancellationToken; struct FrontendCancellationToken;
struct AnyTypeSummary;
struct LoadDefinitionFileResult struct LoadDefinitionFileResult
{ {
@ -56,13 +56,32 @@ struct SourceNode
return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule;
} }
bool hasInvalidModuleDependency(bool forAutocomplete) const
{
return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency;
}
void setInvalidModuleDependency(bool value, bool forAutocomplete)
{
if (forAutocomplete)
invalidModuleDependencyForAutocomplete = value;
else
invalidModuleDependency = value;
}
ModuleName name; ModuleName name;
std::string humanReadableName; std::string humanReadableName;
DenseHashSet<ModuleName> requireSet{{}}; DenseHashSet<ModuleName> requireSet{{}};
std::vector<std::pair<ModuleName, Location>> requireLocations; std::vector<std::pair<ModuleName, Location>> requireLocations;
Set<ModuleName> dependents{{}};
bool dirtySourceModule = true; bool dirtySourceModule = true;
bool dirtyModule = true; bool dirtyModule = true;
bool dirtyModuleForAutocomplete = true; bool dirtyModuleForAutocomplete = true;
bool invalidModuleDependency = true;
bool invalidModuleDependencyForAutocomplete = true;
double autocompleteLimitsMult = 1.0; double autocompleteLimitsMult = 1.0;
}; };
@ -117,7 +136,7 @@ struct FrontendModuleResolver : ModuleResolver
std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override;
std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override;
void setModule(const ModuleName& moduleName, ModulePtr module); bool setModule(const ModuleName& moduleName, ModulePtr module);
void clearModules(); void clearModules();
private: private:
@ -151,9 +170,13 @@ struct Frontend
// Parse and typecheck module graph // Parse and typecheck module graph
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const;
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr); void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
void traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree);
/** Borrow a pointer into the SourceModule cache. /** Borrow a pointer into the SourceModule cache.
* *
* Returns nullptr if we don't have it. This could mean that the script * Returns nullptr if we don't have it. This could mean that the script
@ -192,6 +215,11 @@ struct Frontend
std::function<void(std::function<void()> task)> executeTask = {}, std::function<void(std::function<void()> task)> executeTask = {},
std::function<bool(size_t done, size_t total)> progress = {} std::function<bool(size_t done, size_t total)> progress = {}
); );
std::vector<ModuleName> checkQueuedModules_DEPRECATED(
std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {},
std::function<bool(size_t done, size_t total)> progress = {}
);
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
std::vector<ModuleName> getRequiredScripts(const ModuleName& name); std::vector<ModuleName> getRequiredScripts(const ModuleName& name);
@ -227,6 +255,9 @@ private:
void checkBuildQueueItem(BuildQueueItem& item); void checkBuildQueueItem(BuildQueueItem& item);
void checkBuildQueueItems(std::vector<BuildQueueItem>& items); void checkBuildQueueItems(std::vector<BuildQueueItem>& items);
void recordItemResult(const BuildQueueItem& item); void recordItemResult(const BuildQueueItem& item);
void performQueueItemTask(std::shared_ptr<BuildQueueWorkState> state, size_t itemPos);
void sendQueueItemTask(std::shared_ptr<BuildQueueWorkState> state, size_t itemPos);
void sendQueueCycleItemTask(std::shared_ptr<BuildQueueWorkState> state);
static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config); static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config);
@ -272,6 +303,7 @@ ModulePtr check(
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope, const ScopePtr& globalScope,
const ScopePtr& typeFunctionScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options, FrontendOptions options,
TypeCheckLimits limits TypeCheckLimits limits
@ -286,6 +318,7 @@ ModulePtr check(
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope, const ScopePtr& globalScope,
const ScopePtr& typeFunctionScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options, FrontendOptions options,
TypeCheckLimits limits, TypeCheckLimits limits,

View file

@ -12,8 +12,8 @@ std::optional<TypeId> generalize(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> bakedTypes, NotNull<DenseHashSet<TypeId>> cachedTypes,
TypeId ty, TypeId ty
/* avoid sealing tables*/ bool avoidSealingTables = false
); );
} }

View file

@ -19,7 +19,9 @@ struct GlobalTypes
TypeArena globalTypes; TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules ScopePtr globalScope; // shared by all modules
ScopePtr globalTypeFunctionScope; // shared by all modules
}; };
} // namespace Luau } // namespace Luau

View file

@ -8,7 +8,6 @@
#include "Luau/ParseResult.h" #include "Luau/ParseResult.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/AnyTypeSummary.h"
#include "Luau/DataFlowGraph.h" #include "Luau/DataFlowGraph.h"
#include <memory> #include <memory>
@ -21,8 +20,13 @@ LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection)
namespace Luau namespace Luau
{ {
using LogLuauProc = void (*)(std::string_view, std::string_view);
extern LogLuauProc logLuau;
void setLogLuau(LogLuauProc ll);
void resetLogLuauProc();
struct Module; struct Module;
struct AnyTypeSummary;
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
@ -80,13 +84,10 @@ struct Module
TypeArena interfaceTypes; TypeArena interfaceTypes;
TypeArena internalTypes; TypeArena internalTypes;
// Summary of Ast Nodes that either contain
// user annotated anys or typechecker inferred anys
AnyTypeSummary ats{};
// Scopes and AST types refer to parse data, so we need to keep that alive // Scopes and AST types refer to parse data, so we need to keep that alive
std::shared_ptr<Allocator> allocator; std::shared_ptr<Allocator> allocator;
std::shared_ptr<AstNameTable> names; std::shared_ptr<AstNameTable> names;
AstStatBlock* root = nullptr;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty std::vector<std::pair<Location, ScopePtr>> scopes; // never empty

View file

@ -1,9 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DataFlowGraph.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/DataFlowGraph.h"
namespace Luau namespace Luau
{ {

View file

@ -31,13 +31,4 @@ struct OrderedMap
} }
}; };
struct QuantifierResult
{
TypeId result;
OrderedMap<TypeId, TypeId> insertedGenerics;
OrderedMap<TypePackId, TypePackId> insertedGenericPacks;
};
std::optional<QuantifierResult> quantify(TypeArena* arena, TypeId ty, Scope* scope);
} // namespace Luau } // namespace Luau

View file

@ -53,6 +53,7 @@ struct Proposition
{ {
const RefinementKey* key; const RefinementKey* key;
TypeId discriminantTy; TypeId discriminantTy;
bool implicitFromCall;
}; };
template<typename T> template<typename T>
@ -69,6 +70,7 @@ struct RefinementArena
RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs);
RefinementId equivalence(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs);
RefinementId proposition(const RefinementKey* key, TypeId discriminantTy); RefinementId proposition(const RefinementKey* key, TypeId discriminantTy);
RefinementId implicitProposition(const RefinementKey* key, TypeId discriminantTy);
private: private:
TypedAllocator<Refinement> allocator; TypedAllocator<Refinement> allocator;

View file

@ -11,14 +11,12 @@
namespace Luau namespace Luau
{ {
class AstStat; class AstNode;
class AstExpr;
class AstStatBlock; class AstStatBlock;
struct AstLocal;
struct RequireTraceResult struct RequireTraceResult
{ {
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr}; DenseHashMap<const AstNode*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requireList; std::vector<std::pair<ModuleName, Location>> requireList;
}; };

View file

@ -35,7 +35,7 @@ struct Scope
explicit Scope(TypePackId returnType); // root scope explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root ScopePtr parent; // null for the root
// All the children of this scope. // All the children of this scope.
std::vector<NotNull<Scope>> children; std::vector<NotNull<Scope>> children;
@ -59,6 +59,8 @@ struct Scope
std::optional<TypeId> lookup(Symbol sym) const; std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookupUnrefinedType(DefId def) const; std::optional<TypeId> lookupUnrefinedType(DefId def) const;
std::optional<TypeId> lookupRValueRefinementType(DefId def) const;
std::optional<TypeId> lookup(DefId def) const; std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(DefId def); std::optional<std::pair<TypeId, Scope*>> lookupEx(DefId def);
std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym); std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym);
@ -71,6 +73,7 @@ struct Scope
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const;
std::optional<std::pair<Symbol, Binding>> linearSearchForBindingPair(const std::string& name, bool traverseScopeChain) const;
RefinementMap refinements; RefinementMap refinements;

View file

@ -19,10 +19,10 @@ struct SimplifyResult
DenseHashSet<TypeId> blockedTypes; DenseHashSet<TypeId> blockedTypes;
}; };
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant); SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts); SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant); SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right);
enum class Relation enum class Relation
{ {

View file

@ -6,12 +6,15 @@
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeFwd.h"
#include <vector>
namespace Luau namespace Luau
{ {
struct TypeArena; struct TypeArena;
struct BuiltinTypes; struct BuiltinTypes;
struct Unifier2; struct Unifier2;
struct Subtyping;
class AstExpr; class AstExpr;
TypeId matchLiteralType( TypeId matchLiteralType(
@ -20,6 +23,7 @@ TypeId matchLiteralType(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Unifier2> unifier, NotNull<Unifier2> unifier,
NotNull<Subtyping> subtyping,
TypeId expectedType, TypeId expectedType,
TypeId exprType, TypeId exprType,
const AstExpr* expr, const AstExpr* expr,

View file

@ -65,11 +65,10 @@ T* getMutable(PendingTypePack* pending)
// Log of what TypeIds we are rebinding, to be committed later. // Log of what TypeIds we are rebinding, to be committed later.
struct TxnLog struct TxnLog
{ {
explicit TxnLog(bool useScopes = false) explicit TxnLog()
: typeVarChanges(nullptr) : typeVarChanges(nullptr)
, typePackChanges(nullptr) , typePackChanges(nullptr)
, ownedSeen() , ownedSeen()
, useScopes(useScopes)
, sharedSeen(&ownedSeen) , sharedSeen(&ownedSeen)
{ {
} }

View file

@ -19,7 +19,6 @@
#include <optional> #include <optional>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
@ -38,6 +37,15 @@ struct Constraint;
struct Subtyping; struct Subtyping;
struct TypeChecker2; struct TypeChecker2;
enum struct Polarity : uint8_t
{
None = 0b000,
Positive = 0b001,
Negative = 0b010,
Mixed = 0b011,
Unknown = 0b100,
};
/** /**
* There are three kinds of type variables: * There are three kinds of type variables:
* - `Free` variables are metavariables, which stand for unconstrained types. * - `Free` variables are metavariables, which stand for unconstrained types.
@ -69,12 +77,16 @@ using Name = std::string;
// A free type is one whose exact shape has yet to be fully determined. // A free type is one whose exact shape has yet to be fully determined.
struct FreeType struct FreeType
{ {
// New constructors
explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound);
// This one got promoted to explicit
explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound);
// Old constructors
explicit FreeType(TypeLevel level); explicit FreeType(TypeLevel level);
explicit FreeType(Scope* scope); explicit FreeType(Scope* scope);
FreeType(Scope* scope, TypeLevel level); FreeType(Scope* scope, TypeLevel level);
FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
int index; int index;
TypeLevel level; TypeLevel level;
Scope* scope = nullptr; Scope* scope = nullptr;
@ -306,7 +318,8 @@ struct MagicFunctionTypeCheckContext
struct MagicFunction struct MagicFunction
{ {
virtual std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0; virtual std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) = 0;
// Callback to allow custom typechecking of builtin function calls whose argument types // Callback to allow custom typechecking of builtin function calls whose argument types
// will only be resolved after constraint solving. For example, the arguments to string.format // will only be resolved after constraint solving. For example, the arguments to string.format
@ -391,6 +404,7 @@ struct FunctionType
// this flag is used as an optimization to exit early from procedures that manipulate free or generic types. // this flag is used as an optimization to exit early from procedures that manipulate free or generic types.
bool hasNoFreeOrGenericTypes = false; bool hasNoFreeOrGenericTypes = false;
bool isCheckedFunction = false; bool isCheckedFunction = false;
bool isDeprecatedFunction = false;
}; };
enum class TableState enum class TableState
@ -617,7 +631,6 @@ struct UserDefinedFunctionData
AstStatTypeFunction* definition = nullptr; AstStatTypeFunction* definition = nullptr;
DenseHashMap<Name, std::pair<AstStatTypeFunction*, size_t>> environment{""}; DenseHashMap<Name, std::pair<AstStatTypeFunction*, size_t>> environment{""};
DenseHashMap<Name, AstStatTypeFunction*> environment_DEPRECATED{""};
}; };
/** /**

View file

@ -32,9 +32,13 @@ struct TypeArena
TypeId addTV(Type&& tv); TypeId addTV(Type&& tv);
TypeId freshType(TypeLevel level); TypeId freshType(NotNull<BuiltinTypes> builtins, TypeLevel level);
TypeId freshType(Scope* scope); TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope);
TypeId freshType(Scope* scope, TypeLevel level); TypeId freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level);
TypeId freshType_DEPRECATED(TypeLevel level);
TypeId freshType_DEPRECATED(Scope* scope);
TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level);
TypePackId freshTypePack(Scope* scope); TypePackId freshTypePack(Scope* scope);

View file

@ -13,6 +13,8 @@
#include "Luau/TypeOrPack.h" #include "Luau/TypeOrPack.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
LUAU_FASTFLAG(LuauImproveTypePathsInErrors)
namespace Luau namespace Luau
{ {
@ -38,18 +40,29 @@ struct Reasonings
std::string toString() std::string toString()
{ {
if (FFlag::LuauImproveTypePathsInErrors && reasons.empty())
return "";
// DenseHashSet ordering is entirely undefined, so we want to // DenseHashSet ordering is entirely undefined, so we want to
// sort the reasons here to achieve a stable error // sort the reasons here to achieve a stable error
// stringification. // stringification.
std::sort(reasons.begin(), reasons.end()); std::sort(reasons.begin(), reasons.end());
std::string allReasons; std::string allReasons = FFlag::LuauImproveTypePathsInErrors ? "\nthis is because " : "";
bool first = true; bool first = true;
for (const std::string& reason : reasons) for (const std::string& reason : reasons)
{ {
if (first) if (FFlag::LuauImproveTypePathsInErrors)
first = false; {
if (reasons.size() > 1)
allReasons += "\n\t * ";
}
else else
allReasons += "\n\t"; {
if (first)
first = false;
else
allReasons += "\n\t";
}
allReasons += reason; allReasons += reason;
} }
@ -63,7 +76,7 @@ void check(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Simplifier> simplifier, NotNull<Simplifier> simplifier,
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<UnifierSharedState> sharedState, NotNull<UnifierSharedState> unifierState,
NotNull<TypeCheckLimits> limits, NotNull<TypeCheckLimits> limits,
DcrLogger* logger, DcrLogger* logger,
const SourceModule& sourceModule, const SourceModule& sourceModule,
@ -116,14 +129,14 @@ private:
std::optional<StackPusher> pushStack(AstNode* node); std::optional<StackPusher> pushStack(AstNode* node);
void checkForInternalTypeFunction(TypeId ty, Location location); void checkForInternalTypeFunction(TypeId ty, Location location);
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location); TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location);
TypePackId lookupPack(AstExpr* expr); TypePackId lookupPack(AstExpr* expr) const;
TypeId lookupType(AstExpr* expr); TypeId lookupType(AstExpr* expr);
TypeId lookupAnnotation(AstType* annotation); TypeId lookupAnnotation(AstType* annotation);
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation); std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation) const;
TypeId lookupExpectedType(AstExpr* expr); TypeId lookupExpectedType(AstExpr* expr) const;
TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena); TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) const;
TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena); TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena);
Scope* findInnermostScope(Location location); Scope* findInnermostScope(Location location) const;
void visit(AstStat* stat); void visit(AstStat* stat);
void visit(AstStatIf* ifStatement); void visit(AstStatIf* ifStatement);
void visit(AstStatWhile* whileStatement); void visit(AstStatWhile* whileStatement);
@ -160,7 +173,7 @@ private:
void visit(AstExprVarargs* expr); void visit(AstExprVarargs* expr);
void visitCall(AstExprCall* call); void visitCall(AstExprCall* call);
void visit(AstExprCall* call); void visit(AstExprCall* call);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty); std::optional<TypeId> tryStripUnionFromNil(TypeId ty) const;
TypeId stripFromNilAndReport(TypeId ty, const Location& location); TypeId stripFromNilAndReport(TypeId ty, const Location& location);
void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy); void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy);
void visit(AstExprIndexName* indexName, ValueContext context); void visit(AstExprIndexName* indexName, ValueContext context);
@ -175,7 +188,7 @@ private:
void visit(AstExprInterpString* interpString); void visit(AstExprInterpString* interpString);
void visit(AstExprError* expr); void visit(AstExprError* expr);
TypeId flattenPack(TypePackId pack); TypeId flattenPack(TypePackId pack);
void visitGenerics(AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks); void visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks);
void visit(AstType* ty); void visit(AstType* ty);
void visit(AstTypeReference* ty); void visit(AstTypeReference* ty);
void visit(AstTypeTable* table); void visit(AstTypeTable* table);

View file

@ -48,6 +48,9 @@ struct TypeFunctionRuntime
// Evaluation of type functions should only be performed in the absence of parse errors in the source module // Evaluation of type functions should only be performed in the absence of parse errors in the source module
bool allowEvaluation = true; bool allowEvaluation = true;
// Root scope in which the type function operates in, set up by ConstraintGenerator
ScopePtr rootScope;
// Output created by 'print' function // Output created by 'print' function
std::vector<std::string> messages; std::vector<std::string> messages;
@ -174,6 +177,7 @@ struct FunctionGraphReductionResult
DenseHashSet<TypePackId> blockedPacks{nullptr}; DenseHashSet<TypePackId> blockedPacks{nullptr};
DenseHashSet<TypeId> reducedTypes{nullptr}; DenseHashSet<TypeId> reducedTypes{nullptr};
DenseHashSet<TypePackId> reducedPacks{nullptr}; DenseHashSet<TypePackId> reducedPacks{nullptr};
DenseHashSet<TypeId> irreducibleTypes{nullptr};
}; };
/** /**
@ -241,6 +245,9 @@ struct BuiltinTypeFunctions
TypeFunction indexFunc; TypeFunction indexFunc;
TypeFunction rawgetFunc; TypeFunction rawgetFunc;
TypeFunction setmetatableFunc;
TypeFunction getmetatableFunc;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const; void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
}; };

View file

@ -3,6 +3,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <optional> #include <optional>
#include <string> #include <string>
@ -215,9 +216,13 @@ struct TypeFunctionClassType
std::optional<TypeFunctionTypeId> metatable; // metaclass? std::optional<TypeFunctionTypeId> metatable; // metaclass?
std::optional<TypeFunctionTypeId> parent; // this was mistaken, and we should actually be keeping separate read/write types here.
std::optional<TypeFunctionTypeId> parent_DEPRECATED;
std::string name; std::optional<TypeFunctionTypeId> readParent;
std::optional<TypeFunctionTypeId> writeParent;
TypeId classTy;
}; };
struct TypeFunctionGenericType struct TypeFunctionGenericType

View file

@ -28,20 +28,12 @@ struct TypeFunctionRuntimeBuilderState
{ {
NotNull<TypeFunctionContext> ctx; NotNull<TypeFunctionContext> ctx;
// Mapping of class name to ClassType
// Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function
// Using this invariant, whenever a ClassType is serialized, we can put it into this map
// whenever a ClassType is deserialized, we can use this map to return the corresponding value
DenseHashMap<std::string, TypeId> classesSerialized{{}};
// List of errors that occur during serialization/deserialization // List of errors that occur during serialization/deserialization
// At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process // At every iteration of serialization/deserialization, if this list.size() != 0, we halt the process
std::vector<std::string> errors{}; std::vector<std::string> errors{};
TypeFunctionRuntimeBuilderState(NotNull<TypeFunctionContext> ctx) TypeFunctionRuntimeBuilderState(NotNull<TypeFunctionContext> ctx)
: ctx(ctx) : ctx(ctx)
, classesSerialized({})
, errors({})
{ {
} }
}; };

View file

@ -399,8 +399,8 @@ private:
const ScopePtr& scope, const ScopePtr& scope,
std::optional<TypeLevel> levelOpt, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstNode& node,
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericType*>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames, const AstArray<AstGenericTypePack*>& genericPackNames,
bool useCache = false bool useCache = false
); );

View file

@ -42,9 +42,19 @@ struct Property
/// element. /// element.
struct Index struct Index
{ {
enum class Variant
{
Pack,
Union,
Intersection
};
/// The 0-based index to use for the lookup. /// The 0-based index to use for the lookup.
size_t index; size_t index;
/// The sort of thing we're indexing from, this is used in stringifying the type path for errors.
Variant variant;
bool operator==(const Index& other) const; bool operator==(const Index& other) const;
}; };
@ -205,6 +215,9 @@ using Path = TypePath::Path;
/// terribly clear to end users of the Luau type system. /// terribly clear to end users of the Luau type system.
std::string toString(const TypePath::Path& path, bool prefixDot = false); std::string toString(const TypePath::Path& path, bool prefixDot = false);
/// Converts a Path to a human readable string for error reporting.
std::string toStringHuman(const TypePath::Path& path);
std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes); std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes); std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);

View file

@ -40,7 +40,7 @@ struct InConditionalContext
TypeContext* typeContext; TypeContext* typeContext;
TypeContext oldValue; TypeContext oldValue;
InConditionalContext(TypeContext* c) explicit InConditionalContext(TypeContext* c)
: typeContext(c) : typeContext(c)
, oldValue(*c) , oldValue(*c)
{ {

View file

@ -93,10 +93,6 @@ struct Unifier
Unifier(NotNull<Normalizer> normalizer, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); Unifier(NotNull<Normalizer> normalizer, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);
// Configure the Unifier to test for scope subsumption via embedded Scope
// pointers rather than TypeLevels.
void enableNewSolver();
// Test whether the two type vars unify. Never commits the result. // Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy); ErrorVec canUnify(TypeId subTy, TypeId superTy);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
@ -169,7 +165,6 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
TxnLog combineLogsIntoIntersection(std::vector<TxnLog> logs);
TxnLog combineLogsIntoUnion(std::vector<TxnLog> logs); TxnLog combineLogsIntoUnion(std::vector<TxnLog> logs);
public: public:
@ -195,11 +190,6 @@ private:
// Available after regular type pack unification errors // Available after regular type pack unification errors
std::optional<int> firstPackErrorPos; std::optional<int> firstPackErrorPos;
// If true, we do a bunch of small things differently to work better with
// the new type inference engine. Most notably, we use the Scope hierarchy
// directly rather than using TypeLevels.
bool useNewSolver = false;
}; };
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp);

View file

@ -87,6 +87,9 @@ struct Unifier2
bool unify(const AnyType* subAny, const TableType* superTable); bool unify(const AnyType* subAny, const TableType* superTable);
bool unify(const TableType* subTable, const AnyType* superAny); bool unify(const TableType* subTable, const AnyType* superAny);
bool unify(const MetatableType* subMetatable, const AnyType*);
bool unify(const AnyType*, const MetatableType* superMetatable);
// TODO think about this one carefully. We don't do unions or intersections of type packs // TODO think about this one carefully. We don't do unions or intersections of type packs
bool unify(TypePackId subTp, TypePackId superTp); bool unify(TypePackId subTp, TypePackId superTp);

View file

@ -49,6 +49,26 @@ struct UnifierSharedState
DenseHashSet<TypePackId> tempSeenTp{nullptr}; DenseHashSet<TypePackId> tempSeenTp{nullptr};
UnifierCounters counters; UnifierCounters counters;
bool reentrantTypeReduction = false;
};
struct TypeReductionRentrancyGuard final
{
explicit TypeReductionRentrancyGuard(NotNull<UnifierSharedState> sharedState)
: sharedState{sharedState}
{
sharedState->reentrantTypeReduction = true;
}
~TypeReductionRentrancyGuard()
{
sharedState->reentrantTypeReduction = false;
}
TypeReductionRentrancyGuard(const TypeReductionRentrancyGuard&) = delete;
TypeReductionRentrancyGuard(TypeReductionRentrancyGuard&&) = delete;
private:
NotNull<UnifierSharedState> sharedState;
}; };
} // namespace Luau } // namespace Luau

View file

@ -1,902 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/AnyTypeSummary.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h"
#include "Luau/Module.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/ToString.h"
#include "Luau/Transpiler.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeChecker2.h"
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/VisitType.h"
#include "Luau/TypePack.h"
#include "Luau/TypeOrPack.h"
#include <algorithm>
#include <memory>
#include <chrono>
#include <condition_variable>
#include <exception>
#include <mutex>
#include <stdexcept>
#include <string>
#include <iostream>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2);
LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300);
LUAU_FASTFLAG(DebugLuauMagicTypes);
namespace Luau
{
void AnyTypeSummary::traverse(const Module* module, AstStat* src, NotNull<BuiltinTypes> builtinTypes)
{
visit(findInnerMostScope(src->location, module), src, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStat* stat, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
RecursionLimiter limiter{&recursionCount, FInt::LuauAnySummaryRecursionLimit};
if (auto s = stat->as<AstStatBlock>())
return visit(scope, s, module, builtinTypes);
else if (auto i = stat->as<AstStatIf>())
return visit(scope, i, module, builtinTypes);
else if (auto s = stat->as<AstStatWhile>())
return visit(scope, s, module, builtinTypes);
else if (auto s = stat->as<AstStatRepeat>())
return visit(scope, s, module, builtinTypes);
else if (auto r = stat->as<AstStatReturn>())
return visit(scope, r, module, builtinTypes);
else if (auto e = stat->as<AstStatExpr>())
return visit(scope, e, module, builtinTypes);
else if (auto s = stat->as<AstStatLocal>())
return visit(scope, s, module, builtinTypes);
else if (auto s = stat->as<AstStatFor>())
return visit(scope, s, module, builtinTypes);
else if (auto s = stat->as<AstStatForIn>())
return visit(scope, s, module, builtinTypes);
else if (auto a = stat->as<AstStatAssign>())
return visit(scope, a, module, builtinTypes);
else if (auto a = stat->as<AstStatCompoundAssign>())
return visit(scope, a, module, builtinTypes);
else if (auto f = stat->as<AstStatFunction>())
return visit(scope, f, module, builtinTypes);
else if (auto f = stat->as<AstStatLocalFunction>())
return visit(scope, f, module, builtinTypes);
else if (auto a = stat->as<AstStatTypeAlias>())
return visit(scope, a, module, builtinTypes);
else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(scope, s, module, builtinTypes);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(scope, s, module, builtinTypes);
else if (auto s = stat->as<AstStatDeclareClass>())
return visit(scope, s, module, builtinTypes);
else if (auto s = stat->as<AstStatError>())
return visit(scope, s, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatBlock* block, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauAnySummaryRecursionLimit)
return; // don't report
for (AstStat* stat : block->body)
visit(scope, stat, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatIf* ifStatement, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (ifStatement->thenbody)
{
const Scope* thenScope = findInnerMostScope(ifStatement->thenbody->location, module);
visit(thenScope, ifStatement->thenbody, module, builtinTypes);
}
if (ifStatement->elsebody)
{
const Scope* elseScope = findInnerMostScope(ifStatement->elsebody->location, module);
visit(elseScope, ifStatement->elsebody, module, builtinTypes);
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatWhile* while_, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const Scope* whileScope = findInnerMostScope(while_->location, module);
visit(whileScope, while_->body, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatRepeat* repeat, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const Scope* repeatScope = findInnerMostScope(repeat->location, module);
visit(repeatScope, repeat->body, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const Scope* retScope = findInnerMostScope(ret->location, module);
auto ctxNode = getNode(rootSrc, ret);
bool seenTP = false;
for (auto val : ret->list)
{
if (isAnyCall(retScope, val, module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
if (isAnyCast(retScope, val, module, builtinTypes))
{
if (auto cast = val->as<AstExprTypeAssertion>())
{
TelemetryTypePair types;
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
types.inferredType = toString(lookupType(cast->expr, module, builtinTypes));
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
if (ret->list.size > 1 && !seenTP)
{
if (containsAny(retScope->returnType))
{
seenTP = true;
TelemetryTypePair types;
types.inferredType = toString(retScope->returnType);
TypeInfo ti{Pattern::TypePk, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, local);
TypePackId values = reconstructTypePack(local->values, module, builtinTypes);
auto [head, tail] = flatten(values);
size_t posn = 0;
for (AstLocal* loc : local->vars)
{
if (local->vars.data[0] == loc && posn < local->values.size)
{
if (loc->annotation)
{
auto annot = lookupAnnotation(loc->annotation, module, builtinTypes);
if (containsAny(annot))
{
TelemetryTypePair types;
types.annotatedType = toString(annot);
types.inferredType = toString(lookupType(local->values.data[posn], module, builtinTypes));
TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
const AstExprTypeAssertion* maybeRequire = local->values.data[posn]->as<AstExprTypeAssertion>();
if (!maybeRequire)
continue;
if (std::min(local->values.size - 1, posn) < head.size())
{
if (isAnyCast(scope, local->values.data[posn], module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(head[std::min(local->values.size - 1, posn)]);
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
else
{
if (std::min(local->values.size - 1, posn) < head.size())
{
if (loc->annotation)
{
auto annot = lookupAnnotation(loc->annotation, module, builtinTypes);
if (containsAny(annot))
{
TelemetryTypePair types;
types.annotatedType = toString(annot);
types.inferredType = toString(head[std::min(local->values.size - 1, posn)]);
TypeInfo ti{Pattern::VarAnnot, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
else
{
if (tail)
{
if (containsAny(*tail))
{
TelemetryTypePair types;
types.inferredType = toString(*tail);
TypeInfo ti{Pattern::VarAny, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
}
++posn;
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatFor* for_, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const Scope* forScope = findInnerMostScope(for_->location, module);
visit(forScope, for_->body, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatForIn* forIn, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const Scope* loopScope = findInnerMostScope(forIn->location, module);
visit(loopScope, forIn->body, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, assign);
TypePackId values = reconstructTypePack(assign->values, module, builtinTypes);
auto [head, tail] = flatten(values);
size_t posn = 0;
for (AstExpr* var : assign->vars)
{
TypeId tp = lookupType(var, module, builtinTypes);
if (containsAny(tp))
{
TelemetryTypePair types;
types.annotatedType = toString(tp);
auto loc = std::min(assign->vars.size - 1, posn);
if (head.size() >= assign->vars.size && posn < head.size())
{
types.inferredType = toString(head[posn]);
}
else if (loc < head.size())
types.inferredType = toString(head[loc]);
else
types.inferredType = toString(builtinTypes->nilType);
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
++posn;
}
for (AstExpr* val : assign->values)
{
if (isAnyCall(scope, val, module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
if (isAnyCast(scope, val, module, builtinTypes))
{
if (auto cast = val->as<AstExprTypeAssertion>())
{
TelemetryTypePair types;
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
types.inferredType = toString(lookupType(val, module, builtinTypes));
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
if (tail)
{
if (containsAny(*tail))
{
TelemetryTypePair types;
types.inferredType = toString(*tail);
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatCompoundAssign* assign, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, assign);
TelemetryTypePair types;
types.inferredType = toString(lookupType(assign->value, module, builtinTypes));
types.annotatedType = toString(lookupType(assign->var, module, builtinTypes));
if (module->astTypes.contains(assign->var))
{
if (containsAny(*module->astTypes.find(assign->var)))
{
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
else if (module->astTypePacks.contains(assign->var))
{
if (containsAny(*module->astTypePacks.find(assign->var)))
{
TypeInfo ti{Pattern::Assign, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
if (isAnyCall(scope, assign->value, module, builtinTypes))
{
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
if (isAnyCast(scope, assign->value, module, builtinTypes))
{
if (auto cast = assign->value->as<AstExprTypeAssertion>())
{
types.annotatedType = toString(lookupAnnotation(cast->annotation, module, builtinTypes));
types.inferredType = toString(lookupType(cast->expr, module, builtinTypes));
TypeInfo ti{Pattern::Casts, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
if (hasVariadicAnys(scope, function->func, module, builtinTypes))
{
TypeInfo ti{Pattern::VarAny, toString(function), types};
typeInfo.push_back(ti);
}
if (hasArgAnys(scope, function->func, module, builtinTypes))
{
TypeInfo ti{Pattern::FuncArg, toString(function), types};
typeInfo.push_back(ti);
}
if (hasAnyReturns(scope, function->func, module, builtinTypes))
{
TypeInfo ti{Pattern::FuncRet, toString(function), types};
typeInfo.push_back(ti);
}
if (function->func->body->body.size > 0)
visit(scope, function->func->body, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatLocalFunction* function, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
TelemetryTypePair types;
if (hasVariadicAnys(scope, function->func, module, builtinTypes))
{
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::VarAny, toString(function), types};
typeInfo.push_back(ti);
}
if (hasArgAnys(scope, function->func, module, builtinTypes))
{
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::FuncArg, toString(function), types};
typeInfo.push_back(ti);
}
if (hasAnyReturns(scope, function->func, module, builtinTypes))
{
types.inferredType = toString(lookupType(function->func, module, builtinTypes));
TypeInfo ti{Pattern::FuncRet, toString(function), types};
typeInfo.push_back(ti);
}
if (function->func->body->body.size > 0)
visit(scope, function->func->body, module, builtinTypes);
}
void AnyTypeSummary::visit(const Scope* scope, AstStatTypeAlias* alias, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, alias);
auto annot = lookupAnnotation(alias->type, module, builtinTypes);
if (containsAny(annot))
{
// no expr => no inference for aliases
TelemetryTypePair types;
types.annotatedType = toString(annot);
TypeInfo ti{Pattern::Alias, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
auto ctxNode = getNode(rootSrc, expr);
if (isAnyCall(scope, expr->expr, module, builtinTypes))
{
TelemetryTypePair types;
types.inferredType = toString(lookupType(expr->expr, module, builtinTypes));
TypeInfo ti{Pattern::FuncApp, toString(ctxNode), types};
typeInfo.push_back(ti);
}
}
void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareGlobal* declareGlobal, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareClass* declareClass, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatDeclareFunction* declareFunction, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
void AnyTypeSummary::visit(const Scope* scope, AstStatError* error, const Module* module, NotNull<BuiltinTypes> builtinTypes) {}
TypeId AnyTypeSummary::checkForFamilyInhabitance(const TypeId instance, const Location location)
{
if (seenTypeFamilyInstances.find(instance))
return instance;
seenTypeFamilyInstances.insert(instance);
return instance;
}
TypeId AnyTypeSummary::lookupType(const AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
const TypeId* ty = module->astTypes.find(expr);
if (ty)
return checkForFamilyInhabitance(follow(*ty), expr->location);
const TypePackId* tp = module->astTypePacks.find(expr);
if (tp)
{
if (auto fst = first(*tp, /*ignoreHiddenVariadics*/ false))
return checkForFamilyInhabitance(*fst, expr->location);
else if (finite(*tp) && size(*tp) == 0)
return checkForFamilyInhabitance(builtinTypes->nilType, expr->location);
}
return builtinTypes->errorRecoveryType();
}
TypePackId AnyTypeSummary::reconstructTypePack(AstArray<AstExpr*> exprs, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (exprs.size == 0)
return arena.addTypePack(TypePack{{}, std::nullopt});
std::vector<TypeId> head;
for (size_t i = 0; i < exprs.size - 1; ++i)
{
head.push_back(lookupType(exprs.data[i], module, builtinTypes));
}
const TypePackId* tail = module->astTypePacks.find(exprs.data[exprs.size - 1]);
if (tail)
return arena.addTypePack(TypePack{std::move(head), follow(*tail)});
else
return arena.addTypePack(TypePack{std::move(head), builtinTypes->errorRecoveryTypePack()});
}
bool AnyTypeSummary::isAnyCall(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (auto call = expr->as<AstExprCall>())
{
TypePackId args = reconstructTypePack(call->args, module, builtinTypes);
if (containsAny(args))
return true;
TypeId func = lookupType(call->func, module, builtinTypes);
if (containsAny(func))
return true;
}
return false;
}
bool AnyTypeSummary::hasVariadicAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (expr->vararg && expr->varargAnnotation)
{
auto annot = lookupPackAnnotation(expr->varargAnnotation, module);
if (annot && containsAny(*annot))
{
return true;
}
}
return false;
}
bool AnyTypeSummary::hasArgAnys(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (expr->args.size > 0)
{
for (const AstLocal* arg : expr->args)
{
if (arg->annotation)
{
auto annot = lookupAnnotation(arg->annotation, module, builtinTypes);
if (containsAny(annot))
{
return true;
}
}
}
}
return false;
}
bool AnyTypeSummary::hasAnyReturns(const Scope* scope, AstExprFunction* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (!expr->returnAnnotation)
{
return false;
}
for (AstType* ret : expr->returnAnnotation->types)
{
if (containsAny(lookupAnnotation(ret, module, builtinTypes)))
{
return true;
}
}
if (expr->returnAnnotation->tailType)
{
auto annot = lookupPackAnnotation(expr->returnAnnotation->tailType, module);
if (annot && containsAny(*annot))
{
return true;
}
}
return false;
}
bool AnyTypeSummary::isAnyCast(const Scope* scope, AstExpr* expr, const Module* module, NotNull<BuiltinTypes> builtinTypes)
{
if (auto cast = expr->as<AstExprTypeAssertion>())
{
auto annot = lookupAnnotation(cast->annotation, module, builtinTypes);
if (containsAny(annot))
{
return true;
}
}
return false;
}
TypeId AnyTypeSummary::lookupAnnotation(AstType* annotation, const Module* module, NotNull<BuiltinTypes> builtintypes)
{
if (FFlag::DebugLuauMagicTypes)
{
if (auto ref = annotation->as<AstTypeReference>(); ref && ref->parameters.size > 0)
{
if (auto ann = ref->parameters.data[0].type)
{
TypeId argTy = lookupAnnotation(ref->parameters.data[0].type, module, builtintypes);
return follow(argTy);
}
}
}
const TypeId* ty = module->astResolvedTypes.find(annotation);
if (ty)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
else
return checkForTypeFunctionInhabitance(builtintypes->errorRecoveryType(), annotation->location);
}
TypeId AnyTypeSummary::checkForTypeFunctionInhabitance(const TypeId instance, const Location location)
{
if (seenTypeFunctionInstances.find(instance))
return instance;
seenTypeFunctionInstances.insert(instance);
return instance;
}
std::optional<TypePackId> AnyTypeSummary::lookupPackAnnotation(AstTypePack* annotation, const Module* module)
{
const TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr)
return {follow(*tp)};
return {};
}
bool AnyTypeSummary::containsAny(TypeId typ)
{
typ = follow(typ);
if (auto t = seen.find(typ); t && !*t)
{
return false;
}
seen[typ] = false;
RecursionCounter counter{&recursionCount};
if (recursionCount >= FInt::LuauAnySummaryRecursionLimit)
{
return false;
}
bool found = false;
if (auto ty = get<AnyType>(typ))
{
found = true;
}
else if (auto ty = get<UnknownType>(typ))
{
found = true;
}
else if (auto ty = get<TableType>(typ))
{
for (auto& [_name, prop] : ty->props)
{
if (FFlag::LuauSolverV2)
{
if (auto newT = follow(prop.readTy))
{
if (containsAny(*newT))
found = true;
}
else if (auto newT = follow(prop.writeTy))
{
if (containsAny(*newT))
found = true;
}
}
else
{
if (containsAny(prop.type()))
found = true;
}
}
}
else if (auto ty = get<IntersectionType>(typ))
{
for (auto part : ty->parts)
{
if (containsAny(part))
{
found = true;
}
}
}
else if (auto ty = get<UnionType>(typ))
{
for (auto option : ty->options)
{
if (containsAny(option))
{
found = true;
}
}
}
else if (auto ty = get<FunctionType>(typ))
{
if (containsAny(ty->argTypes))
found = true;
else if (containsAny(ty->retTypes))
found = true;
}
seen[typ] = found;
return found;
}
bool AnyTypeSummary::containsAny(TypePackId typ)
{
typ = follow(typ);
if (auto t = seen.find(typ); t && !*t)
{
return false;
}
seen[typ] = false;
auto [head, tail] = flatten(typ);
bool found = false;
for (auto tp : head)
{
if (containsAny(tp))
found = true;
}
if (tail)
{
if (auto vtp = get<VariadicTypePack>(tail))
{
if (auto ty = get<AnyType>(follow(vtp->ty)))
{
found = true;
}
}
else if (auto tftp = get<TypeFunctionInstanceTypePack>(tail))
{
for (TypePackId tp : tftp->packArguments)
{
if (containsAny(tp))
{
found = true;
}
}
for (TypeId t : tftp->typeArguments)
{
if (containsAny(t))
{
found = true;
}
}
}
}
seen[typ] = found;
return found;
}
const Scope* AnyTypeSummary::findInnerMostScope(const Location location, const Module* module)
{
const Scope* bestScope = module->getModuleScope().get();
bool didNarrow = false;
do
{
didNarrow = false;
for (auto scope : bestScope->children)
{
if (scope->location.encloses(location))
{
bestScope = scope.get();
didNarrow = true;
break;
}
}
} while (didNarrow && bestScope->children.size() > 0);
return bestScope;
}
std::optional<AstExpr*> AnyTypeSummary::matchRequire(const AstExprCall& call)
{
const char* require = "require";
if (call.args.size != 1)
return std::nullopt;
const AstExprGlobal* funcAsGlobal = call.func->as<AstExprGlobal>();
if (!funcAsGlobal || funcAsGlobal->name != require)
return std::nullopt;
if (call.args.size != 1)
return std::nullopt;
return call.args.data[0];
}
AstNode* AnyTypeSummary::getNode(AstStatBlock* root, AstNode* node)
{
FindReturnAncestry finder(node, root->location.end);
root->visit(&finder);
if (!finder.currNode)
finder.currNode = node;
LUAU_ASSERT(finder.found && finder.currNode);
return finder.currNode;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstStatLocalFunction* node)
{
currNode = node;
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstStatFunction* node)
{
currNode = node;
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstType* node)
{
return !found;
}
bool AnyTypeSummary::FindReturnAncestry::visit(AstNode* node)
{
if (node == stat)
{
found = true;
}
if (node->location.end == rootEnd && stat->location.end >= rootEnd)
{
currNode = node;
found = true;
}
return !found;
}
AnyTypeSummary::TypeInfo::TypeInfo(Pattern code, std::string node, TelemetryTypePair type)
: code(code)
, node(node)
, type(type)
{
}
AnyTypeSummary::FindReturnAncestry::FindReturnAncestry(AstNode* stat, Position rootEnd)
: stat(stat)
, rootEnd(rootEnd)
{
}
AnyTypeSummary::AnyTypeSummary() {}
} // namespace Luau

View file

@ -1065,6 +1065,11 @@ struct AstJsonEncoder : public AstVisitor
); );
} }
void write(class AstTypeOptional* node)
{
writeNode(node, "AstTypeOptional", [&]() {});
}
void write(class AstTypeUnion* node) void write(class AstTypeUnion* node)
{ {
writeNode( writeNode(
@ -1146,6 +1151,8 @@ struct AstJsonEncoder : public AstVisitor
return writeString("checked"); return writeString("checked");
case AstAttr::Type::Native: case AstAttr::Type::Native:
return writeString("native"); return writeString("native");
case AstAttr::Type::Deprecated:
return writeString("deprecated");
} }
} }
@ -1161,6 +1168,19 @@ struct AstJsonEncoder : public AstVisitor
); );
} }
bool visit(class AstTypeGroup* node) override
{
writeNode(
node,
"AstTypeGroup",
[&]()
{
write("inner", node->type);
}
);
return false;
}
bool visit(class AstTypeSingletonBool* node) override bool visit(class AstTypeSingletonBool* node) override
{ {
writeNode( writeNode(

View file

@ -2,6 +2,7 @@
#include "Luau/Autocomplete.h" #include "Luau/Autocomplete.h"
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
@ -15,6 +16,9 @@ namespace Luau
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback)
{ {
LUAU_TIMETRACE_SCOPE("Luau::autocomplete", "Autocomplete");
LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str());
const SourceModule* sourceModule = frontend.getSourceModule(moduleName); const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule) if (!sourceModule)
return {}; return {};

View file

@ -10,6 +10,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/TimeTrace.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Subtyping.h" #include "Luau/Subtyping.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
@ -20,12 +21,16 @@
#include <utility> #include <utility>
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions2)
LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAGVARIABLE(DebugLuauMagicVariableNames)
LUAU_FASTFLAG(LuauExposeRequireByStringAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteUseLimits)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteUsesModuleForTypeCompatibility)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteUnionCopyPreviousSeen)
static const std::unordered_set<std::string> kStatementStartingKeywords = static const std::unordered_set<std::string> kStatementStartingKeywords =
{"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -147,44 +152,91 @@ static std::optional<TypeId> findExpectedTypeAt(const Module& module, AstNode* n
return *it; return *it;
} }
static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, TypeArena* typeArena, NotNull<BuiltinTypes> builtinTypes) static bool checkTypeMatch(
const Module& module,
TypeId subTy,
TypeId superTy,
NotNull<Scope> scope,
TypeArena* typeArena,
NotNull<BuiltinTypes> builtinTypes
)
{ {
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter); UnifierSharedState unifierState(&iceReporter);
SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes); SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes);
Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}};
if (FFlag::LuauAutocompleteUsesModuleForTypeCompatibility)
if (FFlag::LuauSolverV2)
{ {
TypeCheckLimits limits; if (module.checkedInNewSolver)
TypeFunctionRuntime typeFunctionRuntime{ {
NotNull{&iceReporter}, NotNull{&limits} TypeCheckLimits limits;
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime TypeFunctionRuntime typeFunctionRuntime{
NotNull{&iceReporter}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
Subtyping subtyping{ Subtyping subtyping{
builtinTypes, NotNull{typeArena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter} builtinTypes,
}; NotNull{typeArena},
NotNull{simplifier.get()},
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull{&iceReporter}
};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype; return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}
else
{
Unifier unifier(NotNull<Normalizer>{&normalizer}, scope, Location(), Variance::Covariant);
// Cost of normalization can be too high for autocomplete response time requirements
unifier.normalize = false;
unifier.checkInhabited = false;
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
return unifier.canUnify(subTy, superTy).empty();
}
} }
else else
{ {
Unifier unifier(NotNull<Normalizer>{&normalizer}, scope, Location(), Variance::Covariant); if (FFlag::LuauSolverV2)
// Cost of normalization can be too high for autocomplete response time requirements
unifier.normalize = false;
unifier.checkInhabited = false;
if (FFlag::LuauAutocompleteUseLimits)
{ {
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{
NotNull{&iceReporter}, NotNull{&limits}
}; // TODO: maybe subtyping checks should not invoke user-defined type function runtime
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
}
return unifier.canUnify(subTy, superTy).empty(); Subtyping subtyping{
builtinTypes,
NotNull{typeArena},
NotNull{simplifier.get()},
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull{&iceReporter}
};
return subtyping.isSubtype(subTy, superTy, scope).isSubtype;
}
else
{
Unifier unifier(NotNull<Normalizer>{&normalizer}, scope, Location(), Variance::Covariant);
// Cost of normalization can be too high for autocomplete response time requirements
unifier.normalize = false;
unifier.checkInhabited = false;
unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit;
unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit;
return unifier.canUnify(subTy, superTy).empty();
}
} }
} }
@ -210,10 +262,10 @@ static TypeCorrectKind checkTypeCorrectKind(
TypeId expectedType = follow(*typeAtPosition); TypeId expectedType = follow(*typeAtPosition);
auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType, &module](const FunctionType* ftv)
{ {
if (std::optional<TypeId> firstRetTy = first(ftv->retTypes)) if (std::optional<TypeId> firstRetTy = first(ftv->retTypes))
return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); return checkTypeMatch(module, *firstRetTy, expectedType, moduleScope, typeArena, builtinTypes);
return false; return false;
}; };
@ -236,7 +288,7 @@ static TypeCorrectKind checkTypeCorrectKind(
} }
} }
return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; return checkTypeMatch(module, ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
} }
enum class PropIndexType enum class PropIndexType
@ -287,7 +339,7 @@ static void autocompleteProps(
// When called with '.', but declared with 'self', it is considered invalid if first argument is compatible // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible
if (std::optional<TypeId> firstArgTy = first(ftv->argTypes)) if (std::optional<TypeId> firstArgTy = first(ftv->argTypes))
{ {
if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) if (checkTypeMatch(module, rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes))
return calledWithSelf; return calledWithSelf;
} }
@ -433,6 +485,21 @@ static void autocompleteProps(
AutocompleteEntryMap inner; AutocompleteEntryMap inner;
std::unordered_set<TypeId> innerSeen; std::unordered_set<TypeId> innerSeen;
// If we don't do this, and we have the misfortune of receiving a
// recursive union like:
//
// t1 where t1 = t1 | Class
//
// Then we are on a one way journey to a stack overflow.
if (FFlag::LuauAutocompleteUnionCopyPreviousSeen)
{
for (auto ty: seen)
{
if (is<UnionType, IntersectionType>(ty))
innerSeen.insert(ty);
}
}
if (isNil(*iter)) if (isNil(*iter))
{ {
++iter; ++iter;
@ -1295,6 +1362,15 @@ static AutocompleteContext autocompleteExpression(
AstNode* node = ancestry.rbegin()[0]; AstNode* node = ancestry.rbegin()[0];
if (FFlag::DebugLuauMagicVariableNames)
{
InternalErrorReporter ice;
if (auto local = node->as<AstExprLocal>(); local && local->local->name == "_luau_autocomplete_ice")
ice.ice("_luau_autocomplete_ice encountered", local->location);
if (auto global = node->as<AstExprGlobal>(); global && global->name == "_luau_autocomplete_ice")
ice.ice("_luau_autocomplete_ice encountered", global->location);
}
if (node->is<AstExprIndexName>()) if (node->is<AstExprIndexName>())
{ {
if (auto it = module.astTypes.find(node->asExpr())) if (auto it = module.astTypes.find(node->asExpr()))
@ -1461,10 +1537,14 @@ static std::optional<AutocompleteEntryMap> convertRequireSuggestionsToAutocomple
return std::nullopt; return std::nullopt;
AutocompleteEntryMap result; AutocompleteEntryMap result;
for (const RequireSuggestion& suggestion : *suggestions) for (RequireSuggestion& suggestion : *suggestions)
{ {
AutocompleteEntry entry = {AutocompleteEntryKind::RequirePath}; AutocompleteEntry entry = {AutocompleteEntryKind::RequirePath};
entry.insertText = std::move(suggestion.fullPath); entry.insertText = std::move(suggestion.fullPath);
if (FFlag::LuauExposeRequireByStringAutocomplete)
{
entry.tags = std::move(suggestion.tags);
}
result[std::move(suggestion.label)] = std::move(entry); result[std::move(suggestion.label)] = std::move(entry);
} }
return result; return result;
@ -1521,12 +1601,9 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(
{ {
for (const std::string& tag : funcType->tags) for (const std::string& tag : funcType->tags)
{ {
if (FFlag::AutocompleteRequirePathSuggestions2) if (tag == kRequireTagName && fileResolver)
{ {
if (tag == kRequireTagName && fileResolver) return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString));
{
return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString));
}
} }
if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) if (std::optional<AutocompleteEntryMap> ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString))
{ {
@ -1718,6 +1795,7 @@ AutocompleteResult autocomplete_(
StringCompletionCallback callback StringCompletionCallback callback
) )
{ {
LUAU_TIMETRACE_SCOPE("Luau::autocomplete_", "AutocompleteCore");
AstNode* node = ancestry.back(); AstNode* node = ancestry.back();
AstExprConstantNil dummy{Location{}}; AstExprConstantNil dummy{Location{}};

View file

@ -29,81 +29,90 @@
*/ */
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)
LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression) LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression)
LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3)
LUAU_FASTFLAG(LuauVectorDefinitionsExtra)
LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(LuauFollowTableFreeze)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunTypecheck)
namespace Luau namespace Luau
{ {
struct MagicSelect final : MagicFunction struct MagicSelect final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicSetMetatable final : MagicFunction struct MagicSetMetatable final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicAssert final : MagicFunction struct MagicAssert final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicPack final : MagicFunction struct MagicPack final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicRequire final : MagicFunction struct MagicRequire final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicClone final : MagicFunction struct MagicClone final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicFreeze final : MagicFunction struct MagicFreeze final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicFormat final : MagicFunction struct MagicFormat final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override; bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override;
}; };
struct MagicMatch final : MagicFunction struct MagicMatch final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicGmatch final : MagicFunction struct MagicGmatch final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
struct MagicFind final : MagicFunction struct MagicFind final : MagicFunction
{ {
std::optional<WithPredicate<TypePackId>> handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override; std::optional<WithPredicate<TypePackId>>
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>) override;
bool infer(const MagicFunctionCallContext& ctx) override; bool infer(const MagicFunctionCallContext& ctx) override;
}; };
@ -279,6 +288,22 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string&
} }
} }
static void finalizeGlobalBindings(ScopePtr scope)
{
LUAU_ASSERT(FFlag::LuauUserTypeFunTypecheck);
for (const auto& pair : scope->bindings)
{
persist(pair.second.typeId);
if (TableType* ttv = getMutable<TableType>(pair.second.typeId))
{
if (!ttv->name)
ttv->name = "typeof(" + toString(pair.first) + ")";
}
}
}
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete)
{ {
LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); LUAU_ASSERT(!globals.globalTypes.types.isFrozen());
@ -310,28 +335,25 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
addGlobalBinding(globals, "string", it->second.type(), "@luau"); addGlobalBinding(globals, "string", it->second.type(), "@luau");
// Setup 'vector' metatable // Setup 'vector' metatable
if (FFlag::LuauVectorDefinitionsExtra) if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end())
{ {
if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end()) TypeId vectorTy = it->second.type;
{ ClassType* vectorCls = getMutable<ClassType>(vectorTy);
TypeId vectorTy = it->second.type;
ClassType* vectorCls = getMutable<ClassType>(vectorTy);
vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed}); vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed});
TableType* metatableTy = Luau::getMutable<TableType>(vectorCls->metatable); TableType* metatableTy = Luau::getMutable<TableType>(vectorCls->metatable);
metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})};
metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})}; metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})};
std::initializer_list<TypeId> mulOverloads{ std::initializer_list<TypeId> mulOverloads{
makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}), makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}),
makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}), makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}),
}; };
metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)}; metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)}; metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)};
metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)}; metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)};
}
} }
// next<K, V>(t: Table<K, V>, i: K?) -> (K?, V) // next<K, V>(t: Table<K, V>, i: K?) -> (K?, V)
@ -393,14 +415,21 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
// clang-format on // clang-format on
} }
for (const auto& pair : globals.globalScope->bindings) if (FFlag::LuauUserTypeFunTypecheck)
{ {
persist(pair.second.typeId); finalizeGlobalBindings(globals.globalScope);
}
if (TableType* ttv = getMutable<TableType>(pair.second.typeId)) else
{
for (const auto& pair : globals.globalScope->bindings)
{ {
if (!ttv->name) persist(pair.second.typeId);
ttv->name = "typeof(" + toString(pair.first) + ")";
if (TableType* ttv = getMutable<TableType>(pair.second.typeId))
{
if (!ttv->name)
ttv->name = "typeof(" + toString(pair.first) + ")";
}
} }
} }
@ -453,21 +482,66 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
ttv->props["foreachi"].deprecated = true; ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>()); attachMagicFunction(ttv->props["pack"].type(), std::make_shared<MagicPack>());
if (FFlag::LuauTableCloneClonesType) if (FFlag::LuauTableCloneClonesType3)
attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>()); attachMagicFunction(ttv->props["clone"].type(), std::make_shared<MagicClone>());
if (FFlag::LuauTypestateBuiltins2) attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
attachMagicFunction(ttv->props["freeze"].type(), std::make_shared<MagicFreeze>());
} }
if (FFlag::AutocompleteRequirePathSuggestions2) TypeId requireTy = getGlobalBinding(globals, "require");
attachTag(requireTy, kRequireTagName);
attachMagicFunction(requireTy, std::make_shared<MagicRequire>());
if (FFlag::LuauUserTypeFunTypecheck)
{ {
TypeId requireTy = getGlobalBinding(globals, "require"); // Global scope cannot be the parent of the type checking environment because it can be changed by the embedder
attachTag(requireTy, kRequireTagName); globals.globalTypeFunctionScope->exportedTypeBindings = globals.globalScope->exportedTypeBindings;
attachMagicFunction(requireTy, std::make_shared<MagicRequire>()); globals.globalTypeFunctionScope->builtinTypeNames = globals.globalScope->builtinTypeNames;
}
else // Type function runtime also removes a few standard libraries and globals, so we will take only the ones that are defined
{ static const char* typeFunctionRuntimeBindings[] = {
attachMagicFunction(getGlobalBinding(globals, "require"), std::make_shared<MagicRequire>()); // Libraries
"math",
"table",
"string",
"bit32",
"utf8",
"buffer",
// Globals
"assert",
"error",
"print",
"next",
"ipairs",
"pairs",
"select",
"unpack",
"getmetatable",
"setmetatable",
"rawget",
"rawset",
"rawlen",
"rawequal",
"tonumber",
"tostring",
"type",
"typeof",
};
for (auto& name : typeFunctionRuntimeBindings)
{
AstName astName = globals.globalNames.names->get(name);
LUAU_ASSERT(astName.value);
globals.globalTypeFunctionScope->bindings[astName] = globals.globalScope->bindings[astName];
}
LoadDefinitionFileResult typeFunctionLoadResult = frontend.loadDefinitionFile(
globals, globals.globalTypeFunctionScope, getTypeFunctionDefinitionSource(), "@luau", /* captureComments */ false, false
);
LUAU_ASSERT(typeFunctionLoadResult.success);
finalizeGlobalBindings(globals.globalTypeFunctionScope);
} }
} }
@ -617,10 +691,7 @@ bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
if (!fmt) if (!fmt)
{ {
if (FFlag::LuauStringFormatArityFix) context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location);
context.typechecker->reportError(
CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location
);
return true; return true;
} }
@ -645,15 +716,15 @@ bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context)
{ {
switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy)) switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy))
{ {
case ErrorSuppression::Suppress: case ErrorSuppression::Suppress:
break; break;
case ErrorSuppression::NormalizationFailed: case ErrorSuppression::NormalizationFailed:
break; break;
case ErrorSuppression::DoNotSuppress: case ErrorSuppression::DoNotSuppress:
Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result);
if (!reasonings.suppressed) if (!reasonings.suppressed)
context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location);
} }
} }
else else
@ -1405,7 +1476,7 @@ std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
WithPredicate<TypePackId> withPredicate WithPredicate<TypePackId> withPredicate
) )
{ {
LUAU_ASSERT(FFlag::LuauTableCloneClonesType); LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
auto [paramPack, _predicates] = withPredicate; auto [paramPack, _predicates] = withPredicate;
@ -1420,6 +1491,9 @@ std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
TypeId inputType = follow(paramTypes[0]); TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return std::nullopt;
CloneState cloneState{typechecker.builtinTypes}; CloneState cloneState{typechecker.builtinTypes};
TypeId resultType = shallowClone(inputType, arena, cloneState); TypeId resultType = shallowClone(inputType, arena, cloneState);
@ -1429,7 +1503,7 @@ std::optional<WithPredicate<TypePackId>> MagicClone::handleOldSolver(
bool MagicClone::infer(const MagicFunctionCallContext& context) bool MagicClone::infer(const MagicFunctionCallContext& context)
{ {
LUAU_ASSERT(FFlag::LuauTableCloneClonesType); LUAU_ASSERT(FFlag::LuauTableCloneClonesType3);
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
@ -1442,8 +1516,11 @@ bool MagicClone::infer(const MagicFunctionCallContext& context)
TypeId inputType = follow(paramTypes[0]); TypeId inputType = follow(paramTypes[0]);
if (!get<TableType>(inputType))
return false;
CloneState cloneState{context.solver->builtinTypes}; CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState); TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ true);
if (auto tableType = getMutable<TableType>(resultType)) if (auto tableType = getMutable<TableType>(resultType))
{ {
@ -1462,7 +1539,8 @@ bool MagicClone::infer(const MagicFunctionCallContext& context)
static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context) static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCallContext& context)
{ {
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
if (FFlag::LuauFollowTableFreeze)
inputType = follow(inputType);
if (auto mt = get<MetatableType>(inputType)) if (auto mt = get<MetatableType>(inputType))
{ {
std::optional<TypeId> frozenTable = freezeTable(mt->table, context); std::optional<TypeId> frozenTable = freezeTable(mt->table, context);
@ -1479,7 +1557,7 @@ static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCa
{ {
// Clone the input type, this will become our final result type after we mutate it. // Clone the input type, this will become our final result type after we mutate it.
CloneState cloneState{context.solver->builtinTypes}; CloneState cloneState{context.solver->builtinTypes};
TypeId resultType = shallowClone(inputType, *arena, cloneState); TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ true);
auto tableTy = getMutable<TableType>(resultType); auto tableTy = getMutable<TableType>(resultType);
// `clone` should not break this. // `clone` should not break this.
LUAU_ASSERT(tableTy); LUAU_ASSERT(tableTy);
@ -1504,15 +1582,14 @@ static std::optional<TypeId> freezeTable(TypeId inputType, const MagicFunctionCa
return std::nullopt; return std::nullopt;
} }
std::optional<WithPredicate<TypePackId>> MagicFreeze::handleOldSolver(struct TypeChecker &, const std::shared_ptr<struct Scope> &, const class AstExprCall &, WithPredicate<TypePackId>) std::optional<WithPredicate<TypePackId>> MagicFreeze::
handleOldSolver(struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)
{ {
return std::nullopt; return std::nullopt;
} }
bool MagicFreeze::infer(const MagicFunctionCallContext& context) bool MagicFreeze::infer(const MagicFunctionCallContext& context)
{ {
LUAU_ASSERT(FFlag::LuauTypestateBuiltins2);
TypeArena* arena = context.solver->arena; TypeArena* arena = context.solver->arena;
const DataFlowGraph* dfg = context.solver->dfg.get(); const DataFlowGraph* dfg = context.solver->dfg.get();
Scope* scope = context.constraint->scope.get(); Scope* scope = context.constraint->scope.get();

View file

@ -1,16 +1,20 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/NotNull.h" #include "Luau/NotNull.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/VisitType.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. // For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit.
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
LUAU_FASTFLAGVARIABLE(LuauClonedTableAndFunctionTypesMustHaveScopes)
LUAU_FASTFLAGVARIABLE(LuauDoNotClonePersistentBindings)
LUAU_FASTFLAG(LuauIncrementalAutocompleteDemandBasedCloning)
namespace Luau namespace Luau
{ {
@ -27,6 +31,8 @@ const T* get(const Kind& kind)
class TypeCloner class TypeCloner
{ {
protected:
NotNull<TypeArena> arena; NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes; NotNull<BuiltinTypes> builtinTypes;
@ -38,17 +44,31 @@ class TypeCloner
NotNull<SeenTypes> types; NotNull<SeenTypes> types;
NotNull<SeenTypePacks> packs; NotNull<SeenTypePacks> packs;
TypeId forceTy = nullptr;
TypePackId forceTp = nullptr;
int steps = 0; int steps = 0;
public: public:
TypeCloner(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<SeenTypes> types, NotNull<SeenTypePacks> packs) TypeCloner(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<SeenTypes> types,
NotNull<SeenTypePacks> packs,
TypeId forceTy,
TypePackId forceTp
)
: arena(arena) : arena(arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, types(types) , types(types)
, packs(packs) , packs(packs)
, forceTy(forceTy)
, forceTp(forceTp)
{ {
} }
virtual ~TypeCloner() = default;
TypeId clone(TypeId ty) TypeId clone(TypeId ty)
{ {
shallowClone(ty); shallowClone(ty);
@ -107,12 +127,13 @@ private:
} }
} }
protected:
std::optional<TypeId> find(TypeId ty) const std::optional<TypeId> find(TypeId ty) const
{ {
ty = follow(ty, FollowOption::DisableLazyTypeThunks); ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end()) if (auto it = types->find(ty); it != types->end())
return it->second; return it->second;
else if (ty->persistent) else if (ty->persistent && ty != forceTy)
return ty; return ty;
return std::nullopt; return std::nullopt;
} }
@ -122,7 +143,7 @@ private:
tp = follow(tp); tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end()) if (auto it = packs->find(tp); it != packs->end())
return it->second; return it->second;
else if (tp->persistent) else if (tp->persistent && tp != forceTp)
return tp; return tp;
return std::nullopt; return std::nullopt;
} }
@ -141,14 +162,14 @@ private:
} }
public: public:
TypeId shallowClone(TypeId ty) virtual TypeId shallowClone(TypeId ty)
{ {
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
ty = follow(ty, FollowOption::DisableLazyTypeThunks); ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto clone = find(ty)) if (auto clone = find(ty))
return *clone; return *clone;
else if (ty->persistent) else if (ty->persistent && ty != forceTy)
return ty; return ty;
TypeId target = arena->addType(ty->ty); TypeId target = arena->addType(ty->ty);
@ -168,13 +189,13 @@ public:
return target; return target;
} }
TypePackId shallowClone(TypePackId tp) virtual TypePackId shallowClone(TypePackId tp)
{ {
tp = follow(tp); tp = follow(tp);
if (auto clone = find(tp)) if (auto clone = find(tp))
return *clone; return *clone;
else if (tp->persistent) else if (tp->persistent && tp != forceTp)
return tp; return tp;
TypePackId target = arena->addTypePack(tp->ty); TypePackId target = arena->addTypePack(tp->ty);
@ -376,7 +397,7 @@ private:
ty = shallowClone(ty); ty = shallowClone(ty);
} }
void cloneChildren(LazyType* t) virtual void cloneChildren(LazyType* t)
{ {
if (auto unwrapped = t->unwrapped.load()) if (auto unwrapped = t->unwrapped.load())
t->unwrapped.store(shallowClone(unwrapped)); t->unwrapped.store(shallowClone(unwrapped));
@ -456,23 +477,127 @@ private:
} }
}; };
class FragmentAutocompleteTypeCloner final : public TypeCloner
{
Scope* replacementForNullScope = nullptr;
public:
FragmentAutocompleteTypeCloner(
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
NotNull<SeenTypes> types,
NotNull<SeenTypePacks> packs,
TypeId forceTy,
TypePackId forceTp,
Scope* replacementForNullScope
)
: TypeCloner(arena, builtinTypes, types, packs, forceTy, forceTp)
, replacementForNullScope(replacementForNullScope)
{
LUAU_ASSERT(replacementForNullScope);
}
TypeId shallowClone(TypeId ty) override
{
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto clone = find(ty))
return *clone;
else if (ty->persistent && ty != forceTy)
return ty;
TypeId target = arena->addType(ty->ty);
asMutable(target)->documentationSymbol = ty->documentationSymbol;
if (auto generic = getMutable<GenericType>(target))
generic->scope = nullptr;
else if (auto free = getMutable<FreeType>(target))
{
free->scope = replacementForNullScope;
}
else if (auto tt = getMutable<TableType>(target))
{
if (FFlag::LuauClonedTableAndFunctionTypesMustHaveScopes)
tt->scope = replacementForNullScope;
}
else if (auto fn = getMutable<FunctionType>(target))
{
if (FFlag::LuauClonedTableAndFunctionTypesMustHaveScopes)
fn->scope = replacementForNullScope;
}
(*types)[ty] = target;
queue.emplace_back(target);
return target;
}
TypePackId shallowClone(TypePackId tp) override
{
tp = follow(tp);
if (auto clone = find(tp))
return *clone;
else if (tp->persistent && tp != forceTp)
return tp;
TypePackId target = arena->addTypePack(tp->ty);
if (auto generic = getMutable<GenericTypePack>(target))
generic->scope = nullptr;
else if (auto free = getMutable<FreeTypePack>(target))
free->scope = replacementForNullScope;
(*packs)[tp] = target;
queue.emplace_back(target);
return target;
}
void cloneChildren(LazyType* t) override
{
// Do not clone lazy types
if (!FFlag::LuauIncrementalAutocompleteDemandBasedCloning)
{
if (auto unwrapped = t->unwrapped.load())
t->unwrapped.store(shallowClone(unwrapped));
}
}
};
} // namespace } // namespace
TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState) TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{ {
if (tp->persistent) if (tp->persistent && !ignorePersistent)
return tp; return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
ignorePersistent ? tp : nullptr
};
return cloner.shallowClone(tp); return cloner.shallowClone(tp);
} }
TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState) TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent)
{ {
if (typeId->persistent) if (typeId->persistent && !ignorePersistent)
return typeId; return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
ignorePersistent ? typeId : nullptr,
nullptr
};
return cloner.shallowClone(typeId); return cloner.shallowClone(typeId);
} }
@ -481,7 +606,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
if (tp->persistent) if (tp->persistent)
return tp; return tp;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(tp); return cloner.clone(tp);
} }
@ -490,13 +615,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (typeId->persistent) if (typeId->persistent)
return typeId; return typeId;
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
return cloner.clone(typeId); return cloner.clone(typeId);
} }
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{ {
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
TypeFun copy = typeFun; TypeFun copy = typeFun;
@ -521,4 +646,110 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
return copy; return copy;
} }
Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState)
{
TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr};
Binding b;
b.deprecated = binding.deprecated;
b.deprecatedSuggestion = binding.deprecatedSuggestion;
b.documentationSymbol = binding.documentationSymbol;
b.location = binding.location;
b.typeId = cloner.clone(binding.typeId);
return b;
}
TypePackId cloneIncremental(TypePackId tp, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes)
{
if (tp->persistent)
return tp;
FragmentAutocompleteTypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
nullptr,
freshScopeForFreeTypes
};
return cloner.clone(tp);
}
TypeId cloneIncremental(TypeId typeId, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes)
{
if (typeId->persistent)
return typeId;
FragmentAutocompleteTypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
nullptr,
freshScopeForFreeTypes
};
return cloner.clone(typeId);
}
TypeFun cloneIncremental(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes)
{
FragmentAutocompleteTypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
nullptr,
freshScopeForFreeTypes
};
TypeFun copy = typeFun;
for (auto& param : copy.typeParams)
{
param.ty = cloner.clone(param.ty);
if (param.defaultValue)
param.defaultValue = cloner.clone(*param.defaultValue);
}
for (auto& param : copy.typePackParams)
{
param.tp = cloner.clone(param.tp);
if (param.defaultValue)
param.defaultValue = cloner.clone(*param.defaultValue);
}
copy.type = cloner.clone(copy.type);
return copy;
}
Binding cloneIncremental(const Binding& binding, TypeArena& dest, CloneState& cloneState, Scope* freshScopeForFreeTypes)
{
FragmentAutocompleteTypeCloner cloner{
NotNull{&dest},
cloneState.builtinTypes,
NotNull{&cloneState.seenTypes},
NotNull{&cloneState.seenTypePacks},
nullptr,
nullptr,
freshScopeForFreeTypes
};
Binding b;
b.deprecated = binding.deprecated;
b.deprecatedSuggestion = binding.deprecatedSuggestion;
b.documentationSymbol = binding.documentationSymbol;
b.location = binding.location;
b.typeId = FFlag::LuauDoNotClonePersistentBindings && binding.typeId->persistent ? binding.typeId : cloner.clone(binding.typeId);
return b;
}
} // namespace Luau } // namespace Luau

View file

@ -3,6 +3,8 @@
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAG(DebugLuauGreedyGeneralization)
namespace Luau namespace Luau
{ {
@ -111,6 +113,11 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{ {
rci.traverse(fchc->argsPack); rci.traverse(fchc->argsPack);
} }
else if (auto fcc = get<FunctionCallConstraint>(*this); fcc && FFlag::DebugLuauGreedyGeneralization)
{
rci.traverse(fcc->fn);
rci.traverse(fcc->argsPack);
}
else if (auto ptc = get<PrimitiveTypeConstraint>(*this)) else if (auto ptc = get<PrimitiveTypeConstraint>(*this))
{ {
rci.traverse(ptc->freeType); rci.traverse(ptc->freeType);
@ -118,7 +125,8 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
else if (auto hpc = get<HasPropConstraint>(*this)) else if (auto hpc = get<HasPropConstraint>(*this))
{ {
rci.traverse(hpc->resultType); rci.traverse(hpc->resultType);
// `HasPropConstraints` should not mutate `subjectType`. if (FFlag::DebugLuauGreedyGeneralization)
rci.traverse(hpc->subjectType);
} }
else if (auto hic = get<HasIndexerConstraint>(*this)) else if (auto hic = get<HasIndexerConstraint>(*this))
{ {
@ -146,6 +154,10 @@ DenseHashSet<TypeId> Constraint::getMaybeMutatedFreeTypes() const
{ {
rci.traverse(rpc->tp); rci.traverse(rpc->tp);
} }
else if (auto tcc = get<TableCheckConstraint>(*this))
{
rci.traverse(tcc->exprType);
}
return types; return types;
} }

File diff suppressed because it is too large Load diff

View file

@ -27,17 +27,20 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
LUAU_FASTFLAGVARIABLE(DebugLuauAssertOnForcedConstraint)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies)
LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer)
LUAU_FASTFLAG(LuauUserTypeFunNoExtraConstraint)
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope)
LUAU_FASTFLAGVARIABLE(LuauAlwaysFillInFunctionCallDiscriminantTypes) LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTablesOnScope)
LUAU_FASTFLAGVARIABLE(LuauPrecalculateMutatedFreeTypes2)
LUAU_FASTFLAGVARIABLE(DebugLuauGreedyGeneralization)
LUAU_FASTFLAG(LuauSearchForRefineableType)
LUAU_FASTFLAG(LuauDeprecatedAttribute)
LUAU_FASTFLAG(LuauBidirectionalInferenceCollectIndexerTypes)
LUAU_FASTFLAG(LuauNewTypeFunReductionChecks2)
namespace Luau namespace Luau
{ {
@ -328,6 +331,7 @@ ConstraintSolver::ConstraintSolver(
NotNull<TypeFunctionRuntime> typeFunctionRuntime, NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> rootScope, NotNull<Scope> rootScope,
std::vector<NotNull<Constraint>> constraints, std::vector<NotNull<Constraint>> constraints,
NotNull<DenseHashMap<Scope*, TypeId>> scopeToFunction,
ModuleName moduleName, ModuleName moduleName,
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
std::vector<RequireCycle> requireCycles, std::vector<RequireCycle> requireCycles,
@ -341,6 +345,7 @@ ConstraintSolver::ConstraintSolver(
, simplifier(simplifier) , simplifier(simplifier)
, typeFunctionRuntime(typeFunctionRuntime) , typeFunctionRuntime(typeFunctionRuntime)
, constraints(std::move(constraints)) , constraints(std::move(constraints))
, scopeToFunction(scopeToFunction)
, rootScope(rootScope) , rootScope(rootScope)
, currentModuleName(std::move(moduleName)) , currentModuleName(std::move(moduleName))
, dfg(dfg) , dfg(dfg)
@ -355,13 +360,33 @@ ConstraintSolver::ConstraintSolver(
{ {
unsolvedConstraints.emplace_back(c); unsolvedConstraints.emplace_back(c);
// initialize the reference counts for the free types in this constraint. if (FFlag::LuauPrecalculateMutatedFreeTypes2)
for (auto ty : c->getMaybeMutatedFreeTypes())
{ {
// increment the reference count for `ty` auto maybeMutatedTypesPerConstraint = c->getMaybeMutatedFreeTypes();
auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); for (auto ty : maybeMutatedTypesPerConstraint)
refCount += 1; {
auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0);
refCount += 1;
if (FFlag::DebugLuauGreedyGeneralization)
{
auto [it, fresh] = mutatedFreeTypeToConstraint.try_emplace(ty, DenseHashSet<const Constraint*>{nullptr});
it->second.insert(c.get());
}
}
maybeMutatedFreeTypes.emplace(c, maybeMutatedTypesPerConstraint);
} }
else
{
// initialize the reference counts for the free types in this constraint.
for (auto ty : c->getMaybeMutatedFreeTypes())
{
// increment the reference count for `ty`
auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0);
refCount += 1;
}
}
for (NotNull<const Constraint> dep : c->dependencies) for (NotNull<const Constraint> dep : c->dependencies)
{ {
@ -439,6 +464,9 @@ void ConstraintSolver::run()
snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints); snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints);
} }
if (FFlag::DebugLuauAssertOnForcedConstraint)
LUAU_ASSERT(!force);
bool success = tryDispatch(c, force); bool success = tryDispatch(c, force);
progress |= success; progress |= success;
@ -448,20 +476,60 @@ void ConstraintSolver::run()
unblock(c); unblock(c);
unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i)); unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i));
// decrement the referenced free types for this constraint if we dispatched successfully! if (FFlag::LuauPrecalculateMutatedFreeTypes2)
for (auto ty : c->getMaybeMutatedFreeTypes())
{ {
size_t& refCount = unresolvedConstraints[ty]; const auto maybeMutated = maybeMutatedFreeTypes.find(c);
if (refCount > 0) if (maybeMutated != maybeMutatedFreeTypes.end())
refCount -= 1; {
DenseHashSet<TypeId> seen{nullptr};
for (auto ty : maybeMutated->second)
{
// There is a high chance that this type has been rebound
// across blocked types, rebound free types, pending
// expansion types, etc, so we need to follow it.
ty = follow(ty);
// We have two constraints that are designed to wait for the if (FFlag::DebugLuauGreedyGeneralization)
// refCount on a free type to be equal to 1: the {
// PrimitiveTypeConstraint and ReduceConstraint. We if (seen.contains(ty))
// therefore wake any constraint waiting for a free type's continue;
// refcount to be 1 or 0. seen.insert(ty);
if (refCount <= 1) }
unblock(ty, Location{});
size_t& refCount = unresolvedConstraints[ty];
if (refCount > 0)
refCount -= 1;
// We have two constraints that are designed to wait for the
// refCount on a free type to be equal to 1: the
// PrimitiveTypeConstraint and ReduceConstraint. We
// therefore wake any constraint waiting for a free type's
// refcount to be 1 or 0.
if (refCount <= 1)
unblock(ty, Location{});
if (FFlag::DebugLuauGreedyGeneralization && refCount == 0)
generalizeOneType(ty);
}
}
}
else
{
// decrement the referenced free types for this constraint if we dispatched successfully!
for (auto ty : c->getMaybeMutatedFreeTypes())
{
size_t& refCount = unresolvedConstraints[ty];
if (refCount > 0)
refCount -= 1;
// We have two constraints that are designed to wait for the
// refCount on a free type to be equal to 1: the
// PrimitiveTypeConstraint and ReduceConstraint. We
// therefore wake any constraint waiting for a free type's
// refcount to be 1 or 0.
if (refCount <= 1)
unblock(ty, Location{});
}
} }
if (logger) if (logger)
@ -558,16 +626,152 @@ bool ConstraintSolver::isDone() const
return unsolvedConstraints.empty(); return unsolvedConstraints.empty();
} }
namespace struct TypeSearcher : TypeVisitor
{ {
TypeId needle;
Polarity current = Polarity::Positive;
struct TypeAndLocation size_t count = 0;
{ Polarity result = Polarity::None;
TypeId typeId;
Location location; explicit TypeSearcher(TypeId needle)
: TypeSearcher(needle, Polarity::Positive)
{}
explicit TypeSearcher(TypeId needle, Polarity initialPolarity)
: needle(needle)
, current(initialPolarity)
{}
bool visit(TypeId ty) override
{
if (ty == needle)
{
++count;
result = Polarity(size_t(result) | size_t(current));
}
return true;
}
void flip()
{
switch (current)
{
case Polarity::Positive:
current = Polarity::Negative;
break;
case Polarity::Negative:
current = Polarity::Positive;
break;
default:
break;
}
}
bool visit(TypeId ty, const FunctionType& ft) override
{
flip();
traverse(ft.argTypes);
flip();
traverse(ft.retTypes);
return false;
}
// bool visit(TypeId ty, const TableType& tt) override
// {
// }
bool visit(TypeId ty, const ClassType&) override
{
return false;
}
}; };
} // namespace void ConstraintSolver::generalizeOneType(TypeId ty)
{
ty = follow(ty);
const FreeType* freeTy = get<FreeType>(ty);
std::string saveme = toString(ty, opts);
// Some constraints (like prim) will also replace a free type with something
// concrete. If so, our work is already done.
if (!freeTy)
return;
NotNull<Scope> tyScope{freeTy->scope};
// TODO: If freeTy occurs within the enclosing function's type, we need to
// check to see whether this type should instead be generic.
TypeId newBound = follow(freeTy->upperBound);
TypeId* functionTyPtr = nullptr;
while (true)
{
functionTyPtr = scopeToFunction->find(tyScope);
if (functionTyPtr || !tyScope->parent)
break;
else if (tyScope->parent)
tyScope = NotNull{tyScope->parent.get()};
else
break;
}
if (ty == newBound)
ty = builtinTypes->unknownType;
if (!functionTyPtr)
{
asMutable(ty)->reassign(Type{BoundType{follow(freeTy->upperBound)}});
}
else
{
const TypeId functionTy = follow(*functionTyPtr);
FunctionType* const function = getMutable<FunctionType>(functionTy);
LUAU_ASSERT(function);
TypeSearcher ts{ty};
ts.traverse(functionTy);
const TypeId upperBound = follow(freeTy->upperBound);
const TypeId lowerBound = follow(freeTy->lowerBound);
switch (ts.result)
{
case Polarity::None:
asMutable(ty)->reassign(Type{BoundType{upperBound}});
break;
case Polarity::Negative:
case Polarity::Mixed:
if (get<UnknownType>(upperBound) && ts.count > 1)
{
asMutable(ty)->reassign(Type{GenericType{tyScope}});
function->generics.emplace_back(ty);
}
else
asMutable(ty)->reassign(Type{BoundType{upperBound}});
break;
case Polarity::Positive:
if (get<UnknownType>(lowerBound) && ts.count > 1)
{
asMutable(ty)->reassign(Type{GenericType{tyScope}});
function->generics.emplace_back(ty);
}
else
asMutable(ty)->reassign(Type{BoundType{lowerBound}});
break;
default:
LUAU_ASSERT(!"Unreachable");
}
}
}
void ConstraintSolver::bind(NotNull<const Constraint> constraint, TypeId ty, TypeId boundTo) void ConstraintSolver::bind(NotNull<const Constraint> constraint, TypeId ty, TypeId boundTo)
{ {
@ -642,6 +846,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*fcc, constraint); success = tryDispatch(*fcc, constraint);
else if (auto fcc = get<FunctionCheckConstraint>(*constraint)) else if (auto fcc = get<FunctionCheckConstraint>(*constraint))
success = tryDispatch(*fcc, constraint); success = tryDispatch(*fcc, constraint);
else if (auto tcc = get<TableCheckConstraint>(*constraint))
success = tryDispatch(*tcc, constraint);
else if (auto fcc = get<PrimitiveTypeConstraint>(*constraint)) else if (auto fcc = get<PrimitiveTypeConstraint>(*constraint))
success = tryDispatch(*fcc, constraint); success = tryDispatch(*fcc, constraint);
else if (auto hpc = get<HasPropConstraint>(*constraint)) else if (auto hpc = get<HasPropConstraint>(*constraint))
@ -699,26 +905,25 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<co
else if (get<PendingExpansionType>(generalizedType)) else if (get<PendingExpansionType>(generalizedType))
return block(generalizedType, constraint); return block(generalizedType, constraint);
std::optional<QuantifierResult> generalized;
std::optional<TypeId> generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType); std::optional<TypeId> generalizedTy = generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, c.sourceType);
if (generalizedTy) if (!generalizedTy)
generalized = QuantifierResult{*generalizedTy}; // FIXME insertedGenerics and insertedGenericPacks
else
reportError(CodeTooComplex{}, constraint->location); reportError(CodeTooComplex{}, constraint->location);
if (generalized) if (generalizedTy)
{ {
if (get<BlockedType>(generalizedType)) if (get<BlockedType>(generalizedType))
bind(constraint, generalizedType, generalized->result); bind(constraint, generalizedType, *generalizedTy);
else else
unify(constraint, generalizedType, generalized->result); unify(constraint, generalizedType, *generalizedTy);
for (auto [free, gen] : generalized->insertedGenerics.pairings) if (FFlag::LuauDeprecatedAttribute)
unify(constraint, free, gen); {
if (FunctionType* fty = getMutable<FunctionType>(follow(generalizedType)))
for (auto [free, gen] : generalized->insertedGenericPacks.pairings) {
unify(constraint, free, gen); if (c.hasDeprecatedAttribute)
fty->isDeprecatedFunction = true;
}
}
} }
else else
{ {
@ -732,12 +937,12 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<co
// clang-tidy doesn't understand this is safe. // clang-tidy doesn't understand this is safe.
if (constraint->scope->interiorFreeTypes) if (constraint->scope->interiorFreeTypes)
for (TypeId ty : *constraint->scope->interiorFreeTypes) // NOLINT(bugprone-unchecked-optional-access) for (TypeId ty : *constraint->scope->interiorFreeTypes) // NOLINT(bugprone-unchecked-optional-access)
generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty);
} }
else else
{ {
for (TypeId ty : c.interiorTypes) for (TypeId ty : c.interiorTypes)
generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty);
} }
@ -823,6 +1028,9 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull<const Co
TypeId tableTy = TypeId tableTy =
arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free});
if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope)
trackInteriorFreeType(constraint->scope, tableTy);
unify(constraint, nextTy, tableTy); unify(constraint, nextTy, tableTy);
auto it = begin(c.variables); auto it = begin(c.variables);
@ -959,16 +1167,6 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
if (auto typeFn = get<TypeFunctionInstanceType>(follow(tf->type))) if (auto typeFn = get<TypeFunctionInstanceType>(follow(tf->type)))
pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type}); pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type});
if (!FFlag::LuauUserTypeFunNoExtraConstraint)
{
// If there are no parameters to the type function we can just use the type directly
if (tf->typeParams.empty() && tf->typePackParams.empty())
{
bindResult(tf->type);
return true;
}
}
// Due to how pending expansion types and TypeFun's are created // Due to how pending expansion types and TypeFun's are created
// If this check passes, we have created a cyclic / corecursive type alias // If this check passes, we have created a cyclic / corecursive type alias
// of size 0 // of size 0
@ -981,14 +1179,11 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true; return true;
} }
if (FFlag::LuauUserTypeFunNoExtraConstraint) // If there are no parameters to the type function we can just use the type directly
if (tf->typeParams.empty() && tf->typePackParams.empty())
{ {
// If there are no parameters to the type function we can just use the type directly bindResult(tf->type);
if (tf->typeParams.empty() && tf->typePackParams.empty()) return true;
{
bindResult(tf->type);
return true;
}
} }
auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments);
@ -1136,12 +1331,9 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
target = follow(instantiated); target = follow(instantiated);
} }
if (FFlag::LuauNewSolverPopulateTableLocations) // This is a new type - redefine the location.
{ ttv->definitionLocation = constraint->location;
// This is a new type - redefine the location. ttv->definitionModuleName = currentModuleName;
ttv->definitionLocation = constraint->location;
ttv->definitionModuleName = currentModuleName;
}
ttv->instantiatedTypeParams = typeArguments; ttv->instantiatedTypeParams = typeArguments;
ttv->instantiatedTypePackParams = packArguments; ttv->instantiatedTypePackParams = packArguments;
@ -1154,38 +1346,35 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
return true; return true;
} }
void ConstraintSolver::fillInDiscriminantTypes( void ConstraintSolver::fillInDiscriminantTypes(NotNull<const Constraint> constraint, const std::vector<std::optional<TypeId>>& discriminantTypes)
NotNull<const Constraint> constraint,
const std::vector<std::optional<TypeId>>& discriminantTypes
)
{ {
for (std::optional<TypeId> ty : discriminantTypes) for (std::optional<TypeId> ty : discriminantTypes)
{ {
if (!ty) if (!ty)
continue; continue;
// If the discriminant type has been transmuted, we need to unblock them. if (FFlag::LuauSearchForRefineableType)
if (!isBlocked(*ty))
{ {
unblock(*ty, constraint->location); if (isBlocked(*ty))
continue; // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored.
} emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
if (FFlag::LuauRemoveNotAnyHack) // We also need to unconditionally unblock these types, otherwise
{ // you end up with funky looking "Blocked on *no-refine*."
// We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. unblock(*ty, constraint->location);
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
} }
else else
{ {
// We use `any` here because the discriminant type may be pointed at by both branches,
// where the discriminant type is not negated, and the other where it is negated, i.e. // If the discriminant type has been transmuted, we need to unblock them.
// `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` if (!isBlocked(*ty))
// v.s. {
// `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` unblock(*ty, constraint->location);
// continue;
// In practice, users cannot negate `any`, so this is an implementation detail we can always change. }
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->anyType);
// We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
} }
} }
} }
@ -1196,17 +1385,24 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
TypePackId argsPack = follow(c.argsPack); TypePackId argsPack = follow(c.argsPack);
TypePackId result = follow(c.result); TypePackId result = follow(c.result);
if (isBlocked(fn) || hasUnresolvedConstraints(fn)) if (FFlag::DebugLuauGreedyGeneralization)
{ {
return block(c.fn, constraint); if (isBlocked(fn))
return block(c.fn, constraint);
}
else
{
if (isBlocked(fn) || hasUnresolvedConstraints(fn))
{
return block(c.fn, constraint);
}
} }
if (get<AnyType>(fn)) if (get<AnyType>(fn))
{ {
emplaceTypePack<BoundTypePack>(asMutable(c.result), builtinTypes->anyTypePack); emplaceTypePack<BoundTypePack>(asMutable(c.result), builtinTypes->anyTypePack);
unblock(c.result, constraint->location); unblock(c.result, constraint->location);
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) fillInDiscriminantTypes(constraint, c.discriminantTypes);
fillInDiscriminantTypes(constraint, c.discriminantTypes);
return true; return true;
} }
@ -1214,16 +1410,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (get<ErrorType>(fn)) if (get<ErrorType>(fn))
{ {
bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); bind(constraint, c.result, builtinTypes->errorRecoveryTypePack());
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) fillInDiscriminantTypes(constraint, c.discriminantTypes);
fillInDiscriminantTypes(constraint, c.discriminantTypes);
return true; return true;
} }
if (get<NeverType>(fn)) if (get<NeverType>(fn))
{ {
bind(constraint, c.result, builtinTypes->neverTypePack); bind(constraint, c.result, builtinTypes->neverTypePack);
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) fillInDiscriminantTypes(constraint, c.discriminantTypes);
fillInDiscriminantTypes(constraint, c.discriminantTypes);
return true; return true;
} }
@ -1304,44 +1498,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
emplace<FreeTypePack>(constraint, c.result, constraint->scope); emplace<FreeTypePack>(constraint, c.result, constraint->scope);
} }
if (FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes) fillInDiscriminantTypes(constraint, c.discriminantTypes);
{
fillInDiscriminantTypes(constraint, c.discriminantTypes);
}
else
{
// NOTE: This is the body of the `fillInDiscriminantTypes` helper.
for (std::optional<TypeId> ty : c.discriminantTypes)
{
if (!ty)
continue;
// If the discriminant type has been transmuted, we need to unblock them.
if (!isBlocked(*ty))
{
unblock(*ty, constraint->location);
continue;
}
if (FFlag::LuauRemoveNotAnyHack)
{
// We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->noRefineType);
}
else
{
// We use `any` here because the discriminant type may be pointed at by both branches,
// where the discriminant type is not negated, and the other where it is negated, i.e.
// `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never`
// v.s.
// `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T`
//
// In practice, users cannot negate `any`, so this is an implementation detail we can always change.
emplaceType<BoundType>(asMutable(follow(*ty)), builtinTypes->anyType);
}
}
}
OverloadResolver resolver{ OverloadResolver resolver{
builtinTypes, builtinTypes,
@ -1407,6 +1564,43 @@ static AstExpr* unwrapGroup(AstExpr* expr)
return expr; return expr;
} }
struct ContainsGenerics : public TypeOnceVisitor
{
DenseHashSet<const void*> generics{nullptr};
bool found = false;
bool visit(TypeId ty) override
{
return !found;
}
bool visit(TypeId ty, const GenericType&) override
{
found |= generics.contains(ty);
return true;
}
bool visit(TypeId ty, const TypeFunctionInstanceType&) override
{
return !found;
}
bool visit(TypePackId tp, const GenericTypePack&) override
{
found |= generics.contains(tp);
return !found;
}
bool hasGeneric(TypeId ty)
{
traverse(ty);
auto ret = found;
found = false;
return ret;
}
};
bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint) bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<const Constraint> constraint)
{ {
TypeId fn = follow(c.fn); TypeId fn = follow(c.fn);
@ -1449,36 +1643,49 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
DenseHashMap<TypeId, TypeId> replacements{nullptr}; DenseHashMap<TypeId, TypeId> replacements{nullptr};
DenseHashMap<TypePackId, TypePackId> replacementPacks{nullptr}; DenseHashMap<TypePackId, TypePackId> replacementPacks{nullptr};
ContainsGenerics containsGenerics;
for (auto generic : ftv->generics) for (auto generic : ftv->generics)
{
replacements[generic] = builtinTypes->unknownType; replacements[generic] = builtinTypes->unknownType;
if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
containsGenerics.generics.insert(generic);
}
for (auto genericPack : ftv->genericPacks) for (auto genericPack : ftv->genericPacks)
{
replacementPacks[genericPack] = builtinTypes->unknownTypePack; replacementPacks[genericPack] = builtinTypes->unknownTypePack;
if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
containsGenerics.generics.insert(genericPack);
}
// If the type of the function has generics, we don't actually want to push any of the generics themselves // If the type of the function has generics, we don't actually want to push any of the generics themselves
// into the argument types as expected types because this creates an unnecessary loop. Instead, we want to // into the argument types as expected types because this creates an unnecessary loop. Instead, we want to
// replace these types with `unknown` (and `...unknown`) to keep any structure but not create the cycle. // replace these types with `unknown` (and `...unknown`) to keep any structure but not create the cycle.
if (!replacements.empty() || !replacementPacks.empty()) if (!FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
{ {
Replacer replacer{arena, std::move(replacements), std::move(replacementPacks)}; if (!replacements.empty() || !replacementPacks.empty())
std::optional<TypeId> res = replacer.substitute(fn);
if (res)
{ {
if (*res != fn) Replacer replacer{arena, std::move(replacements), std::move(replacementPacks)};
std::optional<TypeId> res = replacer.substitute(fn);
if (res)
{ {
FunctionType* ftvMut = getMutable<FunctionType>(*res); if (*res != fn)
LUAU_ASSERT(ftvMut); {
ftvMut->generics.clear(); FunctionType* ftvMut = getMutable<FunctionType>(*res);
ftvMut->genericPacks.clear(); LUAU_ASSERT(ftvMut);
ftvMut->generics.clear();
ftvMut->genericPacks.clear();
}
fn = *res;
ftv = get<FunctionType>(*res);
LUAU_ASSERT(ftv);
// we've potentially copied type functions here, so we need to reproduce their reduce constraint.
reproduceConstraints(constraint->scope, constraint->location, replacer);
} }
fn = *res;
ftv = get<FunctionType>(*res);
LUAU_ASSERT(ftv);
// we've potentially copied type functions here, so we need to reproduce their reduce constraint.
reproduceConstraints(constraint->scope, constraint->location, replacer);
} }
} }
@ -1497,6 +1704,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
(*c.astExpectedTypes)[expr] = expectedArgTy; (*c.astExpectedTypes)[expr] = expectedArgTy;
// Generic types are skipped over entirely, for now.
if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes && containsGenerics.hasGeneric(expectedArgTy))
continue;
const FunctionType* expectedLambdaTy = get<FunctionType>(expectedArgTy); const FunctionType* expectedLambdaTy = get<FunctionType>(expectedArgTy);
const FunctionType* lambdaTy = get<FunctionType>(actualArgTy); const FunctionType* lambdaTy = get<FunctionType>(actualArgTy);
const AstExprFunction* lambdaExpr = expr->as<AstExprFunction>(); const AstExprFunction* lambdaExpr = expr->as<AstExprFunction>();
@ -1524,8 +1735,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
else if (expr->is<AstExprTable>()) else if (expr->is<AstExprTable>())
{ {
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
Subtyping sp{builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, NotNull{&iceReporter}};
std::vector<TypeId> toBlock; std::vector<TypeId> toBlock;
(void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); (void)matchLiteralType(
c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, NotNull{&sp}, expectedArgTy, actualArgTy, expr, toBlock
);
LUAU_ASSERT(toBlock.empty()); LUAU_ASSERT(toBlock.empty());
} }
} }
@ -1533,6 +1747,29 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
return true; return true;
} }
bool ConstraintSolver::tryDispatch(const TableCheckConstraint& c, NotNull<const Constraint> constraint)
{
// This is expensive as we need to traverse a (potentially large)
// literal up front in order to determine if there are any blocked
// types, otherwise we may run `matchTypeLiteral` multiple times,
// which right now may fail due to being non-idempotent (it
// destructively updates the underlying literal type).
auto blockedTypes = findBlockedTypesIn(c.table, c.astTypes);
for (const auto ty : blockedTypes)
{
block(ty, constraint);
}
if (!blockedTypes.empty())
return false;
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
Subtyping sp{builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, NotNull{&iceReporter}};
std::vector<TypeId> toBlock;
(void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, NotNull{&sp}, c.expectedType, c.exprType, c.table, toBlock);
LUAU_ASSERT(toBlock.empty());
return true;
}
bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint) bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint)
{ {
std::optional<TypeId> expectedType = c.expectedType ? std::make_optional<TypeId>(follow(*c.expectedType)) : std::nullopt; std::optional<TypeId> expectedType = c.expectedType ? std::make_optional<TypeId>(follow(*c.expectedType)) : std::nullopt;
@ -1854,6 +2091,10 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
else else
{ {
TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope});
if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope)
trackInteriorFreeType(constraint->scope, newUpperBound);
TableType* upperTable = getMutable<TableType>(newUpperBound); TableType* upperTable = getMutable<TableType>(newUpperBound);
LUAU_ASSERT(upperTable); LUAU_ASSERT(upperTable);
@ -1883,7 +2124,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const
bind( bind(
constraint, constraint,
c.propType, c.propType,
isIndex && FFlag::LuauAllowNilAssignmentToIndexer ? arena->addType(UnionType{{propTy, builtinTypes->nilType}}) : propTy isIndex ? arena->addType(UnionType{{propTy, builtinTypes->nilType}}) : propTy
); );
unify(constraint, rhsType, propTy); unify(constraint, rhsType, propTy);
return true; return true;
@ -1981,8 +2222,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
bind( bind(
constraint, constraint,
c.propType, c.propType,
FFlag::LuauAllowNilAssignmentToIndexer ? arena->addType(UnionType{{lhsTable->indexer->indexResultType, builtinTypes->nilType}}) arena->addType(UnionType{{lhsTable->indexer->indexResultType, builtinTypes->nilType}})
: lhsTable->indexer->indexResultType
); );
return true; return true;
} }
@ -2035,8 +2275,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNull<const
bind( bind(
constraint, constraint,
c.propType, c.propType,
FFlag::LuauAllowNilAssignmentToIndexer ? arena->addType(UnionType{{lhsClass->indexer->indexResultType, builtinTypes->nilType}}) arena->addType(UnionType{{lhsClass->indexer->indexResultType, builtinTypes->nilType}})
: lhsClass->indexer->indexResultType
); );
return true; return true;
} }
@ -2180,11 +2419,18 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull<const Cons
for (TypePackId r : result.reducedPacks) for (TypePackId r : result.reducedPacks)
unblock(r, constraint->location); unblock(r, constraint->location);
if (FFlag::LuauNewTypeFunReductionChecks2)
{
for (TypeId ity : result.irreducibleTypes)
uninhabitedTypeFunctions.insert(ity);
}
bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty(); bool reductionFinished = result.blockedTypes.empty() && result.blockedPacks.empty();
ty = follow(ty); ty = follow(ty);
// If we couldn't reduce this type function, stick it in the set! // If we couldn't reduce this type function, stick it in the set!
if (get<TypeFunctionInstanceType>(ty)) if (get<TypeFunctionInstanceType>(ty) && (!FFlag::LuauNewTypeFunReductionChecks2 || !result.irreducibleTypes.find(ty)))
typeFunctionsToFinalize[ty] = constraint; typeFunctionsToFinalize[ty] = constraint;
if (force || reductionFinished) if (force || reductionFinished)
@ -2637,6 +2883,10 @@ TablePropLookupResult ConstraintSolver::lookupTableProp(
NotNull<Scope> scope{ft->scope}; NotNull<Scope> scope{ft->scope};
const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope}); const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope});
if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope)
trackInteriorFreeType(constraint->scope, newUpperBound);
TableType* tt = getMutable<TableType>(newUpperBound); TableType* tt = getMutable<TableType>(newUpperBound);
LUAU_ASSERT(tt); LUAU_ASSERT(tt);
TypeId propType = freshType(arena, builtinTypes, scope); TypeId propType = freshType(arena, builtinTypes, scope);
@ -3080,9 +3330,27 @@ void ConstraintSolver::shiftReferences(TypeId source, TypeId target)
auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0); auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0);
targetRefs += count; targetRefs += count;
// Any constraint that might have mutated source may now mutate target
if (FFlag::DebugLuauGreedyGeneralization)
{
auto it = mutatedFreeTypeToConstraint.find(source);
if (it != mutatedFreeTypeToConstraint.end())
{
auto [it2, fresh] = mutatedFreeTypeToConstraint.try_emplace(target, DenseHashSet<const Constraint*>{nullptr});
for (const Constraint* constraint : it->second)
{
it2->second.insert(constraint);
auto [it3, fresh2] = maybeMutatedFreeTypes.try_emplace(NotNull{constraint}, DenseHashSet<TypeId>{nullptr});
it3->second.insert(target);
}
}
}
} }
std::optional<TypeId> ConstraintSolver::generalizeFreeType(NotNull<Scope> scope, TypeId type, bool avoidSealingTables) std::optional<TypeId> ConstraintSolver::generalizeFreeType(NotNull<Scope> scope, TypeId type)
{ {
TypeId t = follow(type); TypeId t = follow(type);
if (get<FreeType>(t)) if (get<FreeType>(t))
@ -3097,7 +3365,7 @@ std::optional<TypeId> ConstraintSolver::generalizeFreeType(NotNull<Scope> scope,
// that until all constraint generation is complete. // that until all constraint generation is complete.
} }
return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type, avoidSealingTables); return generalize(NotNull{arena}, builtinTypes, scope, generalizedTypes, type);
} }
bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty)

View file

@ -13,7 +13,6 @@
LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauTypestateBuiltins2)
namespace Luau namespace Luau
{ {
@ -83,12 +82,6 @@ std::optional<DefId> DataFlowGraph::getDefOptional(const AstExpr* expr) const
return NotNull{*def}; return NotNull{*def};
} }
std::optional<DefId> DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const
{
auto def = compoundAssignDefs.find(expr);
return def ? std::optional<DefId>(*def) : std::nullopt;
}
DefId DataFlowGraph::getDef(const AstLocal* local) const DefId DataFlowGraph::getDef(const AstLocal* local) const
{ {
auto def = localDefs.find(local); auto def = localDefs.find(local);
@ -879,7 +872,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c)
{ {
visitExpr(c->func); visitExpr(c->func);
if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) if (shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin()))
{ {
AstExpr* firstArg = *c->args.begin(); AstExpr* firstArg = *c->args.begin();
@ -912,8 +905,17 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c)
for (AstExpr* arg : c->args) for (AstExpr* arg : c->args)
visitExpr(arg); visitExpr(arg);
// calls should be treated as subscripted. // We treat function calls as "subscripted" as they could potentially
return {defArena->freshCell(/* subscripted */ true), nullptr}; // return a subscripted value, consider:
//
// local function foo(tbl: {[string]: woof)
// return tbl["foobarbaz"]
// end
//
// local v = foo({})
//
// We want to consider `v` to be subscripted here.
return {defArena->freshCell(/*subscripted=*/true)};
} }
DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i) DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i)
@ -1160,6 +1162,8 @@ void DataFlowGraphBuilder::visitType(AstType* t)
return visitType(f); return visitType(f);
else if (auto tyof = t->as<AstTypeTypeof>()) else if (auto tyof = t->as<AstTypeTypeof>())
return visitType(tyof); return visitType(tyof);
else if (auto o = t->as<AstTypeOptional>())
return;
else if (auto u = t->as<AstTypeUnion>()) else if (auto u = t->as<AstTypeUnion>())
return visitType(u); return visitType(u);
else if (auto i = t->as<AstTypeIntersection>()) else if (auto i = t->as<AstTypeIntersection>())
@ -1170,6 +1174,8 @@ void DataFlowGraphBuilder::visitType(AstType* t)
return; // ok return; // ok
else if (auto s = t->as<AstTypeSingletonString>()) else if (auto s = t->as<AstTypeSingletonString>())
return; // ok return; // ok
else if (auto g = t->as<AstTypeGroup>())
return visitType(g->type);
else else
handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType");
} }
@ -1259,21 +1265,21 @@ void DataFlowGraphBuilder::visitTypeList(AstTypeList l)
visitTypePack(l.tailType); visitTypePack(l.tailType);
} }
void DataFlowGraphBuilder::visitGenerics(AstArray<AstGenericType> g) void DataFlowGraphBuilder::visitGenerics(AstArray<AstGenericType*> g)
{ {
for (AstGenericType generic : g) for (AstGenericType* generic : g)
{ {
if (generic.defaultValue) if (generic->defaultValue)
visitType(generic.defaultValue); visitType(generic->defaultValue);
} }
} }
void DataFlowGraphBuilder::visitGenericPacks(AstArray<AstGenericTypePack> g) void DataFlowGraphBuilder::visitGenericPacks(AstArray<AstGenericTypePack*> g)
{ {
for (AstGenericTypePack generic : g) for (AstGenericTypePack* generic : g)
{ {
if (generic.defaultValue) if (generic->defaultValue)
visitTypePack(generic.defaultValue); visitTypePack(generic->defaultValue);
} }
} }

View file

@ -1,207 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauVectorDefinitionsExtra)
LUAU_FASTFLAG(LuauBufferBitMethods2)
LUAU_FASTFLAGVARIABLE(LuauMathMapDefinition)
LUAU_FASTFLAG(LuauVector2Constructor)
namespace Luau namespace Luau
{ {
static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC(
declare bit32: {
band: @checked (...number) -> number,
bor: @checked (...number) -> number,
bxor: @checked (...number) -> number,
btest: @checked (number, ...number) -> boolean,
rrotate: @checked (x: number, disp: number) -> number,
lrotate: @checked (x: number, disp: number) -> number,
lshift: @checked (x: number, disp: number) -> number,
arshift: @checked (x: number, disp: number) -> number,
rshift: @checked (x: number, disp: number) -> number,
bnot: @checked (x: number) -> number,
extract: @checked (n: number, field: number, width: number?) -> number,
replace: @checked (n: number, v: number, field: number, width: number?) -> number,
countlz: @checked (n: number) -> number,
countrz: @checked (n: number) -> number,
byteswap: @checked (n: number) -> number,
}
declare math: {
frexp: @checked (n: number) -> (number, number),
ldexp: @checked (s: number, e: number) -> number,
fmod: @checked (x: number, y: number) -> number,
modf: @checked (n: number) -> (number, number),
pow: @checked (x: number, y: number) -> number,
exp: @checked (n: number) -> number,
ceil: @checked (n: number) -> number,
floor: @checked (n: number) -> number,
abs: @checked (n: number) -> number,
sqrt: @checked (n: number) -> number,
log: @checked (n: number, base: number?) -> number,
log10: @checked (n: number) -> number,
rad: @checked (n: number) -> number,
deg: @checked (n: number) -> number,
sin: @checked (n: number) -> number,
cos: @checked (n: number) -> number,
tan: @checked (n: number) -> number,
sinh: @checked (n: number) -> number,
cosh: @checked (n: number) -> number,
tanh: @checked (n: number) -> number,
atan: @checked (n: number) -> number,
acos: @checked (n: number) -> number,
asin: @checked (n: number) -> number,
atan2: @checked (y: number, x: number) -> number,
min: @checked (number, ...number) -> number,
max: @checked (number, ...number) -> number,
pi: number,
huge: number,
randomseed: @checked (seed: number) -> (),
random: @checked (number?, number?) -> number,
sign: @checked (n: number) -> number,
clamp: @checked (n: number, min: number, max: number) -> number,
noise: @checked (x: number, y: number?, z: number?) -> number,
round: @checked (n: number) -> number,
map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number,
}
type DateTypeArg = {
year: number,
month: number,
day: number,
hour: number?,
min: number?,
sec: number?,
isdst: boolean?,
}
type DateTypeResult = {
year: number,
month: number,
wday: number,
yday: number,
day: number,
hour: number,
min: number,
sec: number,
isdst: boolean,
}
declare os: {
time: (time: DateTypeArg?) -> number,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@checked declare function require(target: any): any
@checked declare function getfenv(target: any): { [string]: any }
declare _G: any
declare _VERSION: string
declare function gcinfo(): number
declare function print<T...>(...: T...)
declare function type<T>(value: T): string
declare function typeof<T>(value: T): string
-- `assert` has a magic function attached that will give more detailed type information
declare function assert<T>(value: T, errorMessage: string?): T
declare function error<T>(message: T, level: number?): never
declare function tostring<T>(value: T): string
declare function tonumber<T>(value: T, radix: number?): number?
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function rawlen<K, V>(obj: {[K]: V} | string): number
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
declare function ipairs<V>(tab: {V}): (({V}, number) -> (number?, V), {V}, number)
declare function pcall<A..., R...>(f: (A...) -> R..., ...: A...): (boolean, R...)
-- FIXME: The actual type of `xpcall` is:
-- <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...)
-- Since we can't represent the return value, we use (boolean, R1...).
declare function xpcall<E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...)
-- `select` has a magic function attached to provide more detailed type information
declare function select<A...>(i: string | number, ...: A...): ...any
-- FIXME: This type is not entirely correct - `loadstring` returns a function or
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
@checked declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended",
wrap: <A..., R...>(f: (A...) -> R...) -> ((A...) -> R...),
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: @checked (co: thread) -> (boolean, any)
}
declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
declare utf8: {
char: @checked (...number) -> string,
charpattern: string,
codes: @checked (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: @checked (str: string, i: number?, j: number?) -> ...number,
len: @checked (s: string, i: number?, j: number?) -> (number?, number?),
offset: @checked (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC( static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC(
@checked declare function require(target: any): any @checked declare function require(target: any): any
@ -405,7 +207,7 @@ declare table: {
static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC( static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC(
declare debug: { declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...), info: ((thread: thread, level: number, options: string) -> ...any) & ((level: number, options: string) -> ...any) & (<A..., R1...>(func: (A...) -> R1..., options: string) -> ...any),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
} }
@ -424,37 +226,6 @@ declare utf8: {
)BUILTIN_SRC"; )BUILTIN_SRC";
static const std::string kBuiltinDefinitionBufferSrc_DEPRECATED = R"BUILTIN_SRC(
--- Buffer API
declare buffer: {
create: @checked (size: number) -> buffer,
fromstring: @checked (str: string) -> buffer,
tostring: @checked (b: buffer) -> string,
len: @checked (b: buffer) -> number,
copy: @checked (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: @checked (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: @checked (b: buffer, offset: number) -> number,
readu8: @checked (b: buffer, offset: number) -> number,
readi16: @checked (b: buffer, offset: number) -> number,
readu16: @checked (b: buffer, offset: number) -> number,
readi32: @checked (b: buffer, offset: number) -> number,
readu32: @checked (b: buffer, offset: number) -> number,
readf32: @checked (b: buffer, offset: number) -> number,
readf64: @checked (b: buffer, offset: number) -> number,
writei8: @checked (b: buffer, offset: number, value: number) -> (),
writeu8: @checked (b: buffer, offset: number, value: number) -> (),
writei16: @checked (b: buffer, offset: number, value: number) -> (),
writeu16: @checked (b: buffer, offset: number, value: number) -> (),
writei32: @checked (b: buffer, offset: number, value: number) -> (),
writeu32: @checked (b: buffer, offset: number, value: number) -> (),
writef32: @checked (b: buffer, offset: number, value: number) -> (),
writef64: @checked (b: buffer, offset: number, value: number) -> (),
readstring: @checked (b: buffer, offset: number, count: number) -> string,
writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC( static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
--- Buffer API --- Buffer API
declare buffer: { declare buffer: {
@ -488,88 +259,6 @@ declare buffer: {
)BUILTIN_SRC"; )BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc_NoExtra_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC(
-- TODO: this will be replaced with a built-in primitive type
declare class vector end
declare vector: {
create: @checked (x: number, y: number, z: number) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc_NoExtra_DEPRECATED = R"BUILTIN_SRC(
-- TODO: this will be replaced with a built-in primitive type
declare class vector end
declare vector: {
create: @checked (x: number, y: number, z: number?) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
declare class vector
x: number
y: number
z: number
end
declare vector: {
create: @checked (x: number, y: number, z: number) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC( static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties -- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
@ -602,37 +291,98 @@ declare vector: {
std::string getBuiltinDefinitionSource() std::string getBuiltinDefinitionSource()
{ {
std::string result = FFlag::LuauMathMapDefinition ? kBuiltinDefinitionBaseSrc : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; std::string result = kBuiltinDefinitionBaseSrc;
if (FFlag::LuauMathMapDefinition) result += kBuiltinDefinitionBit32Src;
{ result += kBuiltinDefinitionMathSrc;
result += kBuiltinDefinitionBit32Src; result += kBuiltinDefinitionOsSrc;
result += kBuiltinDefinitionMathSrc; result += kBuiltinDefinitionCoroutineSrc;
result += kBuiltinDefinitionOsSrc; result += kBuiltinDefinitionTableSrc;
result += kBuiltinDefinitionCoroutineSrc; result += kBuiltinDefinitionDebugSrc;
result += kBuiltinDefinitionTableSrc; result += kBuiltinDefinitionUtf8Src;
result += kBuiltinDefinitionDebugSrc; result += kBuiltinDefinitionBufferSrc;
result += kBuiltinDefinitionUtf8Src; result += kBuiltinDefinitionVectorSrc;
}
result += FFlag::LuauBufferBitMethods2 ? kBuiltinDefinitionBufferSrc : kBuiltinDefinitionBufferSrc_DEPRECATED;
if (FFlag::LuauVectorDefinitionsExtra)
{
if (FFlag::LuauVector2Constructor)
result += kBuiltinDefinitionVectorSrc;
else
result += kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED;
}
else
{
if (FFlag::LuauVector2Constructor)
result += kBuiltinDefinitionVectorSrc_NoExtra_DEPRECATED;
else
result += kBuiltinDefinitionVectorSrc_NoExtra_NoVector2Ctor_DEPRECATED;
}
return result; return result;
} }
// TODO: split into separate tagged unions when the new solver can appropriately handle that.
static const std::string kBuiltinDefinitionTypesSrc = R"BUILTIN_SRC(
export type type = {
tag: "nil" | "unknown" | "never" | "any" | "boolean" | "number" | "string" | "buffer" | "thread" |
"singleton" | "negation" | "union" | "intesection" | "table" | "function" | "class" | "generic",
is: (self: type, arg: string) -> boolean,
-- for singleton type
value: (self: type) -> (string | boolean | nil),
-- for negation type
inner: (self: type) -> type,
-- for union and intersection types
components: (self: type) -> {type},
-- for table type
setproperty: (self: type, key: type, value: type?) -> (),
setreadproperty: (self: type, key: type, value: type?) -> (),
setwriteproperty: (self: type, key: type, value: type?) -> (),
readproperty: (self: type, key: type) -> type?,
writeproperty: (self: type, key: type) -> type?,
properties: (self: type) -> { [type]: { read: type?, write: type? } },
setindexer: (self: type, index: type, result: type) -> (),
setreadindexer: (self: type, index: type, result: type) -> (),
setwriteindexer: (self: type, index: type, result: type) -> (),
indexer: (self: type) -> { index: type, readresult: type, writeresult: type }?,
readindexer: (self: type) -> { index: type, result: type }?,
writeindexer: (self: type) -> { index: type, result: type }?,
setmetatable: (self: type, arg: type) -> (),
metatable: (self: type) -> type?,
-- for function type
setparameters: (self: type, head: {type}?, tail: type?) -> (),
parameters: (self: type) -> { head: {type}?, tail: type? },
setreturns: (self: type, head: {type}?, tail: type? ) -> (),
returns: (self: type) -> { head: {type}?, tail: type? },
setgenerics: (self: type, {type}?) -> (),
generics: (self: type) -> {type},
-- for class type
-- 'properties', 'metatable', 'indexer', 'readindexer' and 'writeindexer' are shared with table type
readparent: (self: type) -> type?,
writeparent: (self: type) -> type?,
-- for generic type
name: (self: type) -> string?,
ispack: (self: type) -> boolean,
}
declare types: {
unknown: type,
never: type,
any: type,
boolean: type,
number: type,
string: type,
thread: type,
buffer: type,
singleton: @checked (arg: string | boolean | nil) -> type,
generic: @checked (name: string, ispack: boolean?) -> type,
negationof: @checked (arg: type) -> type,
unionof: @checked (...type) -> type,
intersectionof: @checked (...type) -> type,
newtable: @checked (props: {[type]: type} | {[type]: { read: type, write: type } } | nil, indexer: { index: type, readresult: type, writeresult: type }?, metatable: type?) -> type,
newfunction: @checked (parameters: { head: {type}?, tail: type? }?, returns: { head: {type}?, tail: type? }?, generics: {type}?) -> type,
copy: @checked (arg: type) -> type,
}
)BUILTIN_SRC";
std::string getTypeFunctionDefinitionSource()
{
return kBuiltinDefinitionTypesSrc;
}
} // namespace Luau } // namespace Luau

View file

@ -92,18 +92,24 @@ size_t TTable::Hash::operator()(const TTable& value) const
return hash; return hash;
} }
uint32_t StringCache::add(std::string_view s) StringId StringCache::add(std::string_view s)
{ {
size_t hash = std::hash<std::string_view>()(s); /* Important subtlety: This use of DenseHashMap<std::string_view, StringId>
if (uint32_t* it = strings.find(hash)) * is okay because std::hash<std::string_view> works solely on the bytes
* referred by the string_view.
*
* In other words, two string views which contain the same bytes will have
* the same hash whether or not their addresses are the same.
*/
if (StringId* it = strings.find(s))
return *it; return *it;
char* storage = static_cast<char*>(allocator.allocate(s.size())); char* storage = static_cast<char*>(allocator.allocate(s.size()));
memcpy(storage, s.data(), s.size()); memcpy(storage, s.data(), s.size());
uint32_t result = uint32_t(views.size()); StringId result = StringId(views.size());
views.emplace_back(storage, s.size()); views.emplace_back(storage, s.size());
strings[hash] = result; strings[s] = result;
return result; return result;
} }
@ -390,6 +396,17 @@ Id toId(
{ {
LUAU_ASSERT(tfun->packArguments.empty()); LUAU_ASSERT(tfun->packArguments.empty());
if (tfun->userFuncName)
{
// TODO: User defined type functions are pseudo-effectful: error
// reporting is done via the `print` statement, so running a
// UDTF multiple times may end up double erroring. egraphs
// currently may induce type functions to be reduced multiple
// times. We should probably opt _not_ to process user defined
// type functions at all.
return egraph.add(TOpaque{ty});
}
std::vector<Id> parts; std::vector<Id> parts;
parts.reserve(tfun->typeArguments.size()); parts.reserve(tfun->typeArguments.size());
for (TypeId part : tfun->typeArguments) for (TypeId part : tfun->typeArguments)

View file

@ -8,6 +8,7 @@
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/TypeChecker2.h"
#include "Luau/TypeFunction.h" #include "Luau/TypeFunction.h"
#include <optional> #include <optional>
@ -17,6 +18,7 @@
#include <unordered_set> #include <unordered_set>
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
LUAU_FASTFLAG(LuauNonStrictFuncDefErrorFix)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, size_t expectedCount,
@ -116,7 +118,10 @@ struct ErrorConverter
size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength); size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength);
if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength) if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength)
return "Type " + given + " could not be converted into " + wanted; return "Type " + given + " could not be converted into " + wanted;
return "Type\n " + given + "\ncould not be converted into\n " + wanted; if (FFlag::LuauImproveTypePathsInErrors)
return "Type\n\t" + given + "\ncould not be converted into\n\t" + wanted;
else
return "Type\n " + given + "\ncould not be converted into\n " + wanted;
}; };
if (givenTypeName == wantedTypeName) if (givenTypeName == wantedTypeName)
@ -751,8 +756,15 @@ struct ErrorConverter
std::string operator()(const NonStrictFunctionDefinitionError& e) const std::string operator()(const NonStrictFunctionDefinitionError& e) const
{ {
return "Argument " + e.argument + " with type '" + toString(e.argumentType) + "' in function '" + e.functionName + if (FFlag::LuauNonStrictFuncDefErrorFix && e.functionName.empty())
"' is used in a way that will run time error"; {
return "Argument " + e.argument + " with type '" + toString(e.argumentType) + "' is used in a way that will run time error";
}
else
{
return "Argument " + e.argument + " with type '" + toString(e.argumentType) + "' in function '" + e.functionName +
"' is used in a way that will run time error";
}
} }
std::string operator()(const PropertyAccessViolation& e) const std::string operator()(const PropertyAccessViolation& e) const

View file

@ -0,0 +1,172 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/FileResolver.h"
#include "Luau/Common.h"
#include "Luau/StringUtils.h"
#include <algorithm>
#include <memory>
#include <optional>
#include <string_view>
#include <utility>
LUAU_FASTFLAGVARIABLE(LuauExposeRequireByStringAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauEscapeCharactersInRequireSuggestions)
LUAU_FASTFLAGVARIABLE(LuauHideImpossibleRequireSuggestions)
namespace Luau
{
static std::optional<RequireSuggestions> processRequireSuggestions(std::optional<RequireSuggestions> suggestions)
{
if (!suggestions)
return suggestions;
if (FFlag::LuauEscapeCharactersInRequireSuggestions)
{
for (RequireSuggestion& suggestion : *suggestions)
{
suggestion.fullPath = escape(suggestion.fullPath);
}
}
return suggestions;
}
static RequireSuggestions makeSuggestionsFromAliases(std::vector<RequireAlias> aliases)
{
RequireSuggestions result;
for (RequireAlias& alias : aliases)
{
RequireSuggestion suggestion;
suggestion.label = "@" + std::move(alias.alias);
suggestion.fullPath = suggestion.label;
suggestion.tags = std::move(alias.tags);
result.push_back(std::move(suggestion));
}
return result;
}
static RequireSuggestions makeSuggestionsForFirstComponent(std::unique_ptr<RequireNode> node)
{
RequireSuggestions result = makeSuggestionsFromAliases(node->getAvailableAliases());
result.push_back(RequireSuggestion{"./", "./", {}});
result.push_back(RequireSuggestion{"../", "../", {}});
return result;
}
static RequireSuggestions makeSuggestionsFromNode(std::unique_ptr<RequireNode> node, const std::string_view path, bool isPartialPath)
{
LUAU_ASSERT(!path.empty());
RequireSuggestions result;
const size_t lastSlashInPath = path.find_last_of('/');
if (lastSlashInPath != std::string_view::npos)
{
// Add a suggestion for the parent directory
RequireSuggestion parentSuggestion;
parentSuggestion.label = "..";
// TODO: after exposing require-by-string's path normalization API, this
// if-else can be replaced. Instead, we can simply normalize the result
// of inserting ".." at the end of the current path.
if (lastSlashInPath >= 2 && path.substr(lastSlashInPath - 2, 3) == "../")
{
parentSuggestion.fullPath = path.substr(0, lastSlashInPath + 1);
parentSuggestion.fullPath += "..";
}
else
{
parentSuggestion.fullPath = path.substr(0, lastSlashInPath);
}
result.push_back(std::move(parentSuggestion));
}
std::string fullPathPrefix;
if (isPartialPath)
{
// ./path/to/chi -> ./path/to/
fullPathPrefix += path.substr(0, lastSlashInPath + 1);
}
else
{
if (path.back() == '/')
{
// ./path/to/ -> ./path/to/
fullPathPrefix += path;
}
else
{
// ./path/to -> ./path/to/
fullPathPrefix += path;
fullPathPrefix += "/";
}
}
for (const std::unique_ptr<RequireNode>& child : node->getChildren())
{
if (!child)
continue;
std::string pathComponent = child->getPathComponent();
if (FFlag::LuauHideImpossibleRequireSuggestions)
{
// If path component contains a slash, it cannot be required by string.
// There's no point suggesting it.
if (pathComponent.find('/') != std::string::npos)
continue;
}
RequireSuggestion suggestion;
suggestion.label = isPartialPath || path.back() == '/' ? child->getLabel() : "/" + child->getLabel();
suggestion.fullPath = fullPathPrefix + std::move(pathComponent);
suggestion.tags = child->getTags();
result.push_back(std::move(suggestion));
}
return result;
}
std::optional<RequireSuggestions> RequireSuggester::getRequireSuggestionsImpl(const ModuleName& requirer, const std::optional<std::string>& path)
const
{
if (!path)
return std::nullopt;
std::unique_ptr<RequireNode> requirerNode = getNode(requirer);
if (!requirerNode)
return std::nullopt;
const size_t slashPos = path->find_last_of('/');
if (slashPos == std::string::npos)
return makeSuggestionsForFirstComponent(std::move(requirerNode));
// If path already points at a Node, return the Node's children as paths.
if (std::unique_ptr<RequireNode> node = requirerNode->resolvePathToNode(*path))
return makeSuggestionsFromNode(std::move(node), *path, /* isPartialPath = */ false);
// Otherwise, recover a partial path and use this to generate suggestions.
if (std::unique_ptr<RequireNode> partialNode = requirerNode->resolvePathToNode(path->substr(0, slashPos)))
return makeSuggestionsFromNode(std::move(partialNode), *path, /* isPartialPath = */ true);
return std::nullopt;
}
std::optional<RequireSuggestions> RequireSuggester::getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& path) const
{
return processRequireSuggestions(getRequireSuggestionsImpl(requirer, path));
}
std::optional<RequireSuggestions> FileResolver::getRequireSuggestions(const ModuleName& requirer, const std::optional<std::string>& path) const
{
if (!FFlag::LuauExposeRequireByStringAutocomplete)
return std::nullopt;
return requireSuggester ? requireSuggester->getRequireSuggestions(requirer, path) : std::nullopt;
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/AnyTypeSummary.h"
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
@ -39,7 +38,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile)
@ -48,10 +46,12 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode)
LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)
LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAGVARIABLE(LuauModuleHoldsAstRoot)
LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule)
LUAU_FASTFLAGVARIABLE(LuauReferenceAllocatorInNewSolver) LUAU_FASTFLAGVARIABLE(LuauFixMultithreadTypecheck)
LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena)
LUAU_FASTFLAG(LuauTypeFunResultInAutocomplete)
namespace Luau namespace Luau
{ {
@ -81,6 +81,20 @@ struct BuildQueueItem
Frontend::Stats stats; Frontend::Stats stats;
}; };
struct BuildQueueWorkState
{
std::function<void(std::function<void()> task)> executeTask;
std::vector<BuildQueueItem> buildQueueItems;
std::mutex mtx;
std::condition_variable cv;
std::vector<size_t> readyQueueItems;
size_t processing = 0;
size_t remaining = 0;
};
std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments) std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments)
{ {
for (const HotComment& hc : hotcomments) for (const HotComment& hc : hotcomments)
@ -138,7 +152,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod
sourceModule.root = parseResult.root; sourceModule.root = parseResult.root;
sourceModule.mode = Mode::Definition; sourceModule.mode = Mode::Definition;
if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments) if (options.captureComments)
{ {
sourceModule.hotcomments = parseResult.hotcomments; sourceModule.hotcomments = parseResult.hotcomments;
sourceModule.commentLocations = parseResult.commentLocations; sourceModule.commentLocations = parseResult.commentLocations;
@ -445,20 +459,6 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
if (item.name == name) if (item.name == name)
checkResult.lintResult = item.module->lintResult; checkResult.lintResult = item.module->lintResult;
if (FFlag::StudioReportLuauAny2 && item.options.retainFullTypeGraphs)
{
if (item.module)
{
const SourceModule& sourceModule = *item.sourceModule;
if (sourceModule.mode == Luau::Mode::Strict)
{
item.module->ats.root = toString(sourceModule.root);
}
item.module->ats.rootSrc = sourceModule.root;
item.module->ats.traverse(item.module.get(), sourceModule.root, NotNull{&builtinTypes_});
}
}
} }
return checkResult; return checkResult;
@ -480,6 +480,203 @@ std::vector<ModuleName> Frontend::checkQueuedModules(
std::function<bool(size_t done, size_t total)> progress std::function<bool(size_t done, size_t total)> progress
) )
{ {
if (!FFlag::LuauFixMultithreadTypecheck)
{
return checkQueuedModules_DEPRECATED(optionOverride, executeTask, progress);
}
FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::LuauSolverV2)
frontendOptions.forAutocomplete = false;
// By taking data into locals, we make sure queue is cleared at the end, even if an ICE or a different exception is thrown
std::vector<ModuleName> currModuleQueue;
std::swap(currModuleQueue, moduleQueue);
DenseHashSet<Luau::ModuleName> seen{{}};
std::shared_ptr<BuildQueueWorkState> state = std::make_shared<BuildQueueWorkState>();
for (const ModuleName& name : currModuleQueue)
{
if (seen.contains(name))
continue;
if (!isDirty(name, frontendOptions.forAutocomplete))
{
seen.insert(name);
continue;
}
std::vector<ModuleName> queue;
bool cycleDetected = parseGraph(
queue,
name,
frontendOptions.forAutocomplete,
[&seen](const ModuleName& name)
{
return seen.contains(name);
}
);
addBuildQueueItems(state->buildQueueItems, queue, cycleDetected, seen, frontendOptions);
}
if (state->buildQueueItems.empty())
return {};
// We need a mapping from modules to build queue slots
std::unordered_map<ModuleName, size_t> moduleNameToQueue;
for (size_t i = 0; i < state->buildQueueItems.size(); i++)
{
BuildQueueItem& item = state->buildQueueItems[i];
moduleNameToQueue[item.name] = i;
}
// Default task execution is single-threaded and immediate
if (!executeTask)
{
executeTask = [](std::function<void()> task)
{
task();
};
}
state->executeTask = executeTask;
state->remaining = state->buildQueueItems.size();
// Record dependencies between modules
for (size_t i = 0; i < state->buildQueueItems.size(); i++)
{
BuildQueueItem& item = state->buildQueueItems[i];
for (const ModuleName& dep : item.sourceNode->requireSet)
{
if (auto it = sourceNodes.find(dep); it != sourceNodes.end())
{
if (it->second->hasDirtyModule(frontendOptions.forAutocomplete))
{
item.dirtyDependencies++;
state->buildQueueItems[moduleNameToQueue[dep]].reverseDeps.push_back(i);
}
}
}
}
// In the first pass, check all modules with no pending dependencies
for (size_t i = 0; i < state->buildQueueItems.size(); i++)
{
if (state->buildQueueItems[i].dirtyDependencies == 0)
sendQueueItemTask(state, i);
}
// If not a single item was found, a cycle in the graph was hit
if (state->processing == 0)
sendQueueCycleItemTask(state);
std::vector<size_t> nextItems;
std::optional<size_t> itemWithException;
bool cancelled = false;
while (state->remaining != 0)
{
{
std::unique_lock guard(state->mtx);
// If nothing is ready yet, wait
state->cv.wait(
guard,
[state]
{
return !state->readyQueueItems.empty();
}
);
// Handle checked items
for (size_t i : state->readyQueueItems)
{
const BuildQueueItem& item = state->buildQueueItems[i];
// If exception was thrown, stop adding new items and wait for processing items to complete
if (item.exception)
itemWithException = i;
if (item.module && item.module->cancelled)
cancelled = true;
if (itemWithException || cancelled)
break;
recordItemResult(item);
// Notify items that were waiting for this dependency
for (size_t reverseDep : item.reverseDeps)
{
BuildQueueItem& reverseDepItem = state->buildQueueItems[reverseDep];
LUAU_ASSERT(reverseDepItem.dirtyDependencies != 0);
reverseDepItem.dirtyDependencies--;
// In case of a module cycle earlier, check if unlocked an item that was already processed
if (!reverseDepItem.processing && reverseDepItem.dirtyDependencies == 0)
nextItems.push_back(reverseDep);
}
}
LUAU_ASSERT(state->processing >= state->readyQueueItems.size());
state->processing -= state->readyQueueItems.size();
LUAU_ASSERT(state->remaining >= state->readyQueueItems.size());
state->remaining -= state->readyQueueItems.size();
state->readyQueueItems.clear();
}
if (progress)
{
if (!progress(state->buildQueueItems.size() - state->remaining, state->buildQueueItems.size()))
cancelled = true;
}
// Items cannot be submitted while holding the lock
for (size_t i : nextItems)
sendQueueItemTask(state, i);
nextItems.clear();
if (state->processing == 0)
{
// Typechecking might have been cancelled by user, don't return partial results
if (cancelled)
return {};
// We might have stopped because of a pending exception
if (itemWithException)
recordItemResult(state->buildQueueItems[*itemWithException]);
}
// If we aren't done, but don't have anything processing, we hit a cycle
if (state->remaining != 0 && state->processing == 0)
sendQueueCycleItemTask(state);
}
std::vector<ModuleName> checkedModules;
checkedModules.reserve(state->buildQueueItems.size());
for (size_t i = 0; i < state->buildQueueItems.size(); i++)
checkedModules.push_back(std::move(state->buildQueueItems[i].name));
return checkedModules;
}
std::vector<ModuleName> Frontend::checkQueuedModules_DEPRECATED(
std::optional<FrontendOptions> optionOverride,
std::function<void(std::function<void()> task)> executeTask,
std::function<bool(size_t done, size_t total)> progress
)
{
LUAU_ASSERT(!FFlag::LuauFixMultithreadTypecheck);
FrontendOptions frontendOptions = optionOverride.value_or(options); FrontendOptions frontendOptions = optionOverride.value_or(options);
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
frontendOptions.forAutocomplete = false; frontendOptions.forAutocomplete = false;
@ -820,6 +1017,13 @@ bool Frontend::parseGraph(
topseen = Permanent; topseen = Permanent;
buildQueue.push_back(top->name); buildQueue.push_back(top->name);
// at this point we know all valid dependencies are processed into SourceNodes
for (const ModuleName& dep : top->requireSet)
{
if (auto it = sourceNodes.find(dep); it != sourceNodes.end())
it->second->dependents.insert(top->name);
}
} }
else else
{ {
@ -1049,6 +1253,11 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
freeze(module->interfaceTypes); freeze(module->interfaceTypes);
module->internalTypes.clear(); module->internalTypes.clear();
if (FFlag::LuauSelectivelyRetainDFGArena)
{
module->defArena.allocator.clear();
module->keyArena.allocator.clear();
}
module->astTypes.clear(); module->astTypes.clear();
module->astTypePacks.clear(); module->astTypePacks.clear();
@ -1102,17 +1311,35 @@ void Frontend::recordItemResult(const BuildQueueItem& item)
if (item.exception) if (item.exception)
std::rethrow_exception(item.exception); std::rethrow_exception(item.exception);
bool replacedModule = false;
if (item.options.forAutocomplete) if (item.options.forAutocomplete)
{ {
moduleResolverForAutocomplete.setModule(item.name, item.module); replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module);
item.sourceNode->dirtyModuleForAutocomplete = false; item.sourceNode->dirtyModuleForAutocomplete = false;
} }
else else
{ {
moduleResolver.setModule(item.name, item.module); replacedModule = moduleResolver.setModule(item.name, item.module);
item.sourceNode->dirtyModule = false; item.sourceNode->dirtyModule = false;
} }
if (replacedModule)
{
LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str());
traverseDependents(
item.name,
[forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode)
{
bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete);
sourceNode.setInvalidModuleDependency(true, forAutocomplete);
return traverseSubtree;
}
);
}
item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete);
stats.timeCheck += item.stats.timeCheck; stats.timeCheck += item.stats.timeCheck;
stats.timeLint += item.stats.timeLint; stats.timeLint += item.stats.timeLint;
@ -1120,6 +1347,58 @@ void Frontend::recordItemResult(const BuildQueueItem& item)
stats.filesNonstrict += item.stats.filesNonstrict; stats.filesNonstrict += item.stats.filesNonstrict;
} }
void Frontend::performQueueItemTask(std::shared_ptr<BuildQueueWorkState> state, size_t itemPos)
{
BuildQueueItem& item = state->buildQueueItems[itemPos];
try
{
checkBuildQueueItem(item);
}
catch (...)
{
item.exception = std::current_exception();
}
{
std::unique_lock guard(state->mtx);
state->readyQueueItems.push_back(itemPos);
}
state->cv.notify_one();
}
void Frontend::sendQueueItemTask(std::shared_ptr<BuildQueueWorkState> state, size_t itemPos)
{
BuildQueueItem& item = state->buildQueueItems[itemPos];
LUAU_ASSERT(!item.processing);
item.processing = true;
state->processing++;
state->executeTask(
[this, state, itemPos]()
{
performQueueItemTask(state, itemPos);
}
);
}
void Frontend::sendQueueCycleItemTask(std::shared_ptr<BuildQueueWorkState> state)
{
for (size_t i = 0; i < state->buildQueueItems.size(); i++)
{
BuildQueueItem& item = state->buildQueueItems[i];
if (!item.processing)
{
sendQueueItemTask(state, i);
break;
}
}
}
ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const
{ {
ScopePtr result; ScopePtr result;
@ -1147,6 +1426,12 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
return result; return result;
} }
bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const
{
auto it = sourceNodes.find(name);
return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete);
}
bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{ {
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
@ -1161,16 +1446,35 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
*/ */
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty) void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
traverseDependents(
name,
[markedDirty](SourceNode& sourceNode)
{
if (markedDirty)
markedDirty->push_back(sourceNode.name);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
return false;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
return true;
}
);
}
void Frontend::traverseDependents(const ModuleName& name, std::function<bool(SourceNode&)> processSubtree)
{
LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend");
if (sourceNodes.count(name) == 0) if (sourceNodes.count(name) == 0)
return; return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes)
{
for (const auto& dep : module.second->requireSet)
reverseDeps[dep].push_back(module.first);
}
std::vector<ModuleName> queue{name}; std::vector<ModuleName> queue{name};
while (!queue.empty()) while (!queue.empty())
@ -1181,22 +1485,10 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
LUAU_ASSERT(sourceNodes.count(next) > 0); LUAU_ASSERT(sourceNodes.count(next) > 0);
SourceNode& sourceNode = *sourceNodes[next]; SourceNode& sourceNode = *sourceNodes[next];
if (markedDirty) if (!processSubtree(sourceNode))
markedDirty->push_back(next);
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
continue; continue;
sourceNode.dirtySourceModule = true; const Set<ModuleName>& dependents = sourceNode.dependents;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(next))
continue;
sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[next];
queue.insert(queue.end(), dependents.begin(), dependents.end()); queue.insert(queue.end(), dependents.begin(), dependents.end());
} }
} }
@ -1224,6 +1516,7 @@ ModulePtr check(
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope, const ScopePtr& parentScope,
const ScopePtr& typeFunctionScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options, FrontendOptions options,
TypeCheckLimits limits, TypeCheckLimits limits,
@ -1240,6 +1533,7 @@ ModulePtr check(
moduleResolver, moduleResolver,
fileResolver, fileResolver,
parentScope, parentScope,
typeFunctionScope,
std::move(prepareModuleScope), std::move(prepareModuleScope),
options, options,
limits, limits,
@ -1301,6 +1595,7 @@ ModulePtr check(
NotNull<ModuleResolver> moduleResolver, NotNull<ModuleResolver> moduleResolver,
NotNull<FileResolver> fileResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& parentScope, const ScopePtr& parentScope,
const ScopePtr& typeFunctionScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope,
FrontendOptions options, FrontendOptions options,
TypeCheckLimits limits, TypeCheckLimits limits,
@ -1313,18 +1608,16 @@ ModulePtr check(
LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str()); LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str());
ModulePtr result = std::make_shared<Module>(); ModulePtr result = std::make_shared<Module>();
if (FFlag::LuauStoreSolverTypeOnModule) result->checkedInNewSolver = true;
result->checkedInNewSolver = true;
result->name = sourceModule.name; result->name = sourceModule.name;
result->humanReadableName = sourceModule.humanReadableName; result->humanReadableName = sourceModule.humanReadableName;
result->mode = mode; result->mode = mode;
result->internalTypes.owningModule = result.get(); result->internalTypes.owningModule = result.get();
result->interfaceTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get();
if (FFlag::LuauReferenceAllocatorInNewSolver) result->allocator = sourceModule.allocator;
{ result->names = sourceModule.names;
result->allocator = sourceModule.allocator; if (FFlag::LuauModuleHoldsAstRoot)
result->names = sourceModule.names; result->root = sourceModule.root;
}
iceHandler->moduleName = sourceModule.name; iceHandler->moduleName = sourceModule.name;
@ -1349,7 +1642,7 @@ ModulePtr check(
SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes); SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes);
TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}};
typeFunctionRuntime.allowEvaluation = sourceModule.parseErrors.empty(); typeFunctionRuntime.allowEvaluation = FFlag::LuauTypeFunResultInAutocomplete || sourceModule.parseErrors.empty();
ConstraintGenerator cg{ ConstraintGenerator cg{
result, result,
@ -1360,6 +1653,7 @@ ModulePtr check(
builtinTypes, builtinTypes,
iceHandler, iceHandler,
parentScope, parentScope,
typeFunctionScope,
std::move(prepareModuleScope), std::move(prepareModuleScope),
logger.get(), logger.get(),
NotNull{&dfg}, NotNull{&dfg},
@ -1375,6 +1669,7 @@ ModulePtr check(
NotNull{&typeFunctionRuntime}, NotNull{&typeFunctionRuntime},
NotNull(cg.rootScope), NotNull(cg.rootScope),
borrowConstraints(cg.constraints), borrowConstraints(cg.constraints),
NotNull{&cg.scopeToFunction},
result->name, result->name,
moduleResolver, moduleResolver,
requireCycles, requireCycles,
@ -1541,6 +1836,7 @@ ModulePtr Frontend::check(
NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver},
NotNull{fileResolver}, NotNull{fileResolver},
environmentScope ? *environmentScope : globals.globalScope, environmentScope ? *environmentScope : globals.globalScope,
globals.globalTypeFunctionScope,
prepareModuleScopeWrap, prepareModuleScopeWrap,
options, options,
typeCheckLimits, typeCheckLimits,
@ -1638,6 +1934,14 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(const ModuleName&
sourceNode->name = sourceModule->name; sourceNode->name = sourceModule->name;
sourceNode->humanReadableName = sourceModule->humanReadableName; sourceNode->humanReadableName = sourceModule->humanReadableName;
// clear all prior dependents. we will re-add them after parsing the rest of the graph
for (const auto& [moduleName, _] : sourceNode->requireLocations)
{
if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end())
depIt->second->dependents.erase(sourceNode->name);
}
sourceNode->requireSet.clear(); sourceNode->requireSet.clear();
sourceNode->requireLocations.clear(); sourceNode->requireLocations.clear();
sourceNode->dirtySourceModule = false; sourceNode->dirtySourceModule = false;
@ -1759,11 +2063,13 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName&
return frontend->fileResolver->getHumanReadableModuleName(moduleName); return frontend->fileResolver->getHumanReadableModuleName(moduleName);
} }
void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module)
{ {
std::scoped_lock lock(moduleMutex); std::scoped_lock lock(moduleMutex);
bool replaced = modules.count(moduleName) > 0;
modules[moduleName] = std::move(module); modules[moduleName] = std::move(module);
return replaced;
} }
void FrontendModuleResolver::clearModules() void FrontendModuleResolver::clearModules()

View file

@ -2,6 +2,8 @@
#include "Luau/Generalization.h" #include "Luau/Generalization.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
@ -10,7 +12,7 @@
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete)
LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound) LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound2)
namespace Luau namespace Luau
{ {
@ -28,7 +30,6 @@ struct MutatingGeneralizer : TypeOnceVisitor
std::vector<TypePackId> genericPacks; std::vector<TypePackId> genericPacks;
bool isWithinFunction = false; bool isWithinFunction = false;
bool avoidSealingTables = false;
MutatingGeneralizer( MutatingGeneralizer(
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
@ -36,8 +37,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, NotNull<DenseHashSet<TypeId>> cachedTypes,
DenseHashMap<const void*, size_t> positiveTypes, DenseHashMap<const void*, size_t> positiveTypes,
DenseHashMap<const void*, size_t> negativeTypes, DenseHashMap<const void*, size_t> negativeTypes
bool avoidSealingTables
) )
: TypeOnceVisitor(/* skipBoundTypes */ true) : TypeOnceVisitor(/* skipBoundTypes */ true)
, arena(arena) , arena(arena)
@ -46,11 +46,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
, cachedTypes(cachedTypes) , cachedTypes(cachedTypes)
, positiveTypes(std::move(positiveTypes)) , positiveTypes(std::move(positiveTypes))
, negativeTypes(std::move(negativeTypes)) , negativeTypes(std::move(negativeTypes))
, avoidSealingTables(avoidSealingTables)
{ {
} }
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement) void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{ {
haystack = follow(haystack); haystack = follow(haystack);
@ -97,6 +96,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
LUAU_ASSERT(onlyType != haystack); LUAU_ASSERT(onlyType != haystack);
emplaceType<BoundType>(asMutable(haystack), onlyType); emplaceType<BoundType>(asMutable(haystack), onlyType);
} }
else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && ut->options.empty())
{
emplaceType<BoundType>(asMutable(haystack), builtinTypes->neverType);
}
return; return;
} }
@ -140,6 +143,10 @@ struct MutatingGeneralizer : TypeOnceVisitor
LUAU_ASSERT(onlyType != needle); LUAU_ASSERT(onlyType != needle);
emplaceType<BoundType>(asMutable(needle), onlyType); emplaceType<BoundType>(asMutable(needle), onlyType);
} }
else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && it->parts.empty())
{
emplaceType<BoundType>(asMutable(needle), builtinTypes->unknownType);
}
return; return;
} }
@ -233,53 +240,6 @@ struct MutatingGeneralizer : TypeOnceVisitor
else else
{ {
TypeId ub = follow(ft->upperBound); TypeId ub = follow(ft->upperBound);
if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound)
{
// If the upper bound is a union type or an intersection type,
// and one of it's members is the free type we're
// generalizing, don't include it in the upper bound. For a
// free type such as:
//
// t1 where t1 = D <: 'a <: (A | B | C | t1)
//
// Naively replacing it with it's upper bound creates:
//
// t1 where t1 = A | B | C | t1
//
// It makes sense to just optimize this and exclude the
// recursive component by semantic subtyping rules.
if (auto itv = get<IntersectionType>(ub))
{
std::vector<TypeId> newIds;
newIds.reserve(itv->parts.size());
for (auto part : itv)
{
if (part != ty)
newIds.push_back(part);
}
if (newIds.size() == 1)
ub = newIds[0];
else if (newIds.size() > 0)
ub = arena->addType(IntersectionType{std::move(newIds)});
}
else if (auto utv = get<UnionType>(ub))
{
std::vector<TypeId> newIds;
newIds.reserve(utv->options.size());
for (auto part : utv)
{
if (part != ty)
newIds.push_back(part);
}
if (newIds.size() == 1)
ub = newIds[0];
else if (newIds.size() > 0)
ub = arena->addType(UnionType{std::move(newIds)});
}
}
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty) if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType; upperFree->lowerBound = builtinTypes->neverType;
else else
@ -329,8 +289,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
TableType* tt = getMutable<TableType>(ty); TableType* tt = getMutable<TableType>(ty);
LUAU_ASSERT(tt); LUAU_ASSERT(tt);
if (!avoidSealingTables) tt->state = TableState::Sealed;
tt->state = TableState::Sealed;
return true; return true;
} }
@ -369,26 +328,19 @@ struct FreeTypeSearcher : TypeVisitor
{ {
} }
enum Polarity Polarity polarity = Polarity::Positive;
{
Positive,
Negative,
Both,
};
Polarity polarity = Positive;
void flip() void flip()
{ {
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
polarity = Negative; polarity = Polarity::Negative;
break; break;
case Negative: case Polarity::Negative:
polarity = Positive; polarity = Polarity::Positive;
break; break;
case Both: default:
break; break;
} }
} }
@ -396,11 +348,11 @@ struct FreeTypeSearcher : TypeVisitor
DenseHashSet<const void*> seenPositive{nullptr}; DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr}; DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty) bool seenWithCurrentPolarity(const void* ty)
{ {
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
{ {
if (seenPositive.contains(ty)) if (seenPositive.contains(ty))
return true; return true;
@ -408,7 +360,7 @@ struct FreeTypeSearcher : TypeVisitor
seenPositive.insert(ty); seenPositive.insert(ty);
return false; return false;
} }
case Negative: case Polarity::Negative:
{ {
if (seenNegative.contains(ty)) if (seenNegative.contains(ty))
return true; return true;
@ -416,7 +368,7 @@ struct FreeTypeSearcher : TypeVisitor
seenNegative.insert(ty); seenNegative.insert(ty);
return false; return false;
} }
case Both: case Polarity::Mixed:
{ {
if (seenPositive.contains(ty) && seenNegative.contains(ty)) if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true; return true;
@ -425,6 +377,8 @@ struct FreeTypeSearcher : TypeVisitor
seenNegative.insert(ty); seenNegative.insert(ty);
return false; return false;
} }
default:
LUAU_ASSERT(!"Unreachable");
} }
return false; return false;
@ -438,7 +392,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty) override bool visit(TypeId ty) override
{ {
if (cachedTypes->contains(ty) || seenWithPolarity(ty)) if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
return false; return false;
LUAU_ASSERT(ty); LUAU_ASSERT(ty);
@ -447,7 +401,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FreeType& ft) override bool visit(TypeId ty, const FreeType& ft) override
{ {
if (cachedTypes->contains(ty) || seenWithPolarity(ty)) if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
return false; return false;
if (!subsumes(scope, ft.scope)) if (!subsumes(scope, ft.scope))
@ -455,16 +409,18 @@ struct FreeTypeSearcher : TypeVisitor
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
positiveTypes[ty]++; positiveTypes[ty]++;
break; break;
case Negative: case Polarity::Negative:
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
case Both: case Polarity::Mixed:
positiveTypes[ty]++; positiveTypes[ty]++;
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
return true; return true;
@ -472,23 +428,25 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const TableType& tt) override bool visit(TypeId ty, const TableType& tt) override
{ {
if (cachedTypes->contains(ty) || seenWithPolarity(ty)) if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
return false; return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{ {
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
positiveTypes[ty]++; positiveTypes[ty]++;
break; break;
case Negative: case Polarity::Negative:
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
case Both: case Polarity::Mixed:
positiveTypes[ty]++; positiveTypes[ty]++;
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
} }
@ -501,7 +459,7 @@ struct FreeTypeSearcher : TypeVisitor
LUAU_ASSERT(prop.isShared() || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); LUAU_ASSERT(prop.isShared() || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
Polarity p = polarity; Polarity p = polarity;
polarity = Both; polarity = Polarity::Mixed;
traverse(prop.type()); traverse(prop.type());
polarity = p; polarity = p;
} }
@ -518,7 +476,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FunctionType& ft) override bool visit(TypeId ty, const FunctionType& ft) override
{ {
if (cachedTypes->contains(ty) || seenWithPolarity(ty)) if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
return false; return false;
flip(); flip();
@ -537,7 +495,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypePackId tp, const FreeTypePack& ftp) override bool visit(TypePackId tp, const FreeTypePack& ftp) override
{ {
if (seenWithPolarity(tp)) if (seenWithCurrentPolarity(tp))
return false; return false;
if (!subsumes(scope, ftp.scope)) if (!subsumes(scope, ftp.scope))
@ -545,16 +503,18 @@ struct FreeTypeSearcher : TypeVisitor
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
positiveTypes[tp]++; positiveTypes[tp]++;
break; break;
case Negative: case Polarity::Negative:
negativeTypes[tp]++; negativeTypes[tp]++;
break; break;
case Both: case Polarity::Mixed:
positiveTypes[tp]++; positiveTypes[tp]++;
negativeTypes[tp]++; negativeTypes[tp]++;
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
return true; return true;
@ -584,7 +544,7 @@ struct TypeCacher : TypeOnceVisitor
{ {
} }
void cache(TypeId ty) void cache(TypeId ty) const
{ {
cachedTypes->insert(ty); cachedTypes->insert(ty);
} }
@ -1009,8 +969,7 @@ std::optional<TypeId> generalize(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<Scope> scope, NotNull<Scope> scope,
NotNull<DenseHashSet<TypeId>> cachedTypes, NotNull<DenseHashSet<TypeId>> cachedTypes,
TypeId ty, TypeId ty
bool avoidSealingTables
) )
{ {
ty = follow(ty); ty = follow(ty);
@ -1021,7 +980,7 @@ std::optional<TypeId> generalize(
FreeTypeSearcher fts{scope, cachedTypes}; FreeTypeSearcher fts{scope, cachedTypes};
fts.traverse(ty); fts.traverse(ty);
MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes)};
gen.traverse(ty); gen.traverse(ty);

View file

@ -9,6 +9,7 @@ GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
{ {
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalTypeFunctionScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType});
globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType});

View file

@ -11,6 +11,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -163,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty)
} }
else else
{ {
return addType(FreeType{scope, level}); return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level});
} }
} }

View file

@ -19,6 +19,8 @@ LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttribute) LUAU_FASTFLAG(LuauAttribute)
LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute) LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute)
LUAU_FASTFLAG(LuauDeprecatedAttribute)
namespace Luau namespace Luau
{ {
@ -2280,6 +2282,57 @@ private:
{ {
} }
bool visit(AstExprLocal* node) override
{
if (FFlag::LuauDeprecatedAttribute)
{
const FunctionType* fty = getFunctionType(node);
bool shouldReport = fty && fty->isDeprecatedFunction && !inScope(fty);
if (shouldReport)
report(node->location, node->local->name.value);
}
return true;
}
bool visit(AstExprGlobal* node) override
{
if (FFlag::LuauDeprecatedAttribute)
{
const FunctionType* fty = getFunctionType(node);
bool shouldReport = fty && fty->isDeprecatedFunction && !inScope(fty);
if (shouldReport)
report(node->location, node->name.value);
}
return true;
}
bool visit(AstStatLocalFunction* node) override
{
if (FFlag::LuauDeprecatedAttribute)
{
check(node->func);
return false;
}
else
return true;
}
bool visit(AstStatFunction* node) override
{
if (FFlag::LuauDeprecatedAttribute)
{
check(node->func);
return false;
}
else
return true;
}
bool visit(AstExprIndexName* node) override bool visit(AstExprIndexName* node) override
{ {
if (std::optional<TypeId> ty = context->getType(node->expr)) if (std::optional<TypeId> ty = context->getType(node->expr))
@ -2325,18 +2378,59 @@ private:
if (prop && prop->deprecated) if (prop && prop->deprecated)
report(node->location, *prop, cty->name.c_str(), node->index.value); report(node->location, *prop, cty->name.c_str(), node->index.value);
else if (FFlag::LuauDeprecatedAttribute && prop)
{
if (std::optional<TypeId> ty = prop->readTy)
{
const FunctionType* fty = get<FunctionType>(follow(ty));
bool shouldReport = fty && fty->isDeprecatedFunction && !inScope(fty);
if (shouldReport)
{
const char* className = nullptr;
if (AstExprGlobal* global = node->expr->as<AstExprGlobal>())
className = global->name.value;
const char* functionName = node->index.value;
report(node->location, className, functionName);
}
}
}
} }
else if (const TableType* tty = get<TableType>(ty)) else if (const TableType* tty = get<TableType>(ty))
{ {
auto prop = tty->props.find(node->index.value); auto prop = tty->props.find(node->index.value);
if (prop != tty->props.end() && prop->second.deprecated) if (prop != tty->props.end())
{ {
// strip synthetic typeof() for builtin tables if (prop->second.deprecated)
if (tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')') {
report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value); // strip synthetic typeof() for builtin tables
else if (tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')')
report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value);
else
report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value);
}
else if (FFlag::LuauDeprecatedAttribute)
{
if (std::optional<TypeId> ty = prop->second.readTy)
{
const FunctionType* fty = get<FunctionType>(follow(ty));
bool shouldReport = fty && fty->isDeprecatedFunction && !inScope(fty);
if (shouldReport)
{
const char* className = nullptr;
if (AstExprGlobal* global = node->expr->as<AstExprGlobal>())
className = global->name.value;
const char* functionName = node->index.value;
report(node->location, className, functionName);
}
}
}
} }
} }
} }
@ -2355,6 +2449,26 @@ private:
} }
} }
void check(AstExprFunction* func)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
LUAU_ASSERT(func);
const FunctionType* fty = getFunctionType(func);
bool isDeprecated = fty && fty->isDeprecatedFunction;
// If a function is deprecated, we don't want to flag its recursive uses.
// So we push it on a stack while its body is being analyzed.
// When a deprecated function is used, we check the stack to ensure that we are not inside that function.
if (isDeprecated)
pushScope(fty);
func->visit(this);
if (isDeprecated)
popScope(fty);
}
void report(const Location& location, const Property& prop, const char* container, const char* field) void report(const Location& location, const Property& prop, const char* container, const char* field)
{ {
std::string suggestion = prop.deprecatedSuggestion.empty() ? "" : format(", use '%s' instead", prop.deprecatedSuggestion.c_str()); std::string suggestion = prop.deprecatedSuggestion.empty() ? "" : format(", use '%s' instead", prop.deprecatedSuggestion.c_str());
@ -2364,6 +2478,63 @@ private:
else else
emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated%s", field, suggestion.c_str()); emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated%s", field, suggestion.c_str());
} }
void report(const Location& location, const char* tableName, const char* functionName)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
if (tableName)
emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s.%s' is deprecated", tableName, functionName);
else
emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated", functionName);
}
void report(const Location& location, const char* functionName)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Function '%s' is deprecated", functionName);
}
std::vector<const FunctionType*> functionTypeScopeStack;
void pushScope(const FunctionType* fty)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
LUAU_ASSERT(fty);
functionTypeScopeStack.push_back(fty);
}
void popScope(const FunctionType* fty)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
LUAU_ASSERT(fty);
LUAU_ASSERT(fty == functionTypeScopeStack.back());
functionTypeScopeStack.pop_back();
}
bool inScope(const FunctionType* fty) const
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
LUAU_ASSERT(fty);
return std::find(functionTypeScopeStack.begin(), functionTypeScopeStack.end(), fty) != functionTypeScopeStack.end();
}
const FunctionType* getFunctionType(AstExpr* node)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
std::optional<TypeId> ty = context->getType(node);
if (!ty)
return nullptr;
const FunctionType* fty = get<FunctionType>(follow(ty));
return fty;
}
}; };
class LintTableOperations : AstVisitor class LintTableOperations : AstVisitor

View file

@ -20,6 +20,26 @@ LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection)
namespace Luau namespace Luau
{ {
static void defaultLogLuau(std::string_view context, std::string_view input)
{
// The default is to do nothing because we don't want to mess with
// the xml parsing done by the dcr script.
}
Luau::LogLuauProc logLuau = &defaultLogLuau;
void setLogLuau(LogLuauProc ll)
{
logLuau = ll;
}
void resetLogLuauProc()
{
logLuau = &defaultLogLuau;
}
static bool contains_DEPRECATED(Position pos, Comment comment) static bool contains_DEPRECATED(Position pos, Comment comment)
{ {
if (comment.location.contains(pos)) if (comment.location.contains(pos))

View file

@ -2,6 +2,7 @@
#include "Luau/NonStrictTypeChecker.h" #include "Luau/NonStrictTypeChecker.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Type.h" #include "Luau/Type.h"
@ -14,12 +15,15 @@
#include "Luau/TypeFunction.h" #include "Luau/TypeFunction.h"
#include "Luau/Def.h" #include "Luau/Def.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h"
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAGVARIABLE(LuauNonStrictVisitorImprovements)
LUAU_FASTFLAGVARIABLE(LuauNewNonStrictWarnOnUnknownGlobals)
LUAU_FASTFLAGVARIABLE(LuauNonStrictFuncDefErrorFix)
namespace Luau namespace Luau
{ {
@ -211,7 +215,7 @@ struct NonStrictTypeChecker
return *fst; return *fst;
else if (auto ftp = get<FreeTypePack>(pack)) else if (auto ftp = get<FreeTypePack>(pack))
{ {
TypeId result = arena->addType(FreeType{ftp->scope}); TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope});
TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack)); TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -341,8 +345,9 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatIf* ifStatement) NonStrictContext visit(AstStatIf* ifStatement)
{ {
NonStrictContext condB = visit(ifStatement->condition); NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue);
NonStrictContext branchContext; NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody) if (ifStatement->elsebody)
{ {
@ -350,17 +355,32 @@ struct NonStrictTypeChecker
NonStrictContext elseBody = visit(ifStatement->elsebody); NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody);
} }
return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext);
} }
NonStrictContext visit(AstStatWhile* whileStatement) NonStrictContext visit(AstStatWhile* whileStatement)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext condition = visit(whileStatement->condition, ValueContext::RValue);
NonStrictContext body = visit(whileStatement->body);
return NonStrictContext::disjunction(builtinTypes, arena, condition, body);
}
else
return {};
} }
NonStrictContext visit(AstStatRepeat* repeatStatement) NonStrictContext visit(AstStatRepeat* repeatStatement)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext body = visit(repeatStatement->body);
NonStrictContext condition = visit(repeatStatement->condition, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, body, condition);
}
else
return {};
} }
NonStrictContext visit(AstStatBreak* breakStatement) NonStrictContext visit(AstStatBreak* breakStatement)
@ -375,49 +395,94 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatReturn* returnStatement) NonStrictContext visit(AstStatReturn* returnStatement)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
// TODO: this is believing existing code, but i'm not sure if this makes sense
// for how the contexts are handled
for (AstExpr* expr : returnStatement->list)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstStatExpr* expr) NonStrictContext visit(AstStatExpr* expr)
{ {
return visit(expr->expr); return visit(expr->expr, ValueContext::RValue);
} }
NonStrictContext visit(AstStatLocal* local) NonStrictContext visit(AstStatLocal* local)
{ {
for (AstExpr* rhs : local->values) for (AstExpr* rhs : local->values)
visit(rhs); visit(rhs, ValueContext::RValue);
return {}; return {};
} }
NonStrictContext visit(AstStatFor* forStatement) NonStrictContext visit(AstStatFor* forStatement)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
{
// TODO: throwing out context based on same principle as existing code?
if (forStatement->from)
visit(forStatement->from, ValueContext::RValue);
if (forStatement->to)
visit(forStatement->to, ValueContext::RValue);
if (forStatement->step)
visit(forStatement->step, ValueContext::RValue);
return visit(forStatement->body);
}
else
{
return {};
}
} }
NonStrictContext visit(AstStatForIn* forInStatement) NonStrictContext visit(AstStatForIn* forInStatement)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* rhs : forInStatement->values)
visit(rhs, ValueContext::RValue);
return visit(forInStatement->body);
}
else
{
return {};
}
} }
NonStrictContext visit(AstStatAssign* assign) NonStrictContext visit(AstStatAssign* assign)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* lhs : assign->vars)
visit(lhs, ValueContext::LValue);
for (AstExpr* rhs : assign->values)
visit(rhs, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstStatCompoundAssign* compoundAssign) NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
visit(compoundAssign->var, ValueContext::LValue);
visit(compoundAssign->value, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstStatFunction* statFn) NonStrictContext visit(AstStatFunction* statFn)
{ {
return visit(statFn->func); return visit(statFn->func, ValueContext::RValue);
} }
NonStrictContext visit(AstStatLocalFunction* localFn) NonStrictContext visit(AstStatLocalFunction* localFn)
{ {
return visit(localFn->func); return visit(localFn->func, ValueContext::RValue);
} }
NonStrictContext visit(AstStatTypeAlias* typeAlias) NonStrictContext visit(AstStatTypeAlias* typeAlias)
@ -447,14 +512,22 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstStatError* error) NonStrictContext visit(AstStatError* error)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstStat* stat : error->statements)
visit(stat);
for (AstExpr* expr : error->expressions)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstExpr* expr) NonStrictContext visit(AstExpr* expr, ValueContext context)
{ {
auto pusher = pushStack(expr); auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>()) if (auto e = expr->as<AstExprGroup>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprConstantNil>()) else if (auto e = expr->as<AstExprConstantNil>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprConstantBool>()) else if (auto e = expr->as<AstExprConstantBool>())
@ -464,17 +537,17 @@ struct NonStrictTypeChecker
else if (auto e = expr->as<AstExprConstantString>()) else if (auto e = expr->as<AstExprConstantString>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprLocal>()) else if (auto e = expr->as<AstExprLocal>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprGlobal>()) else if (auto e = expr->as<AstExprGlobal>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprVarargs>()) else if (auto e = expr->as<AstExprVarargs>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprCall>()) else if (auto e = expr->as<AstExprCall>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprIndexName>()) else if (auto e = expr->as<AstExprIndexName>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprIndexExpr>()) else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e); return visit(e, context);
else if (auto e = expr->as<AstExprFunction>()) else if (auto e = expr->as<AstExprFunction>())
return visit(e); return visit(e);
else if (auto e = expr->as<AstExprTable>()) else if (auto e = expr->as<AstExprTable>())
@ -498,9 +571,12 @@ struct NonStrictTypeChecker
} }
} }
NonStrictContext visit(AstExprGroup* group) NonStrictContext visit(AstExprGroup* group, ValueContext context)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
return visit(group->expr, context);
else
return {};
} }
NonStrictContext visit(AstExprConstantNil* expr) NonStrictContext visit(AstExprConstantNil* expr)
@ -523,34 +599,36 @@ struct NonStrictTypeChecker
return {}; return {};
} }
NonStrictContext visit(AstExprLocal* local) NonStrictContext visit(AstExprLocal* local, ValueContext context)
{ {
return {}; return {};
} }
NonStrictContext visit(AstExprGlobal* global) NonStrictContext visit(AstExprGlobal* global, ValueContext context)
{ {
if (FFlag::LuauNewNonStrictWarnOnUnknownGlobals)
{
// We don't file unknown symbols for LValues.
if (context == ValueContext::LValue)
return {};
NotNull<Scope> scope = stack.back();
if (!scope->lookup(global->name))
{
reportError(UnknownSymbol{global->name.value, UnknownSymbol::Binding}, global->location);
}
}
return {}; return {};
} }
NonStrictContext visit(AstExprVarargs* global) NonStrictContext visit(AstExprVarargs* varargs)
{ {
return {}; return {};
} }
NonStrictContext visit(AstExprCall* call) NonStrictContext visit(AstExprCall* call)
{ {
if (FFlag::LuauCountSelfCallsNonstrict)
return visitCall(call);
else
return visitCall_DEPRECATED(call);
}
// rename this to `visit` when `FFlag::LuauCountSelfCallsNonstrict` is removed, and clean up above `visit`.
NonStrictContext visitCall(AstExprCall* call)
{
LUAU_ASSERT(FFlag::LuauCountSelfCallsNonstrict);
NonStrictContext fresh{}; NonStrictContext fresh{};
TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func); TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func);
if (!originalCallTy) if (!originalCallTy)
@ -659,117 +737,24 @@ struct NonStrictTypeChecker
return fresh; return fresh;
} }
// Remove with `FFlag::LuauCountSelfCallsNonstrict` clean up. NonStrictContext visit(AstExprIndexName* indexName, ValueContext context)
NonStrictContext visitCall_DEPRECATED(AstExprCall* call)
{ {
LUAU_ASSERT(!FFlag::LuauCountSelfCallsNonstrict); if (FFlag::LuauNonStrictVisitorImprovements)
return visit(indexName->expr, context);
else
return {};
}
NonStrictContext fresh{}; NonStrictContext visit(AstExprIndexExpr* indexExpr, ValueContext context)
TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func); {
if (!originalCallTy) if (FFlag::LuauNonStrictVisitorImprovements)
return fresh;
TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy)))
{ {
if (fn->isCheckedFunction) NonStrictContext expr = visit(indexExpr->expr, context);
{ NonStrictContext index = visit(indexExpr->index, ValueContext::RValue);
// We know fn is a checked function, which means it looks like: return NonStrictContext::disjunction(builtinTypes, arena, expr, index);
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error
std::vector<TypeId> argTypes;
argTypes.reserve(call->args.size);
// Pad out the arg types array with the types you would expect to see
TypePackIterator curr = begin(fn->argTypes);
TypePackIterator fin = end(fn->argTypes);
while (curr != fin)
{
argTypes.push_back(*curr);
++curr;
}
if (auto argTail = curr.tail())
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail)))
{
while (argTypes.size() < call->args.size)
{
argTypes.push_back(vtp->ty);
}
}
}
std::string functionName = getFunctionNameAsString(*call->func).value_or("");
if (call->args.size > argTypes.size())
{
// We are passing more arguments than we expect, so we should error
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location);
return fresh;
}
for (size_t i = 0; i < call->args.size; i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i];
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
DefId def = dfg->getDef(arg);
TypeId runTimeErrorTy;
// If we're dealing with any, negating any will cause all subtype tests to fail
// However, when someone calls this function, they're going to want to be able to pass it anything,
// for that reason, we manually inject never into the context so that the runtime test will always pass.
if (!norm)
reportError(NormalizationTooComplex{}, arg->location);
if (norm && get<AnyType>(norm->tops))
runTimeErrorTy = builtinTypes->neverType;
else
runTimeErrorTy = getOrCreateNegation(expectedArgType);
fresh.addContext(def, runTimeErrorTy);
}
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
for (size_t i = 0; i < call->args.size; i++)
{
AstExpr* arg = call->args.data[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
}
if (call->args.size < argTypes.size())
{
// We are passing fewer arguments than we expect
// so we need to ensure that the rest of the args are optional.
bool remainingArgsOptional = true;
for (size_t i = call->args.size; i < argTypes.size(); i++)
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);
if (!remainingArgsOptional)
{
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location);
return fresh;
}
}
}
} }
else
return fresh; return {};
}
NonStrictContext visit(AstExprIndexName* indexName)
{
return {};
}
NonStrictContext visit(AstExprIndexExpr* indexExpr)
{
return {};
} }
NonStrictContext visit(AstExprFunction* exprFn) NonStrictContext visit(AstExprFunction* exprFn)
@ -780,7 +765,17 @@ struct NonStrictTypeChecker
for (AstLocal* local : exprFn->args) for (AstLocal* local : exprFn->args)
{ {
if (std::optional<TypeId> ty = willRunTimeErrorFunctionDefinition(local, remainder)) if (std::optional<TypeId> ty = willRunTimeErrorFunctionDefinition(local, remainder))
reportError(NonStrictFunctionDefinitionError{exprFn->debugname.value, local->name.value, *ty}, local->location); {
if (FFlag::LuauNonStrictFuncDefErrorFix)
{
const char* debugname = exprFn->debugname.value;
reportError(NonStrictFunctionDefinitionError{debugname ? debugname : "", local->name.value, *ty}, local->location);
}
else
{
reportError(NonStrictFunctionDefinitionError{exprFn->debugname.value, local->name.value, *ty}, local->location);
}
}
remainder.remove(dfg->getDef(local)); remainder.remove(dfg->getDef(local));
} }
return remainder; return remainder;
@ -788,39 +783,74 @@ struct NonStrictTypeChecker
NonStrictContext visit(AstExprTable* table) NonStrictContext visit(AstExprTable* table)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (auto [_, key, value] : table->items)
{
if (key)
visit(key, ValueContext::RValue);
visit(value, ValueContext::RValue);
}
}
return {}; return {};
} }
NonStrictContext visit(AstExprUnary* unary) NonStrictContext visit(AstExprUnary* unary)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
return visit(unary->expr, ValueContext::RValue);
else
return {};
} }
NonStrictContext visit(AstExprBinary* binary) NonStrictContext visit(AstExprBinary* binary)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
{
NonStrictContext lhs = visit(binary->left, ValueContext::RValue);
NonStrictContext rhs = visit(binary->right, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, lhs, rhs);
}
else
return {};
} }
NonStrictContext visit(AstExprTypeAssertion* typeAssertion) NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{ {
return {}; if (FFlag::LuauNonStrictVisitorImprovements)
return visit(typeAssertion->expr, ValueContext::RValue);
else
return {};
} }
NonStrictContext visit(AstExprIfElse* ifElse) NonStrictContext visit(AstExprIfElse* ifElse)
{ {
NonStrictContext condB = visit(ifElse->condition); NonStrictContext condB = visit(ifElse->condition, ValueContext::RValue);
NonStrictContext thenB = visit(ifElse->trueExpr); NonStrictContext thenB = visit(ifElse->trueExpr, ValueContext::RValue);
NonStrictContext elseB = visit(ifElse->falseExpr); NonStrictContext elseB = visit(ifElse->falseExpr, ValueContext::RValue);
return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB));
} }
NonStrictContext visit(AstExprInterpString* interpString) NonStrictContext visit(AstExprInterpString* interpString)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* expr : interpString->expressions)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }
NonStrictContext visit(AstExprError* error) NonStrictContext visit(AstExprError* error)
{ {
if (FFlag::LuauNonStrictVisitorImprovements)
{
for (AstExpr* expr : error->expressions)
visit(expr, ValueContext::RValue);
}
return {}; return {};
} }

View file

@ -17,12 +17,15 @@
LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant) LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant)
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeNegatedErrorToAnError)
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAGVARIABLE(LuauNormalizeIntersectErrorToAnError)
LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000)
LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200)
LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance); LUAU_FASTINTVARIABLE(LuauNormalizeUnionLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauIntersectNormalsNeedsToTrackResourceLimits); LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization)
LUAU_FASTFLAGVARIABLE(LuauNormalizedBufferIsNotUnknown)
LUAU_FASTFLAGVARIABLE(LuauNormalizeLimitFunctionSet)
namespace Luau namespace Luau
{ {
@ -304,7 +307,9 @@ bool NormalizedType::isUnknown() const
// Otherwise, we can still be unknown! // Otherwise, we can still be unknown!
bool hasAllPrimitives = isPrim(booleans, PrimitiveType::Boolean) && isPrim(nils, PrimitiveType::NilType) && isNumber(numbers) && bool hasAllPrimitives = isPrim(booleans, PrimitiveType::Boolean) && isPrim(nils, PrimitiveType::NilType) && isNumber(numbers) &&
strings.isString() && isPrim(threads, PrimitiveType::Thread) && isThread(threads); strings.isString() &&
(FFlag::LuauNormalizedBufferIsNotUnknown ? isThread(threads) && isBuffer(buffers)
: isPrim(threads, PrimitiveType::Thread) && isThread(threads));
// Check is class // Check is class
bool isTopClass = false; bool isTopClass = false;
@ -579,7 +584,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ
{ {
left = follow(left); left = follow(left);
right = follow(right); right = follow(right);
// We're asking if intersection is inahbited between left and right but we've already seen them .... // We're asking if intersection is inhabited between left and right but we've already seen them ....
if (cacheInhabitance) if (cacheInhabitance)
{ {
@ -1685,6 +1690,13 @@ NormalizationResult Normalizer::unionNormals(NormalizedType& here, const Normali
return res; return res;
} }
if (FFlag::LuauNormalizeLimitFunctionSet)
{
// Limit based on worst-case expansion of the function unions
if (here.functions.parts.size() * there.functions.parts.size() >= size_t(FInt::LuauNormalizeUnionLimit))
return NormalizationResult::HitLimits;
}
here.booleans = unionOfBools(here.booleans, there.booleans); here.booleans = unionOfBools(here.booleans, there.booleans);
unionClasses(here.classes, there.classes); unionClasses(here.classes, there.classes);
@ -1696,6 +1708,7 @@ NormalizationResult Normalizer::unionNormals(NormalizedType& here, const Normali
here.buffers = (get<NeverType>(there.buffers) ? here.buffers : there.buffers); here.buffers = (get<NeverType>(there.buffers) ? here.buffers : there.buffers);
unionFunctions(here.functions, there.functions); unionFunctions(here.functions, there.functions);
unionTables(here.tables, there.tables); unionTables(here.tables, there.tables);
return NormalizationResult::True; return NormalizationResult::True;
} }
@ -1735,7 +1748,7 @@ NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, N
return NormalizationResult::True; return NormalizationResult::True;
} }
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explanation of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::unionNormalWithTy( NormalizationResult Normalizer::unionNormalWithTy(
NormalizedType& here, NormalizedType& here,
TypeId there, TypeId there,
@ -2285,9 +2298,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
else if (isSubclass(there, hereTy)) else if (isSubclass(there, hereTy))
{ {
TypeIds negations = std::move(hereNegations); TypeIds negations = std::move(hereNegations);
bool emptyIntersectWithNegation = false;
for (auto nIt = negations.begin(); nIt != negations.end();) for (auto nIt = negations.begin(); nIt != negations.end();)
{ {
if (isSubclass(there, *nIt))
{
// Hitting this block means that the incoming class is a
// subclass of this type, _and_ one of its negations is a
// superclass of this type, e.g.:
//
// Dog & ~Animal
//
// Clearly this intersects to never, so we mark this class as
// being removed from the normalized class type.
emptyIntersectWithNegation = true;
break;
}
if (!isSubclass(*nIt, there)) if (!isSubclass(*nIt, there))
{ {
nIt = negations.erase(nIt); nIt = negations.erase(nIt);
@ -2300,7 +2328,8 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
it = heres.ordering.erase(it); it = heres.ordering.erase(it);
heres.classes.erase(hereTy); heres.classes.erase(hereTy);
heres.pushPair(there, std::move(negations)); if (!emptyIntersectWithNegation)
heres.pushPair(there, std::move(negations));
break; break;
} }
// If the incoming class is a superclass of the current class, we don't // If the incoming class is a superclass of the current class, we don't
@ -2585,11 +2614,31 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
{ {
if (tprop.readTy.has_value()) if (tprop.readTy.has_value())
{ {
// if the intersection of the read types of a property is uninhabited, the whole table is `never`. if (FFlag::LuauFixInfiniteRecursionInNormalization)
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance)
{ {
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
// If any property is going to get mapped to `never`, we can just call the entire table `never`.
// Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties.
// Prior versions of this code attempted to do this semantically using the normalization machinery, but this
// mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach
// will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach
// once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize
// when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that
// construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization`
if (get<NeverType>(ty))
return {builtinTypes->neverType};
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
else
{
// if the intersection of the read types of a property is uninhabited, the whole table is `never`.
// We've seen these table prop elements before and we're about to ask if their intersection
// is inhabited
auto pair1 = std::pair{*hprop.readTy, *tprop.readTy}; auto pair1 = std::pair{*hprop.readTy, *tprop.readTy};
auto pair2 = std::pair{*tprop.readTy, *hprop.readTy}; auto pair2 = std::pair{*tprop.readTy, *hprop.readTy};
if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2)) if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2))
@ -2604,6 +2653,8 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
seenTablePropPairs.insert(pair2); seenTablePropPairs.insert(pair2);
} }
// FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this
// fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here.
Set<TypeId> seenSet{nullptr}; Set<TypeId> seenSet{nullptr};
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet); NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet);
@ -2617,34 +2668,6 @@ std::optional<TypeId> Normalizer::intersectionOfTables(TypeId here, TypeId there
hereSubThere &= (ty == hprop.readTy); hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy); thereSubHere &= (ty == tprop.readTy);
} }
else
{
if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy))
{
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
return {builtinTypes->neverType};
}
else
{
seenSet.insert(*hprop.readTy);
seenSet.insert(*tprop.readTy);
}
NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy);
seenSet.erase(*hprop.readTy);
seenSet.erase(*tprop.readTy);
if (NormalizationResult::True != res)
return {builtinTypes->neverType};
TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result;
prop.readTy = ty;
hereSubThere &= (ty == hprop.readTy);
thereSubHere &= (ty == tprop.readTy);
}
} }
else else
{ {
@ -3040,15 +3063,12 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(
return NormalizationResult::True; return NormalizationResult::True;
} }
// See above for an explaination of `ignoreSmallerTyvars`. // See above for an explanation of `ignoreSmallerTyvars`.
NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars)
{ {
if (FFlag::LuauIntersectNormalsNeedsToTrackResourceLimits) RecursionCounter _rc(&sharedState->counters.recursionCount);
{ if (!withinResourceLimits())
RecursionCounter _rc(&sharedState->counters.recursionCount); return NormalizationResult::HitLimits;
if (!withinResourceLimits())
return NormalizationResult::HitLimits;
}
if (!get<NeverType>(there.tops)) if (!get<NeverType>(there.tops))
{ {
@ -3061,11 +3081,17 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor
return unionNormals(here, there, ignoreSmallerTyvars); return unionNormals(here, there, ignoreSmallerTyvars);
} }
// Limit based on worst-case expansion of the table intersection // Limit based on worst-case expansion of the table/function intersections
// This restriction can be relaxed when table intersection simplification is improved // This restriction can be relaxed when table intersection simplification is improved
if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit)) if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit))
return NormalizationResult::HitLimits; return NormalizationResult::HitLimits;
if (FFlag::LuauNormalizeLimitFunctionSet)
{
if (here.functions.parts.size() * there.functions.parts.size() >= size_t(FInt::LuauNormalizeIntersectionLimit))
return NormalizationResult::HitLimits;
}
here.booleans = intersectionOfBools(here.booleans, there.booleans); here.booleans = intersectionOfBools(here.booleans, there.booleans);
intersectClasses(here.classes, there.classes); intersectClasses(here.classes, there.classes);
@ -3201,7 +3227,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(
{ {
TypeId errors = here.errors; TypeId errors = here.errors;
clearNormal(here); clearNormal(here);
here.errors = errors; here.errors = FFlag::LuauNormalizeIntersectErrorToAnError && get<ErrorType>(errors) ? errors : there;
} }
else if (const PrimitiveType* ptv = get<PrimitiveType>(there)) else if (const PrimitiveType* ptv = get<PrimitiveType>(there))
{ {
@ -3298,8 +3324,18 @@ NormalizationResult Normalizer::intersectNormalWithTy(
clearNormal(here); clearNormal(here);
return NormalizationResult::True; return NormalizationResult::True;
} }
else if (FFlag::LuauNormalizeNegatedErrorToAnError && get<ErrorType>(t))
{
// ~error is still an error, so intersecting with the negation is the same as intersecting with a type
TypeId errors = here.errors;
clearNormal(here);
here.errors = FFlag::LuauNormalizeIntersectErrorToAnError && get<ErrorType>(errors) ? errors : t;
}
else if (auto nt = get<NegationType>(t)) else if (auto nt = get<NegationType>(t))
{
here.tyvars = std::move(tyvars);
return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes); return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes);
}
else else
{ {
// TODO negated unions, intersections, table, and function. // TODO negated unions, intersections, table, and function.

View file

@ -107,134 +107,4 @@ void quantify(TypeId ty, TypeLevel level)
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
} }
struct PureQuantifier : Substitution
{
Scope* scope;
OrderedMap<TypeId, TypeId> insertedGenerics;
OrderedMap<TypePackId, TypePackId> insertedGenericPacks;
bool seenMutableType = false;
bool seenGenericType = false;
PureQuantifier(TypeArena* arena, Scope* scope)
: Substitution(TxnLog::empty(), arena)
, scope(scope)
{
}
bool isDirty(TypeId ty) override
{
LUAU_ASSERT(ty == follow(ty));
if (auto ftv = get<FreeType>(ty))
{
bool result = subsumes(scope, ftv->scope);
seenMutableType |= result;
return result;
}
else if (auto ttv = get<TableType>(ty))
{
if (ttv->state == TableState::Free)
seenMutableType = true;
else if (ttv->state == TableState::Generic)
seenGenericType = true;
return (ttv->state == TableState::Unsealed || ttv->state == TableState::Free) && subsumes(scope, ttv->scope);
}
return false;
}
bool isDirty(TypePackId tp) override
{
if (auto ftp = get<FreeTypePack>(tp))
{
return subsumes(scope, ftp->scope);
}
return false;
}
TypeId clean(TypeId ty) override
{
if (auto ftv = get<FreeType>(ty))
{
TypeId result = arena->addType(GenericType{scope});
insertedGenerics.push(ty, result);
return result;
}
else if (auto ttv = get<TableType>(ty))
{
TypeId result = arena->addType(TableType{});
TableType* resultTable = getMutable<TableType>(result);
LUAU_ASSERT(resultTable);
*resultTable = *ttv;
resultTable->level = TypeLevel{};
resultTable->scope = scope;
if (ttv->state == TableState::Free)
{
resultTable->state = TableState::Generic;
insertedGenerics.push(ty, result);
}
else if (ttv->state == TableState::Unsealed)
resultTable->state = TableState::Sealed;
return result;
}
return ty;
}
TypePackId clean(TypePackId tp) override
{
if (auto ftp = get<FreeTypePack>(tp))
{
TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{scope}});
insertedGenericPacks.push(tp, result);
return result;
}
return tp;
}
bool ignoreChildren(TypeId ty) override
{
if (get<ClassType>(ty))
return true;
return ty->persistent;
}
bool ignoreChildren(TypePackId ty) override
{
return ty->persistent;
}
};
std::optional<QuantifierResult> quantify(TypeArena* arena, TypeId ty, Scope* scope)
{
PureQuantifier quantifier{arena, scope};
std::optional<TypeId> result = quantifier.substitute(ty);
if (!result)
return std::nullopt;
FunctionType* ftv = getMutable<FunctionType>(*result);
LUAU_ASSERT(ftv);
ftv->scope = scope;
for (auto k : quantifier.insertedGenerics.keys)
{
TypeId g = quantifier.insertedGenerics.pairings[k];
if (get<GenericType>(g))
ftv->generics.push_back(g);
}
for (auto k : quantifier.insertedGenericPacks.keys)
ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]);
ftv->hasNoFreeOrGenericTypes = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType;
return std::optional<QuantifierResult>({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)});
}
} // namespace Luau } // namespace Luau

View file

@ -54,7 +54,15 @@ RefinementId RefinementArena::proposition(const RefinementKey* key, TypeId discr
if (!key) if (!key)
return nullptr; return nullptr;
return NotNull{allocator.allocate(Proposition{key, discriminantTy})}; return NotNull{allocator.allocate(Proposition{key, discriminantTy, false})};
}
RefinementId RefinementArena::implicitProposition(const RefinementKey* key, TypeId discriminantTy)
{
if (!key)
return nullptr;
return NotNull{allocator.allocate(Proposition{key, discriminantTy, true})};
} }
} // namespace Luau } // namespace Luau

View file

@ -65,7 +65,7 @@ struct RequireTracer : AstVisitor
return true; return true;
} }
AstExpr* getDependent(AstExpr* node) AstExpr* getDependent_DEPRECATED(AstExpr* node)
{ {
if (AstExprLocal* expr = node->as<AstExprLocal>()) if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local]; return locals[expr->local];
@ -78,6 +78,27 @@ struct RequireTracer : AstVisitor
else else
return nullptr; return nullptr;
} }
AstNode* getDependent(AstNode* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else if (AstExprGroup* expr = node->as<AstExprGroup>())
return expr->expr;
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
return expr->annotation;
else if (AstTypeGroup* expr = node->as<AstTypeGroup>())
return expr->type;
else if (AstTypeTypeof* expr = node->as<AstTypeTypeof>())
return expr->expr;
else
return nullptr;
}
void process() void process()
{ {
@ -91,13 +112,15 @@ struct RequireTracer : AstVisitor
// push all dependent expressions to the work stack; note that the vector is modified during traversal // push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i) for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i])) {
if (AstNode* dep = getDependent(work[i]))
work.push_back(dep); work.push_back(dep);
}
// resolve all expressions to a module info // resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i) for (size_t i = work.size(); i > 0; --i)
{ {
AstExpr* expr = work[i - 1]; AstNode* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times // when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr)) if (result.exprs.contains(expr))
@ -105,19 +128,22 @@ struct RequireTracer : AstVisitor
std::optional<ModuleInfo> info; std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr)) if (AstNode* dep = getDependent(expr))
{ {
const ModuleInfo* context = result.exprs.find(dep); const ModuleInfo* context = result.exprs.find(dep);
// locals just inherit their dependent context, no resolution required if (context && expr->is<AstExprLocal>())
if (expr->is<AstExprLocal>()) info = *context; // locals just inherit their dependent context, no resolution required
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt; else if (context && (expr->is<AstExprGroup>() || expr->is<AstTypeGroup>()))
else info = *context; // simple group nodes propagate their value
info = fileResolver->resolveModule(context, expr); else if (context && (expr->is<AstTypeTypeof>() || expr->is<AstExprTypeAssertion>()))
info = *context; // typeof type annotations will resolve to the typeof content
else if (AstExpr* asExpr = expr->asExpr())
info = fileResolver->resolveModule(context, asExpr);
} }
else else if (AstExpr* asExpr = expr->asExpr())
{ {
info = fileResolver->resolveModule(&moduleContext, expr); info = fileResolver->resolveModule(&moduleContext, asExpr);
} }
if (info) if (info)
@ -150,7 +176,7 @@ struct RequireTracer : AstVisitor
ModuleName currentModuleName; ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals; DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work; std::vector<AstNode*> work;
std::vector<AstExprCall*> requireCalls; std::vector<AstExprCall*> requireCalls;
}; };

View file

@ -84,6 +84,17 @@ std::optional<TypeId> Scope::lookupUnrefinedType(DefId def) const
return std::nullopt; return std::nullopt;
} }
std::optional<TypeId> Scope::lookupRValueRefinementType(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (auto ty = current->rvalueRefinements.find(def))
return *ty;
}
return std::nullopt;
}
std::optional<TypeId> Scope::lookup(DefId def) const std::optional<TypeId> Scope::lookup(DefId def) const
{ {
for (const Scope* current = this; current; current = current->parent.get()) for (const Scope* current = this; current; current = current->parent.get())
@ -181,6 +192,29 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt; return std::nullopt;
} }
std::optional<std::pair<Symbol, Binding>> Scope::linearSearchForBindingPair(const std::string& name, bool traverseScopeChain) const
{
const Scope* scope = this;
while (scope)
{
for (auto& [n, binding] : scope->bindings)
{
if (n.local && n.local->name == name.c_str())
return {{n, binding}};
else if (n.global.value && n.global == name.c_str())
return {{n, binding}};
}
scope = scope->parent.get();
if (!traverseScopeChain)
break;
}
return std::nullopt;
}
// Updates the `this` scope with the assignments from the `childScope` including ones that doesn't exist in `this`. // Updates the `this` scope with the assignments from the `childScope` including ones that doesn't exist in `this`.
void Scope::inheritAssignments(const ScopePtr& childScope) void Scope::inheritAssignments(const ScopePtr& childScope)
{ {

View file

@ -31,16 +31,16 @@ struct TypeSimplifier
int recursionDepth = 0; int recursionDepth = 0;
TypeId mkNegation(TypeId ty); TypeId mkNegation(TypeId ty) const;
TypeId intersectFromParts(std::set<TypeId> parts); TypeId intersectFromParts(std::set<TypeId> parts);
TypeId intersectUnionWithType(TypeId unionTy, TypeId right); TypeId intersectUnionWithType(TypeId left, TypeId right);
TypeId intersectUnions(TypeId left, TypeId right); TypeId intersectUnions(TypeId left, TypeId right);
TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); TypeId intersectNegatedUnion(TypeId left, TypeId right);
TypeId intersectTypeWithNegation(TypeId a, TypeId b); TypeId intersectTypeWithNegation(TypeId left, TypeId right);
TypeId intersectNegations(TypeId a, TypeId b); TypeId intersectNegations(TypeId left, TypeId right);
TypeId intersectIntersectionWithType(TypeId left, TypeId right); TypeId intersectIntersectionWithType(TypeId left, TypeId right);
@ -48,8 +48,8 @@ struct TypeSimplifier
// unions, intersections, or negations. // unions, intersections, or negations.
std::optional<TypeId> basicIntersect(TypeId left, TypeId right); std::optional<TypeId> basicIntersect(TypeId left, TypeId right);
TypeId intersect(TypeId ty, TypeId discriminant); TypeId intersect(TypeId left, TypeId right);
TypeId union_(TypeId ty, TypeId discriminant); TypeId union_(TypeId left, TypeId right);
TypeId simplify(TypeId ty); TypeId simplify(TypeId ty);
TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen); TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen);
@ -573,7 +573,7 @@ Relation relate(TypeId left, TypeId right)
return relate(left, right, seen); return relate(left, right, seen);
} }
TypeId TypeSimplifier::mkNegation(TypeId ty) TypeId TypeSimplifier::mkNegation(TypeId ty) const
{ {
TypeId result = nullptr; TypeId result = nullptr;

View file

@ -13,6 +13,7 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256) LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256)
LUAU_FASTFLAG(LuauSyntheticErrors) LUAU_FASTFLAG(LuauSyntheticErrors)
LUAU_FASTFLAG(LuauDeprecatedAttribute)
namespace Luau namespace Luau
{ {
@ -102,6 +103,8 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a
clone.tags = a.tags; clone.tags = a.tags;
clone.argNames = a.argNames; clone.argNames = a.argNames;
clone.isCheckedFunction = a.isCheckedFunction; clone.isCheckedFunction = a.isCheckedFunction;
if (FFlag::LuauDeprecatedAttribute)
clone.isDeprecatedFunction = a.isDeprecatedFunction;
return dest.addType(std::move(clone)); return dest.addType(std::move(clone));
} }
else if constexpr (std::is_same_v<T, TableType>) else if constexpr (std::is_same_v<T, TableType>)

View file

@ -22,7 +22,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity) LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity)
LUAU_FASTFLAGVARIABLE(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAGVARIABLE(LuauSubtypingStopAtNormFail)
namespace Luau namespace Luau
{ {
@ -416,6 +416,14 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope
SubtypingResult result = isCovariantWith(env, subTy, superTy, scope); SubtypingResult result = isCovariantWith(env, subTy, superTy, scope);
if (FFlag::LuauSubtypingStopAtNormFail && result.normalizationTooComplex)
{
if (result.isCacheable)
resultCache[{subTy, superTy}] = result;
return result;
}
for (const auto& [subTy, bounds] : env.mappedGenerics) for (const auto& [subTy, bounds] : env.mappedGenerics)
{ {
const auto& lb = bounds.lowerBound; const auto& lb = bounds.lowerBound;
@ -593,7 +601,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
if (!result.isSubtype && !result.normalizationTooComplex) if (!result.isSubtype && !result.normalizationTooComplex)
{ {
SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope); SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope);
if (semantic.isSubtype)
if (FFlag::LuauSubtypingStopAtNormFail && semantic.normalizationTooComplex)
{
result = semantic;
}
else if (semantic.isSubtype)
{ {
semantic.reasoning.clear(); semantic.reasoning.clear();
result = semantic; result = semantic;
@ -608,7 +621,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
if (!result.isSubtype && !result.normalizationTooComplex) if (!result.isSubtype && !result.normalizationTooComplex)
{ {
SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope); SubtypingResult semantic = isCovariantWith(env, normalizer->normalize(subTy), normalizer->normalize(superTy), scope);
if (semantic.isSubtype)
if (FFlag::LuauSubtypingStopAtNormFail && semantic.normalizationTooComplex)
{
result = semantic;
}
else if (semantic.isSubtype)
{ {
// Clear the semantic reasoning, as any reasonings within // Clear the semantic reasoning, as any reasonings within
// potentially contain invalid paths. // potentially contain invalid paths.
@ -754,7 +772,8 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
// Match head types pairwise // Match head types pairwise
for (size_t i = 0; i < headSize; ++i) for (size_t i = 0; i < headSize; ++i)
results.push_back(isCovariantWith(env, subHead[i], superHead[i], scope).withBothComponent(TypePath::Index{i})); results.push_back(isCovariantWith(env, subHead[i], superHead[i], scope).withBothComponent(TypePath::Index{i, TypePath::Index::Variant::Pack})
);
// Handle mismatched head sizes // Handle mismatched head sizes
@ -767,7 +786,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
for (size_t i = headSize; i < superHead.size(); ++i) for (size_t i = headSize; i < superHead.size(); ++i)
results.push_back(isCovariantWith(env, vt->ty, superHead[i], scope) results.push_back(isCovariantWith(env, vt->ty, superHead[i], scope)
.withSubPath(TypePath::PathBuilder().tail().variadic().build()) .withSubPath(TypePath::PathBuilder().tail().variadic().build())
.withSuperComponent(TypePath::Index{i})); .withSuperComponent(TypePath::Index{i, TypePath::Index::Variant::Pack}));
} }
else if (auto gt = get<GenericTypePack>(*subTail)) else if (auto gt = get<GenericTypePack>(*subTail))
{ {
@ -821,7 +840,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
{ {
for (size_t i = headSize; i < subHead.size(); ++i) for (size_t i = headSize; i < subHead.size(); ++i)
results.push_back(isCovariantWith(env, subHead[i], vt->ty, scope) results.push_back(isCovariantWith(env, subHead[i], vt->ty, scope)
.withSubComponent(TypePath::Index{i}) .withSubComponent(TypePath::Index{i, TypePath::Index::Variant::Pack})
.withSuperPath(TypePath::PathBuilder().tail().variadic().build())); .withSuperPath(TypePath::PathBuilder().tail().variadic().build()));
} }
else if (auto gt = get<GenericTypePack>(*superTail)) else if (auto gt = get<GenericTypePack>(*superTail))
@ -859,7 +878,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId
else else
return SubtypingResult{false} return SubtypingResult{false}
.withSuperComponent(TypePath::PackField::Tail) .withSuperComponent(TypePath::PackField::Tail)
.withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); .withError({scope->location, UnexpectedTypePackInSubtyping{*superTail}});
} }
else else
return {false}; return {false};
@ -1082,6 +1101,10 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
for (TypeId ty : superUnion) for (TypeId ty : superUnion)
{ {
SubtypingResult next = isCovariantWith(env, subTy, ty, scope); SubtypingResult next = isCovariantWith(env, subTy, ty, scope);
if (FFlag::LuauSubtypingStopAtNormFail && next.normalizationTooComplex)
return SubtypingResult{false, /* normalizationTooComplex */ true};
if (next.isSubtype) if (next.isSubtype)
return SubtypingResult{true}; return SubtypingResult{true};
} }
@ -1100,7 +1123,13 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Unio
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0; size_t i = 0;
for (TypeId ty : subUnion) for (TypeId ty : subUnion)
subtypings.push_back(isCovariantWith(env, ty, superTy, scope).withSubComponent(TypePath::Index{i++})); {
subtypings.push_back(isCovariantWith(env, ty, superTy, scope).withSubComponent(TypePath::Index{i++, TypePath::Index::Variant::Union}));
if (FFlag::LuauSubtypingStopAtNormFail && subtypings.back().normalizationTooComplex)
return SubtypingResult{false, /* normalizationTooComplex */ true};
}
return SubtypingResult::all(subtypings); return SubtypingResult::all(subtypings);
} }
@ -1110,7 +1139,13 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypeId sub
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0; size_t i = 0;
for (TypeId ty : superIntersection) for (TypeId ty : superIntersection)
subtypings.push_back(isCovariantWith(env, subTy, ty, scope).withSuperComponent(TypePath::Index{i++})); {
subtypings.push_back(isCovariantWith(env, subTy, ty, scope).withSuperComponent(TypePath::Index{i++, TypePath::Index::Variant::Intersection}));
if (FFlag::LuauSubtypingStopAtNormFail && subtypings.back().normalizationTooComplex)
return SubtypingResult{false, /* normalizationTooComplex */ true};
}
return SubtypingResult::all(subtypings); return SubtypingResult::all(subtypings);
} }
@ -1120,7 +1155,13 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Inte
std::vector<SubtypingResult> subtypings; std::vector<SubtypingResult> subtypings;
size_t i = 0; size_t i = 0;
for (TypeId ty : subIntersection) for (TypeId ty : subIntersection)
subtypings.push_back(isCovariantWith(env, ty, superTy, scope).withSubComponent(TypePath::Index{i++})); {
subtypings.push_back(isCovariantWith(env, ty, superTy, scope).withSubComponent(TypePath::Index{i++, TypePath::Index::Variant::Intersection}));
if (FFlag::LuauSubtypingStopAtNormFail && subtypings.back().normalizationTooComplex)
return SubtypingResult{false, /* normalizationTooComplex */ true};
}
return SubtypingResult::any(subtypings); return SubtypingResult::any(subtypings);
} }
@ -1410,7 +1451,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Meta
// of the supertype table. // of the supertype table.
// //
// There's a flaw here in that if the __index metamethod contributes a new // There's a flaw here in that if the __index metamethod contributes a new
// field that would satisfy the subtyping relationship, we'll erronously say // field that would satisfy the subtyping relationship, we'll erroneously say
// that the metatable isn't a subtype of the table, even though they have // that the metatable isn't a subtype of the table, even though they have
// compatible properties/shapes. We'll revisit this later when we have a // compatible properties/shapes. We'll revisit this later when we have a
// better understanding of how important this is. // better understanding of how important this is.
@ -1474,7 +1515,7 @@ SubtypingResult Subtyping::isCovariantWith(
// If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it. // If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it.
// This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent. // This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent.
if (FFlag::LuauRetrySubtypingWithoutHiddenPack && !result.isSubtype) if (!result.isSubtype)
{ {
auto [arguments, tail] = flatten(superFunction->argTypes); auto [arguments, tail] = flatten(superFunction->argTypes);
@ -1760,7 +1801,12 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Type
{ {
results.emplace_back(); results.emplace_back();
for (TypeId superTy : superTypes) for (TypeId superTy : superTypes)
{
results.back().orElse(isCovariantWith(env, subTy, superTy, scope)); results.back().orElse(isCovariantWith(env, subTy, superTy, scope));
if (FFlag::LuauSubtypingStopAtNormFail && results.back().normalizationTooComplex)
return SubtypingResult{false, /* normalizationTooComplex */ true};
}
} }
return SubtypingResult::all(results); return SubtypingResult::all(results);

View file

@ -4,7 +4,6 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauSymbolEquality)
namespace Luau namespace Luau
{ {
@ -15,10 +14,8 @@ bool Symbol::operator==(const Symbol& rhs) const
return local == rhs.local; return local == rhs.local;
else if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else else
return false; return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
} }
std::string toString(const Symbol& name) std::string toString(const Symbol& name)

View file

@ -1,14 +1,22 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TableLiteralInference.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h" #include "Luau/Unifier2.h"
LUAU_FASTFLAGVARIABLE(LuauBidirectionalInferenceUpcast)
LUAU_FASTFLAGVARIABLE(LuauBidirectionalInferenceCollectIndexerTypes)
LUAU_FASTFLAGVARIABLE(LuauBidirectionalFailsafe)
namespace Luau namespace Luau
{ {
@ -109,6 +117,7 @@ TypeId matchLiteralType(
NotNull<BuiltinTypes> builtinTypes, NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena, NotNull<TypeArena> arena,
NotNull<Unifier2> unifier, NotNull<Unifier2> unifier,
NotNull<Subtyping> subtyping,
TypeId expectedType, TypeId expectedType,
TypeId exprType, TypeId exprType,
const AstExpr* expr, const AstExpr* expr,
@ -129,17 +138,38 @@ TypeId matchLiteralType(
* things like replace explicit named properties with indexers as required * things like replace explicit named properties with indexers as required
* by the expected type. * by the expected type.
*/ */
if (!isLiteral(expr)) if (!isLiteral(expr))
return exprType; {
if (FFlag::LuauBidirectionalInferenceUpcast)
{
auto result = subtyping->isSubtype(/*subTy=*/exprType, /*superTy=*/expectedType, unifier->scope);
return result.isSubtype ? expectedType : exprType;
}
else
return exprType;
}
expectedType = follow(expectedType); expectedType = follow(expectedType);
exprType = follow(exprType); exprType = follow(exprType);
if (get<AnyType>(expectedType) || get<UnknownType>(expectedType)) if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
{ {
// "Narrowing" to unknown or any is not going to do anything useful. // The intent of `matchLiteralType` is to upcast values when it's safe
return exprType; // to do so. it's always safe to upcast to `any` or `unknown`, so we
// can unconditionally do so here.
if (is<AnyType, UnknownType>(expectedType))
return expectedType;
} }
else
{
if (get<AnyType>(expectedType) || get<UnknownType>(expectedType))
{
// "Narrowing" to unknown or any is not going to do anything useful.
return exprType;
}
}
if (expr->is<AstExprConstantString>()) if (expr->is<AstExprConstantString>())
{ {
@ -207,11 +237,29 @@ TypeId matchLiteralType(
return exprType; return exprType;
} }
// TODO: lambdas
if (FFlag::LuauBidirectionalInferenceUpcast && expr->is<AstExprFunction>())
{
// TODO: Push argument / return types into the lambda. For now, just do
// the non-literal thing: check for a subtype and upcast if valid.
auto result = subtyping->isSubtype(/*subTy=*/exprType, /*superTy=*/expectedType, unifier->scope);
return result.isSubtype
? expectedType
: exprType;
}
if (auto exprTable = expr->as<AstExprTable>()) if (auto exprTable = expr->as<AstExprTable>())
{ {
TableType* const tableTy = getMutable<TableType>(exprType); TableType* const tableTy = getMutable<TableType>(exprType);
// This can occur if we have an expression like:
//
// { x = {}, x = 42 }
//
// The type of this will be `{ x: number }`
if (FFlag::LuauBidirectionalFailsafe && !tableTy)
return exprType;
LUAU_ASSERT(tableTy); LUAU_ASSERT(tableTy);
const TableType* expectedTableTy = get<TableType>(expectedType); const TableType* expectedTableTy = get<TableType>(expectedType);
@ -226,7 +274,7 @@ TypeId matchLiteralType(
if (tt) if (tt)
{ {
TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *tt, exprType, expr, toBlock); TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *tt, exprType, expr, toBlock);
parts.push_back(res); parts.push_back(res);
return arena->addType(UnionType{std::move(parts)}); return arena->addType(UnionType{std::move(parts)});
@ -236,6 +284,11 @@ TypeId matchLiteralType(
return exprType; return exprType;
} }
DenseHashSet<AstExprConstantString*> keysToDelete{nullptr};
DenseHashSet<TypeId> indexerKeyTypes{nullptr};
DenseHashSet<TypeId> indexerValueTypes{nullptr};
for (const AstExprTable::Item& item : exprTable->items) for (const AstExprTable::Item& item : exprTable->items)
{ {
if (isRecord(item)) if (isRecord(item))
@ -243,12 +296,20 @@ TypeId matchLiteralType(
const AstArray<char>& s = item.key->as<AstExprConstantString>()->value; const AstArray<char>& s = item.key->as<AstExprConstantString>()->value;
std::string keyStr{s.data, s.data + s.size}; std::string keyStr{s.data, s.data + s.size};
auto it = tableTy->props.find(keyStr); auto it = tableTy->props.find(keyStr);
// This can occur, potentially, if we are re-entrant.
if (FFlag::LuauBidirectionalFailsafe && it == tableTy->props.end())
continue;
LUAU_ASSERT(it != tableTy->props.end()); LUAU_ASSERT(it != tableTy->props.end());
Property& prop = it->second; Property& prop = it->second;
// Table literals always initially result in shared read-write types // If we encounter a duplcate property, we may have already
LUAU_ASSERT(prop.isShared()); // set it to be read-only. If that's the case, the only thing
// that will definitely crash is trying to access a write
// only property.
LUAU_ASSERT(!prop.isWriteOnly());
TypeId propTy = *prop.readTy; TypeId propTy = *prop.readTy;
auto it2 = expectedTableTy->props.find(keyStr); auto it2 = expectedTableTy->props.find(keyStr);
@ -269,18 +330,28 @@ TypeId matchLiteralType(
builtinTypes, builtinTypes,
arena, arena,
unifier, unifier,
subtyping,
expectedTableTy->indexer->indexResultType, expectedTableTy->indexer->indexResultType,
propTy, propTy,
item.value, item.value,
toBlock toBlock
); );
if (tableTy->indexer) if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
unifier->unify(matchedType, tableTy->indexer->indexResultType); {
indexerKeyTypes.insert(arena->addType(SingletonType{StringSingleton{keyStr}}));
indexerValueTypes.insert(matchedType);
}
else else
tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; {
if (tableTy->indexer)
unifier->unify(matchedType, tableTy->indexer->indexResultType);
else
tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType};
}
keysToDelete.insert(item.key->as<AstExprConstantString>());
tableTy->props.erase(keyStr);
} }
// If it's just an extra property and the expected type // If it's just an extra property and the expected type
@ -304,21 +375,21 @@ TypeId matchLiteralType(
if (expectedProp.isShared()) if (expectedProp.isShared())
{ {
matchedType = matchedType =
matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedReadTy, propTy, item.value, toBlock);
prop.readTy = matchedType; prop.readTy = matchedType;
prop.writeTy = matchedType; prop.writeTy = matchedType;
} }
else if (expectedReadTy) else if (expectedReadTy)
{ {
matchedType = matchedType =
matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedReadTy, propTy, item.value, toBlock);
prop.readTy = matchedType; prop.readTy = matchedType;
prop.writeTy.reset(); prop.writeTy.reset();
} }
else if (expectedWriteTy) else if (expectedWriteTy)
{ {
matchedType = matchedType =
matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedWriteTy, propTy, item.value, toBlock); matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, subtyping, *expectedWriteTy, propTy, item.value, toBlock);
prop.readTy.reset(); prop.readTy.reset();
prop.writeTy = matchedType; prop.writeTy = matchedType;
} }
@ -335,6 +406,11 @@ TypeId matchLiteralType(
LUAU_ASSERT(matchedType); LUAU_ASSERT(matchedType);
(*astExpectedTypes)[item.value] = matchedType; (*astExpectedTypes)[item.value] = matchedType;
// NOTE: We do *not* add to the potential indexer types here.
// I think this is correct to support something like:
//
// { [string]: number, foo: boolean }
//
} }
else if (item.kind == AstExprTable::Item::List) else if (item.kind == AstExprTable::Item::List)
{ {
@ -352,15 +428,25 @@ TypeId matchLiteralType(
builtinTypes, builtinTypes,
arena, arena,
unifier, unifier,
subtyping,
expectedTableTy->indexer->indexResultType, expectedTableTy->indexer->indexResultType,
*propTy, *propTy,
item.value, item.value,
toBlock toBlock
); );
// if the index result type is the prop type, we can replace it with the matched type here. if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
if (tableTy->indexer->indexResultType == *propTy) {
tableTy->indexer->indexResultType = matchedType; indexerKeyTypes.insert(builtinTypes->numberType);
indexerValueTypes.insert(matchedType);
}
else
{
// if the index result type is the prop type, we can replace it with the matched type here.
if (tableTy->indexer->indexResultType == *propTy)
tableTy->indexer->indexResultType = matchedType;
}
} }
} }
else if (item.kind == AstExprTable::Item::General) else if (item.kind == AstExprTable::Item::General)
@ -382,11 +468,25 @@ TypeId matchLiteralType(
// Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings) // Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings)
if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer) if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer)
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;
if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes)
{
indexerKeyTypes.insert(tKey);
indexerValueTypes.insert(tProp);
}
} }
else else
LUAU_ASSERT(!"Unexpected"); LUAU_ASSERT(!"Unexpected");
} }
for (const auto& key : keysToDelete)
{
const AstArray<char>& s = key->value;
std::string keyStr{s.data, s.data + s.size};
tableTy->props.erase(keyStr);
}
// Keys that the expectedType says we should have, but that aren't // Keys that the expectedType says we should have, but that aren't
// specified by the AST fragment. // specified by the AST fragment.
// //
@ -436,9 +536,39 @@ TypeId matchLiteralType(
// have one too. // have one too.
// TODO: If the expected table also has an indexer, we might want to // TODO: If the expected table also has an indexer, we might want to
// push the expected indexer's types into it. // push the expected indexer's types into it.
if (expectedTableTy->indexer && !tableTy->indexer) if (FFlag::LuauBidirectionalInferenceCollectIndexerTypes && expectedTableTy->indexer)
{ {
tableTy->indexer = expectedTableTy->indexer; if (indexerValueTypes.size() > 0 && indexerKeyTypes.size() > 0)
{
TypeId inferredKeyType = builtinTypes->neverType;
TypeId inferredValueType = builtinTypes->neverType;
for (auto kt: indexerKeyTypes)
{
auto simplified = simplifyUnion(builtinTypes, arena, inferredKeyType, kt);
inferredKeyType = simplified.result;
}
for (auto vt: indexerValueTypes)
{
auto simplified = simplifyUnion(builtinTypes, arena, inferredValueType, vt);
inferredValueType = simplified.result;
}
tableTy->indexer = TableIndexer{inferredKeyType, inferredValueType};
auto keyCheck = subtyping->isSubtype(inferredKeyType, expectedTableTy->indexer->indexType, unifier->scope);
if (keyCheck.isSubtype)
tableTy->indexer->indexType = expectedTableTy->indexer->indexType;
auto valueCheck = subtyping->isSubtype(inferredValueType, expectedTableTy->indexer->indexResultType, unifier->scope);
if (valueCheck.isSubtype)
tableTy->indexer->indexResultType = expectedTableTy->indexer->indexResultType;
}
else
LUAU_ASSERT(indexerKeyTypes.empty() && indexerValueTypes.empty());
}
else
{
if (expectedTableTy->indexer && !tableTy->indexer)
{
tableTy->indexer = expectedTableTy->indexer;
}
} }
} }

View file

@ -1865,6 +1865,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
} }
else if constexpr (std::is_same_v<T, EqualityConstraint>) else if constexpr (std::is_same_v<T, EqualityConstraint>)
return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType); return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType);
else if constexpr (std::is_same_v<T, TableCheckConstraint>)
return "table_check " + tos(c.expectedType) + " :> " + tos(c.exprType);
else else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch"); static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
}; };

File diff suppressed because it is too large Load diff

View file

@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
return false; return false;
} }
FreeType::FreeType(TypeLevel level) // New constructors
FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex()) : index(Unifiable::freshIndex())
, level(level) , level(level)
, scope(nullptr) , lowerBound(lowerBound)
{ , upperBound(upperBound)
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
{ {
} }
@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound)
{ {
} }
FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
, lowerBound(lowerBound)
, upperBound(upperBound)
{
}
// Old constructors
FreeType::FreeType(TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(nullptr)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope)
: index(Unifiable::freshIndex())
, level{}
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
FreeType::FreeType(Scope* scope, TypeLevel level)
: index(Unifiable::freshIndex())
, level(level)
, scope(scope)
{
LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds);
}
GenericType::GenericType() GenericType::GenericType()
: index(Unifiable::freshIndex()) : index(Unifiable::freshIndex())
, name("g" + std::to_string(index)) , name("g" + std::to_string(index))

View file

@ -3,6 +3,7 @@
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena); LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv)
return allocated; return allocated;
} }
TypeId TypeArena::freshType(TypeLevel level) TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope)
{
TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(NotNull<BuiltinTypes> builtins, Scope* scope, TypeLevel level)
{
TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType_DEPRECATED(TypeLevel level)
{ {
TypeId allocated = types.allocate(FreeType{level}); TypeId allocated = types.allocate(FreeType{level});
@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level)
return allocated; return allocated;
} }
TypeId TypeArena::freshType(Scope* scope) TypeId TypeArena::freshType_DEPRECATED(Scope* scope)
{ {
TypeId allocated = types.allocate(FreeType{scope}); TypeId allocated = types.allocate(FreeType{scope});
@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope)
return allocated; return allocated;
} }
TypeId TypeArena::freshType(Scope* scope, TypeLevel level) TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level)
{ {
TypeId allocated = types.allocate(FreeType{scope, level}); TypeId allocated = types.allocate(FreeType{scope, level});

View file

@ -13,6 +13,8 @@
#include <string> #include <string>
LUAU_FASTFLAG(LuauStoreCSTData2)
static char* allocateString(Luau::Allocator& allocator, std::string_view contents) static char* allocateString(Luau::Allocator& allocator, std::string_view contents)
{ {
char* result = (char*)allocator.allocate(contents.size() + 1); char* result = (char*)allocator.allocate(contents.size() + 1);
@ -261,24 +263,24 @@ public:
if (hasSeen(&ftv)) if (hasSeen(&ftv))
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"), std::nullopt, Location()); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"), std::nullopt, Location());
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
generics.size = ftv.generics.size(); generics.size = ftv.generics.size();
generics.data = static_cast<AstGenericType*>(allocator->allocate(sizeof(AstGenericType) * generics.size)); generics.data = static_cast<AstGenericType**>(allocator->allocate(sizeof(AstGenericType) * generics.size));
size_t numGenerics = 0; size_t numGenerics = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{ {
if (auto gtv = get<GenericType>(*it)) if (auto gtv = get<GenericType>(*it))
generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; generics.data[numGenerics++] = allocator->alloc<AstGenericType>(Location(), AstName(gtv->name.c_str()), nullptr);
} }
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
genericPacks.size = ftv.genericPacks.size(); genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstGenericTypePack*>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); genericPacks.data = static_cast<AstGenericTypePack**>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size));
size_t numGenericPacks = 0; size_t numGenericPacks = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{ {
if (auto gtv = get<GenericTypePack>(*it)) if (auto gtv = get<GenericTypePack>(*it))
genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; genericPacks.data[numGenericPacks++] = allocator->alloc<AstGenericTypePack>(Location(), AstName(gtv->name.c_str()), nullptr);
} }
AstArray<AstType*> argTypes; AstArray<AstType*> argTypes;
@ -305,7 +307,8 @@ public:
std::optional<AstArgumentName>* arg = &argNames.data[i++]; std::optional<AstArgumentName>* arg = &argNames.data[i++];
if (el) if (el)
new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location)); new (arg)
std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), FFlag::LuauStoreCSTData2 ? Location() : el->location));
else else
new (arg) std::optional<AstArgumentName>(); new (arg) std::optional<AstArgumentName>();
} }

View file

@ -7,7 +7,6 @@
#include "Luau/DcrLogger.h" #include "Luau/DcrLogger.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Metamethods.h" #include "Luau/Metamethods.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
@ -27,13 +26,14 @@
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <sstream>
#include <ostream>
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues) LUAU_FASTFLAGVARIABLE(LuauImproveTypePathsInErrors)
LUAU_FASTFLAG(LuauUserTypeFunTypecheck)
LUAU_FASTFLAGVARIABLE(LuauTypeCheckerAcceptNumberConcats)
namespace Luau namespace Luau
{ {
@ -176,7 +176,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor
DenseHashSet<TypeId> mentionedFunctions{nullptr}; DenseHashSet<TypeId> mentionedFunctions{nullptr};
DenseHashSet<TypePackId> mentionedFunctionPacks{nullptr}; DenseHashSet<TypePackId> mentionedFunctionPacks{nullptr};
InternalTypeFunctionFinder(std::vector<TypeId>& declStack) explicit InternalTypeFunctionFinder(std::vector<TypeId>& declStack)
{ {
TypeFunctionFinder f; TypeFunctionFinder f;
for (TypeId fn : declStack) for (TypeId fn : declStack)
@ -507,7 +507,7 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l
return instance; return instance;
} }
TypePackId TypeChecker2::lookupPack(AstExpr* expr) TypePackId TypeChecker2::lookupPack(AstExpr* expr) const
{ {
// If a type isn't in the type graph, it probably means that a recursion limit was exceeded. // If a type isn't in the type graph, it probably means that a recursion limit was exceeded.
// We'll just return anyType in these cases. Typechecking against any is very fast and this // We'll just return anyType in these cases. Typechecking against any is very fast and this
@ -557,7 +557,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation)
return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); return checkForTypeFunctionInhabitance(follow(*ty), annotation->location);
} }
std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) const
{ {
TypePackId* tp = module->astResolvedTypePacks.find(annotation); TypePackId* tp = module->astResolvedTypePacks.find(annotation);
if (tp != nullptr) if (tp != nullptr)
@ -565,7 +565,7 @@ std::optional<TypePackId> TypeChecker2::lookupPackAnnotation(AstTypePack* annota
return {}; return {};
} }
TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) const
{ {
if (TypeId* ty = module->astExpectedTypes.find(expr)) if (TypeId* ty = module->astExpectedTypes.find(expr))
return follow(*ty); return follow(*ty);
@ -573,7 +573,7 @@ TypeId TypeChecker2::lookupExpectedType(AstExpr* expr)
return builtinTypes->anyType; return builtinTypes->anyType;
} }
TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) const
{ {
if (TypeId* ty = module->astExpectedTypes.find(expr)) if (TypeId* ty = module->astExpectedTypes.find(expr))
return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt});
@ -597,7 +597,7 @@ TypePackId TypeChecker2::reconstructPack(AstArray<AstExpr*> exprs, TypeArena& ar
return arena.addTypePack(TypePack{head, tail}); return arena.addTypePack(TypePack{head, tail});
} }
Scope* TypeChecker2::findInnermostScope(Location location) Scope* TypeChecker2::findInnermostScope(Location location) const
{ {
Scope* bestScope = module->getModuleScope().get(); Scope* bestScope = module->getModuleScope().get();
@ -1205,7 +1205,8 @@ void TypeChecker2::visit(AstStatTypeAlias* stat)
void TypeChecker2::visit(AstStatTypeFunction* stat) void TypeChecker2::visit(AstStatTypeFunction* stat)
{ {
// TODO: add type checking for user-defined type functions if (FFlag::LuauUserTypeFunTypecheck)
visit(stat->body);
} }
void TypeChecker2::visit(AstTypeList types) void TypeChecker2::visit(AstTypeList types)
@ -1359,7 +1360,7 @@ void TypeChecker2::visit(AstExprGlobal* expr)
{ {
reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location);
} }
else if (FFlag::InferGlobalTypes) else
{ {
if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value)) if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value))
{ {
@ -1564,7 +1565,7 @@ void TypeChecker2::visit(AstExprCall* call)
visitCall(call); visitCall(call);
} }
std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty) std::optional<TypeId> TypeChecker2::tryStripUnionFromNil(TypeId ty) const
{ {
if (const UnionType* utv = get<UnionType>(ty)) if (const UnionType* utv = get<UnionType>(ty))
{ {
@ -1851,16 +1852,8 @@ void TypeChecker2::visit(AstExprTable* expr)
{ {
for (const AstExprTable::Item& item : expr->items) for (const AstExprTable::Item& item : expr->items)
{ {
if (FFlag::LuauTableKeysAreRValues) if (item.key)
{ visit(item.key, ValueContext::RValue);
if (item.key)
visit(item.key, ValueContext::RValue);
}
else
{
if (item.key)
visit(item.key, ValueContext::LValue);
}
visit(item.value, ValueContext::RValue); visit(item.value, ValueContext::RValue);
} }
} }
@ -2108,7 +2101,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey)
} }
else else
{ {
expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); expectedRets = module->internalTypes.addTypePack(
{FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{})
: module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})}
);
} }
TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets));
@ -2234,10 +2230,21 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey)
return builtinTypes->numberType; return builtinTypes->numberType;
case AstExprBinary::Op::Concat: case AstExprBinary::Op::Concat:
testIsSubtype(leftType, builtinTypes->stringType, expr->left->location); {
testIsSubtype(rightType, builtinTypes->stringType, expr->right->location); if (FFlag::LuauTypeCheckerAcceptNumberConcats)
{
const TypeId numberOrString = module->internalTypes.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}});
testIsSubtype(leftType, numberOrString, expr->left->location);
testIsSubtype(rightType, numberOrString, expr->right->location);
}
else
{
testIsSubtype(leftType, builtinTypes->stringType, expr->left->location);
testIsSubtype(rightType, builtinTypes->stringType, expr->right->location);
}
return builtinTypes->stringType; return builtinTypes->stringType;
}
case AstExprBinary::Op::CompareGe: case AstExprBinary::Op::CompareGe:
case AstExprBinary::Op::CompareGt: case AstExprBinary::Op::CompareGt:
case AstExprBinary::Op::CompareLe: case AstExprBinary::Op::CompareLe:
@ -2360,7 +2367,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
return *fst; return *fst;
else if (auto ftp = get<FreeTypePack>(pack)) else if (auto ftp = get<FreeTypePack>(pack))
{ {
TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope)
: module->internalTypes.addType(FreeType{ftp->scope});
TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope});
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack)); TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
@ -2377,30 +2385,30 @@ TypeId TypeChecker2::flattenPack(TypePackId pack)
ice->ice("flattenPack got a weird pack!"); ice->ice("flattenPack got a weird pack!");
} }
void TypeChecker2::visitGenerics(AstArray<AstGenericType> generics, AstArray<AstGenericTypePack> genericPacks) void TypeChecker2::visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks)
{ {
DenseHashSet<AstName> seen{AstName{}}; DenseHashSet<AstName> seen{AstName{}};
for (const auto& g : generics) for (const auto* g : generics)
{ {
if (seen.contains(g.name)) if (seen.contains(g->name))
reportError(DuplicateGenericParameter{g.name.value}, g.location); reportError(DuplicateGenericParameter{g->name.value}, g->location);
else else
seen.insert(g.name); seen.insert(g->name);
if (g.defaultValue) if (g->defaultValue)
visit(g.defaultValue); visit(g->defaultValue);
} }
for (const auto& g : genericPacks) for (const auto* g : genericPacks)
{ {
if (seen.contains(g.name)) if (seen.contains(g->name))
reportError(DuplicateGenericParameter{g.name.value}, g.location); reportError(DuplicateGenericParameter{g->name.value}, g->location);
else else
seen.insert(g.name); seen.insert(g->name);
if (g.defaultValue) if (g->defaultValue)
visit(g.defaultValue); visit(g->defaultValue);
} }
} }
@ -2422,6 +2430,8 @@ void TypeChecker2::visit(AstType* ty)
return visit(t); return visit(t);
else if (auto t = ty->as<AstTypeIntersection>()) else if (auto t = ty->as<AstTypeIntersection>())
return visit(t); return visit(t);
else if (auto t = ty->as<AstTypeGroup>())
return visit(t->type);
} }
void TypeChecker2::visit(AstTypeReference* ty) void TypeChecker2::visit(AstTypeReference* ty)
@ -2707,20 +2717,61 @@ Reasonings TypeChecker2::explainReasonings_(TID subTy, TID superTy, Location loc
if (!subLeafTy && !superLeafTy && !subLeafTp && !superLeafTp) if (!subLeafTy && !superLeafTy && !subLeafTp && !superLeafTp)
ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location); ice->ice("Subtyping test returned a reasoning where one path ends at a type and the other ends at a pack.", location);
std::string relation = "a subtype of"; if (FFlag::LuauImproveTypePathsInErrors)
if (reasoning.variance == SubtypingVariance::Invariant) {
relation = "exactly"; std::string relation = "a subtype of";
else if (reasoning.variance == SubtypingVariance::Contravariant) if (reasoning.variance == SubtypingVariance::Invariant)
relation = "a supertype of"; relation = "exactly";
else if (reasoning.variance == SubtypingVariance::Contravariant)
relation = "a supertype of";
std::string reason; std::string subLeafAsString = toString(subLeaf);
if (reasoning.subPath == reasoning.superPath) // if the string is empty, it must be an empty type pack
reason = "at " + toString(reasoning.subPath) + ", " + toString(subLeaf) + " is not " + relation + " " + toString(superLeaf); if (subLeafAsString.empty())
subLeafAsString = "()";
std::string superLeafAsString = toString(superLeaf);
// if the string is empty, it must be an empty type pack
if (superLeafAsString.empty())
superLeafAsString = "()";
std::stringstream baseReasonBuilder;
baseReasonBuilder << "`" << subLeafAsString << "` is not " << relation << " `" << superLeafAsString << "`";
std::string baseReason = baseReasonBuilder.str();
std::stringstream reason;
if (reasoning.subPath == reasoning.superPath)
reason << toStringHuman(reasoning.subPath) << "`" << subLeafAsString << "` in the former type and `" << superLeafAsString
<< "` in the latter type, and " << baseReason;
else if (!reasoning.subPath.empty() && !reasoning.superPath.empty())
reason << toStringHuman(reasoning.subPath) << "`" << subLeafAsString << "` and " << toStringHuman(reasoning.superPath) << "`"
<< superLeafAsString << "`, and " << baseReason;
else if (!reasoning.subPath.empty())
reason << toStringHuman(reasoning.subPath) << "`" << subLeafAsString << "`, which is not " << relation << " `" << superLeafAsString
<< "`";
else
reason << toStringHuman(reasoning.superPath) << "`" << superLeafAsString << "`, and " << baseReason;
reasons.push_back(reason.str());
}
else else
reason = "type " + toString(subTy) + toString(reasoning.subPath, /* prefixDot */ true) + " (" + toString(subLeaf) + ") is not " + {
relation + " " + toString(superTy) + toString(reasoning.superPath, /* prefixDot */ true) + " (" + toString(superLeaf) + ")"; std::string relation = "a subtype of";
if (reasoning.variance == SubtypingVariance::Invariant)
relation = "exactly";
else if (reasoning.variance == SubtypingVariance::Contravariant)
relation = "a supertype of";
reasons.push_back(reason); std::string reason;
if (reasoning.subPath == reasoning.superPath)
reason = "at " + toString(reasoning.subPath) + ", " + toString(subLeaf) + " is not " + relation + " " + toString(superLeaf);
else
reason = "type " + toString(subTy) + toString(reasoning.subPath, /* prefixDot */ true) + " (" + toString(subLeaf) + ") is not " +
relation + " " + toString(superTy) + toString(reasoning.superPath, /* prefixDot */ true) + " (" + toString(superLeaf) + ")";
reasons.push_back(reason);
}
// if we haven't already proved this isn't suppressing, we have to keep checking. // if we haven't already proved this isn't suppressing, we have to keep checking.
if (suppressed) if (suppressed)

File diff suppressed because it is too large Load diff

View file

@ -14,12 +14,7 @@
#include <vector> #include <vector>
LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixInner) LUAU_FASTFLAGVARIABLE(LuauTypeFunReadWriteParents)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunPrintToError)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunThreadBuffer)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunGenerics)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunCloneTail)
namespace Luau namespace Luau
{ {
@ -137,11 +132,9 @@ static std::string getTag(lua_State* L, TypeFunctionTypeId ty)
return "number"; return "number";
else if (auto s = get<TypeFunctionPrimitiveType>(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) else if (auto s = get<TypeFunctionPrimitiveType>(ty); s && s->type == TypeFunctionPrimitiveType::Type::String)
return "string"; return "string";
else if (auto s = get<TypeFunctionPrimitiveType>(ty); else if (auto s = get<TypeFunctionPrimitiveType>(ty); s && s->type == TypeFunctionPrimitiveType::Type::Thread)
FFlag::LuauUserTypeFunThreadBuffer && s && s->type == TypeFunctionPrimitiveType::Type::Thread)
return "thread"; return "thread";
else if (auto s = get<TypeFunctionPrimitiveType>(ty); else if (auto s = get<TypeFunctionPrimitiveType>(ty); s && s->type == TypeFunctionPrimitiveType::Type::Buffer)
FFlag::LuauUserTypeFunThreadBuffer && s && s->type == TypeFunctionPrimitiveType::Type::Buffer)
return "buffer"; return "buffer";
else if (get<TypeFunctionUnknownType>(ty)) else if (get<TypeFunctionUnknownType>(ty))
return "unknown"; return "unknown";
@ -163,7 +156,7 @@ static std::string getTag(lua_State* L, TypeFunctionTypeId ty)
return "function"; return "function";
else if (get<TypeFunctionClassType>(ty)) else if (get<TypeFunctionClassType>(ty))
return "class"; return "class";
else if (FFlag::LuauUserTypeFunGenerics && get<TypeFunctionGenericType>(ty)) else if (get<TypeFunctionGenericType>(ty))
return "generic"; return "generic";
LUAU_UNREACHABLE(); LUAU_UNREACHABLE();
@ -432,21 +425,11 @@ static int getNegatedValue(lua_State* L)
luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount);
TypeFunctionTypeId self = getTypeUserData(L, 1); TypeFunctionTypeId self = getTypeUserData(L, 1);
if (FFlag::LuauUserTypeFunFixInner) if (auto tfnt = get<TypeFunctionNegationType>(self); tfnt)
{ allocTypeUserData(L, tfnt->type->type);
if (auto tfnt = get<TypeFunctionNegationType>(self); tfnt)
allocTypeUserData(L, tfnt->type->type);
else
luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str());
}
else else
{ luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str());
if (auto tfnt = get<TypeFunctionNegationType>(self); !tfnt)
allocTypeUserData(L, tfnt->type->type);
else
luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str());
}
return 1; return 1;
} }
@ -687,10 +670,8 @@ static int readTableProp(lua_State* L)
auto prop = tftt->props.at(tfsst->value); auto prop = tftt->props.at(tfsst->value);
if (prop.readTy) if (prop.readTy)
allocTypeUserData(L, (*prop.readTy)->type); allocTypeUserData(L, (*prop.readTy)->type);
else if (FFlag::LuauUserTypeFunFixNoReadWrite)
lua_pushnil(L);
else else
luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); lua_pushnil(L);
return 1; return 1;
} }
@ -727,10 +708,8 @@ static int writeTableProp(lua_State* L)
auto prop = tftt->props.at(tfsst->value); auto prop = tftt->props.at(tfsst->value);
if (prop.writeTy) if (prop.writeTy)
allocTypeUserData(L, (*prop.writeTy)->type); allocTypeUserData(L, (*prop.writeTy)->type);
else if (FFlag::LuauUserTypeFunFixNoReadWrite)
lua_pushnil(L);
else else
luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); lua_pushnil(L);
return 1; return 1;
} }
@ -950,99 +929,6 @@ static void pushTypePack(lua_State* L, TypeFunctionTypePackId tp)
} }
} }
static int createFunction_DEPRECATED(lua_State* L)
{
int argumentCount = lua_gettop(L);
if (argumentCount > 2)
luaL_error(L, "types.newfunction: expected 0-2 arguments, but got %d", argumentCount);
TypeFunctionTypePackId argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{});
if (lua_istable(L, 1))
{
std::vector<TypeFunctionTypeId> head{};
lua_getfield(L, 1, "head");
if (lua_istable(L, -1))
{
int argSize = lua_objlen(L, -1);
for (int i = 1; i <= argSize; i++)
{
lua_pushinteger(L, i);
lua_gettable(L, -2);
if (lua_isnil(L, -1))
{
lua_pop(L, 1);
break;
}
TypeFunctionTypeId ty = getTypeUserData(L, -1);
head.push_back(ty);
lua_pop(L, 1); // Remove `ty` from stack
}
}
lua_pop(L, 1); // Pop the "head" field
std::optional<TypeFunctionTypePackId> tail;
lua_getfield(L, 1, "tail");
if (auto type = optionalTypeUserData(L, -1))
tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type});
lua_pop(L, 1); // Pop the "tail" field
if (head.size() == 0 && tail.has_value())
argTypes = *tail;
else
argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail});
}
else if (!lua_isnoneornil(L, 1))
luaL_typeerrorL(L, 1, "table");
TypeFunctionTypePackId retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{});
if (lua_istable(L, 2))
{
std::vector<TypeFunctionTypeId> head{};
lua_getfield(L, 2, "head");
if (lua_istable(L, -1))
{
int argSize = lua_objlen(L, -1);
for (int i = 1; i <= argSize; i++)
{
lua_pushinteger(L, i);
lua_gettable(L, -2);
if (lua_isnil(L, -1))
{
lua_pop(L, 1);
break;
}
TypeFunctionTypeId ty = getTypeUserData(L, -1);
head.push_back(ty);
lua_pop(L, 1); // Remove `ty` from stack
}
}
lua_pop(L, 1); // Pop the "head" field
std::optional<TypeFunctionTypePackId> tail;
lua_getfield(L, 2, "tail");
if (auto type = optionalTypeUserData(L, -1))
tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type});
lua_pop(L, 1); // Pop the "tail" field
if (head.size() == 0 && tail.has_value())
retTypes = *tail;
else
retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail});
}
else if (!lua_isnoneornil(L, 2))
luaL_typeerrorL(L, 2, "table");
allocTypeUserData(L, TypeFunctionFunctionType{{}, {}, argTypes, retTypes});
return 1;
}
// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}, generics: {type}?) -> type` // Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}, generics: {type}?) -> type`
// Returns the type instance representing a function // Returns the type instance representing a function
static int createFunction(lua_State* L) static int createFunction(lua_State* L)
@ -1111,45 +997,7 @@ static int setFunctionParameters(lua_State* L)
if (!tfft) if (!tfft)
luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str());
if (FFlag::LuauUserTypeFunGenerics) tfft->argTypes = getTypePack(L, 2, 3);
{
tfft->argTypes = getTypePack(L, 2, 3);
}
else
{
std::vector<TypeFunctionTypeId> head{};
if (lua_istable(L, 2))
{
int argSize = lua_objlen(L, 2);
for (int i = 1; i <= argSize; i++)
{
lua_pushinteger(L, i);
lua_gettable(L, 2);
if (lua_isnil(L, -1))
{
lua_pop(L, 1);
break;
}
TypeFunctionTypeId ty = getTypeUserData(L, -1);
head.push_back(ty);
lua_pop(L, 1); // Remove `ty` from stack
}
}
else if (!lua_isnoneornil(L, 2))
luaL_typeerrorL(L, 2, "table");
std::optional<TypeFunctionTypePackId> tail;
if (auto type = optionalTypeUserData(L, 3))
tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type});
if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack
tfft->argTypes = *tail;
else // Make argTypes a type pack
tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail});
}
return 0; return 0;
} }
@ -1167,59 +1015,7 @@ static int getFunctionParameters(lua_State* L)
if (!tfft) if (!tfft)
luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str());
if (FFlag::LuauUserTypeFunGenerics) pushTypePack(L, tfft->argTypes);
{
pushTypePack(L, tfft->argTypes);
}
else
{
if (auto tftp = get<TypeFunctionTypePack>(tfft->argTypes))
{
int size = 0;
if (tftp->head.size() > 0)
size++;
if (tftp->tail.has_value())
size++;
lua_createtable(L, 0, size);
int argSize = (int)tftp->head.size();
if (argSize > 0)
{
lua_createtable(L, argSize, 0);
for (int i = 0; i < argSize; i++)
{
allocTypeUserData(L, tftp->head[i]->type);
lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed
}
lua_setfield(L, -2, "head");
}
if (tftp->tail.has_value())
{
auto tfvp = get<TypeFunctionVariadicTypePack>(*tftp->tail);
if (!tfvp)
LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment");
allocTypeUserData(L, tfvp->type->type);
lua_setfield(L, -2, "tail");
}
return 1;
}
if (auto tfvp = get<TypeFunctionVariadicTypePack>(tfft->argTypes))
{
lua_createtable(L, 0, 1);
allocTypeUserData(L, tfvp->type->type);
lua_setfield(L, -2, "tail");
return 1;
}
lua_createtable(L, 0, 0);
}
return 1; return 1;
} }
@ -1237,45 +1033,7 @@ static int setFunctionReturns(lua_State* L)
if (!tfft) if (!tfft)
luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str());
if (FFlag::LuauUserTypeFunGenerics) tfft->retTypes = getTypePack(L, 2, 3);
{
tfft->retTypes = getTypePack(L, 2, 3);
}
else
{
std::vector<TypeFunctionTypeId> head{};
if (lua_istable(L, 2))
{
int argSize = lua_objlen(L, 2);
for (int i = 1; i <= argSize; i++)
{
lua_pushinteger(L, i);
lua_gettable(L, 2);
if (lua_isnil(L, -1))
{
lua_pop(L, 1);
break;
}
TypeFunctionTypeId ty = getTypeUserData(L, -1);
head.push_back(ty);
lua_pop(L, 1); // Remove `ty` from stack
}
}
else if (!lua_isnoneornil(L, 2))
luaL_typeerrorL(L, 2, "table");
std::optional<TypeFunctionTypePackId> tail;
if (auto type = optionalTypeUserData(L, 3))
tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type});
if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack
tfft->retTypes = *tail;
else // Make retTypes a type pack
tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail});
}
return 0; return 0;
} }
@ -1293,59 +1051,7 @@ static int getFunctionReturns(lua_State* L)
if (!tfft) if (!tfft)
luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str());
if (FFlag::LuauUserTypeFunGenerics) pushTypePack(L, tfft->retTypes);
{
pushTypePack(L, tfft->retTypes);
}
else
{
if (auto tftp = get<TypeFunctionTypePack>(tfft->retTypes))
{
int size = 0;
if (tftp->head.size() > 0)
size++;
if (tftp->tail.has_value())
size++;
lua_createtable(L, 0, size);
int argSize = (int)tftp->head.size();
if (argSize > 0)
{
lua_createtable(L, argSize, 0);
for (int i = 0; i < argSize; i++)
{
allocTypeUserData(L, tftp->head[i]->type);
lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed
}
lua_setfield(L, -2, "head");
}
if (tftp->tail.has_value())
{
auto tfvp = get<TypeFunctionVariadicTypePack>(*tftp->tail);
if (!tfvp)
LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment");
allocTypeUserData(L, tfvp->type->type);
lua_setfield(L, -2, "tail");
}
return 1;
}
if (auto tfvp = get<TypeFunctionVariadicTypePack>(tfft->retTypes))
{
lua_createtable(L, 0, 1);
allocTypeUserData(L, tfvp->type->type);
lua_setfield(L, -2, "tail");
return 1;
}
lua_createtable(L, 0, 0);
}
return 1; return 1;
} }
@ -1401,7 +1107,7 @@ static int getFunctionGenerics(lua_State* L)
// Luau: `self:parent() -> type` // Luau: `self:parent() -> type`
// Returns the parent of a class type // Returns the parent of a class type
static int getClassParent(lua_State* L) static int getClassParent_DEPRECATED(lua_State* L)
{ {
int argumentCount = lua_gettop(L); int argumentCount = lua_gettop(L);
if (argumentCount != 1) if (argumentCount != 1)
@ -1413,10 +1119,54 @@ static int getClassParent(lua_State* L)
luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str());
// If the parent does not exist, we should return nil // If the parent does not exist, we should return nil
if (!tfct->parent) if (!tfct->parent_DEPRECATED)
lua_pushnil(L); lua_pushnil(L);
else else
allocTypeUserData(L, (*tfct->parent)->type); allocTypeUserData(L, (*tfct->parent_DEPRECATED)->type);
return 1;
}
// Luau: `self:readparent() -> type`
// Returns the read type of the class' parent
static int getReadParent(lua_State* L)
{
int argumentCount = lua_gettop(L);
if (argumentCount != 1)
luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount);
TypeFunctionTypeId self = getTypeUserData(L, 1);
auto tfct = get<TypeFunctionClassType>(self);
if (!tfct)
luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str());
// If the parent does not exist, we should return nil
if (!tfct->readParent)
lua_pushnil(L);
else
allocTypeUserData(L, (*tfct->readParent)->type);
return 1;
}
//
// Luau: `self:writeparent() -> type`
// Returns the write type of the class' parent
static int getWriteParent(lua_State* L)
{
int argumentCount = lua_gettop(L);
if (argumentCount != 1)
luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount);
TypeFunctionTypeId self = getTypeUserData(L, 1);
auto tfct = get<TypeFunctionClassType>(self);
if (!tfct)
luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str());
// If the parent does not exist, we should return nil
if (!tfct->writeParent)
lua_pushnil(L);
else
allocTypeUserData(L, (*tfct->writeParent)->type);
return 1; return 1;
} }
@ -1759,8 +1509,8 @@ void registerTypesLibrary(lua_State* L)
{"boolean", createBoolean}, {"boolean", createBoolean},
{"number", createNumber}, {"number", createNumber},
{"string", createString}, {"string", createString},
{FFlag::LuauUserTypeFunThreadBuffer ? "thread" : nullptr, FFlag::LuauUserTypeFunThreadBuffer ? createThread : nullptr}, {"thread", createThread},
{FFlag::LuauUserTypeFunThreadBuffer ? "buffer" : nullptr, FFlag::LuauUserTypeFunThreadBuffer ? createBuffer : nullptr}, {"buffer", createBuffer},
{nullptr, nullptr} {nullptr, nullptr}
}; };
@ -1770,9 +1520,9 @@ void registerTypesLibrary(lua_State* L)
{"unionof", createUnion}, {"unionof", createUnion},
{"intersectionof", createIntersection}, {"intersectionof", createIntersection},
{"newtable", createTable}, {"newtable", createTable},
{"newfunction", FFlag::LuauUserTypeFunGenerics ? createFunction : createFunction_DEPRECATED}, {"newfunction", createFunction},
{"copy", deepCopy}, {"copy", deepCopy},
{FFlag::LuauUserTypeFunGenerics ? "generic" : nullptr, FFlag::LuauUserTypeFunGenerics ? createGeneric : nullptr}, {"generic", createGeneric},
{nullptr, nullptr} {nullptr, nullptr}
}; };
@ -1844,15 +1594,18 @@ void registerTypeUserData(lua_State* L)
{"components", getComponents}, {"components", getComponents},
// Class type methods // Class type methods
{"parent", getClassParent}, {FFlag::LuauTypeFunReadWriteParents ? "readparent" : "parent", FFlag::LuauTypeFunReadWriteParents ? getReadParent : getClassParent_DEPRECATED},
// Function type methods (cont.) // Function type methods (cont.)
{FFlag::LuauUserTypeFunGenerics ? "setgenerics" : nullptr, FFlag::LuauUserTypeFunGenerics ? setFunctionGenerics : nullptr}, {"setgenerics", setFunctionGenerics},
{FFlag::LuauUserTypeFunGenerics ? "generics" : nullptr, FFlag::LuauUserTypeFunGenerics ? getFunctionGenerics : nullptr}, {"generics", getFunctionGenerics},
// Generic type methods // Generic type methods
{FFlag::LuauUserTypeFunGenerics ? "name" : nullptr, FFlag::LuauUserTypeFunGenerics ? getGenericName : nullptr}, {"name", getGenericName},
{FFlag::LuauUserTypeFunGenerics ? "ispack" : nullptr, FFlag::LuauUserTypeFunGenerics ? getGenericIsPack : nullptr}, {"ispack", getGenericIsPack},
// move this under Class type methods when removing FFlagLuauTypeFunReadWriteParents
{FFlag::LuauTypeFunReadWriteParents ? "writeparent" : nullptr, FFlag::LuauTypeFunReadWriteParents ? getWriteParent : nullptr},
{nullptr, nullptr} {nullptr, nullptr}
}; };
@ -1860,6 +1613,9 @@ void registerTypeUserData(lua_State* L)
// Create and register metatable for type userdata // Create and register metatable for type userdata
luaL_newmetatable(L, "type"); luaL_newmetatable(L, "type");
lua_pushstring(L, "type");
lua_setfield(L, -2, "__type");
// Protect metatable from being changed // Protect metatable from being changed
lua_pushstring(L, "The metatable is locked"); lua_pushstring(L, "The metatable is locked");
lua_setfield(L, -2, "__metatable"); lua_setfield(L, -2, "__metatable");
@ -1898,7 +1654,9 @@ static int print(lua_State* L)
size_t l = 0; size_t l = 0;
const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al
if (i > 1) if (i > 1)
result.append('\t', 1); {
result.append(1, '\t');
}
result.append(s, l); result.append(s, l);
lua_pop(L, 1); lua_pop(L, 1);
} }
@ -1941,29 +1699,16 @@ void setTypeFunctionEnvironment(lua_State* L)
luaopen_base(L); luaopen_base(L);
lua_pop(L, 1); lua_pop(L, 1);
if (FFlag::LuauUserTypeFunPrintToError) // Remove certain global functions from the base library
static const char* unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"};
for (auto& name : unavailableGlobals)
{ {
// Remove certain global functions from the base library lua_pushcfunction(L, unsupportedFunction, name);
static const char* unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; lua_setglobal(L, name);
for (auto& name : unavailableGlobals) }
{
lua_pushcfunction(L, unsupportedFunction, name);
lua_setglobal(L, name);
}
lua_pushcfunction(L, print, "print"); lua_pushcfunction(L, print, "print");
lua_setglobal(L, "print"); lua_setglobal(L, "print");
}
else
{
// Remove certain global functions from the base library
static const std::string unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"};
for (auto& name : unavailableGlobals)
{
lua_pushcfunction(L, unsupportedFunction, "Removing global function from type function environment");
lua_setglobal(L, name.c_str());
}
}
} }
void resetTypeFunctionState(lua_State* L) void resetTypeFunctionState(lua_State* L)
@ -2003,14 +1748,14 @@ bool areEqual(SeenSet& seen, const TypeFunctionSingletonType& lhs, const TypeFun
{ {
const TypeFunctionBooleanSingleton* lp = get<TypeFunctionBooleanSingleton>(&lhs); const TypeFunctionBooleanSingleton* lp = get<TypeFunctionBooleanSingleton>(&lhs);
const TypeFunctionBooleanSingleton* rp = get<TypeFunctionBooleanSingleton>(&lhs); const TypeFunctionBooleanSingleton* rp = get<TypeFunctionBooleanSingleton>(&rhs);
if (lp && rp) if (lp && rp)
return lp->value == rp->value; return lp->value == rp->value;
} }
{ {
const TypeFunctionStringSingleton* lp = get<TypeFunctionStringSingleton>(&lhs); const TypeFunctionStringSingleton* lp = get<TypeFunctionStringSingleton>(&lhs);
const TypeFunctionStringSingleton* rp = get<TypeFunctionStringSingleton>(&lhs); const TypeFunctionStringSingleton* rp = get<TypeFunctionStringSingleton>(&rhs);
if (lp && rp) if (lp && rp)
return lp->value == rp->value; return lp->value == rp->value;
} }
@ -2119,25 +1864,22 @@ bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunc
if (seenSetContains(seen, &lhs, &rhs)) if (seenSetContains(seen, &lhs, &rhs))
return true; return true;
if (FFlag::LuauUserTypeFunGenerics) if (lhs.generics.size() != rhs.generics.size())
return false;
for (auto l = lhs.generics.begin(), r = rhs.generics.begin(); l != lhs.generics.end() && r != rhs.generics.end(); ++l, ++r)
{ {
if (lhs.generics.size() != rhs.generics.size()) if (!areEqual(seen, **l, **r))
return false; return false;
}
for (auto l = lhs.generics.begin(), r = rhs.generics.begin(); l != lhs.generics.end() && r != rhs.generics.end(); ++l, ++r) if (lhs.genericPacks.size() != rhs.genericPacks.size())
{ return false;
if (!areEqual(seen, **l, **r))
return false;
}
if (lhs.genericPacks.size() != rhs.genericPacks.size()) for (auto l = lhs.genericPacks.begin(), r = rhs.genericPacks.begin(); l != lhs.genericPacks.end() && r != rhs.genericPacks.end(); ++l, ++r)
{
if (!areEqual(seen, **l, **r))
return false; return false;
for (auto l = lhs.genericPacks.begin(), r = rhs.genericPacks.begin(); l != lhs.genericPacks.end() && r != rhs.genericPacks.end(); ++l, ++r)
{
if (!areEqual(seen, **l, **r))
return false;
}
} }
if (bool(lhs.argTypes) != bool(rhs.argTypes)) if (bool(lhs.argTypes) != bool(rhs.argTypes))
@ -2166,7 +1908,7 @@ bool areEqual(SeenSet& seen, const TypeFunctionClassType& lhs, const TypeFunctio
if (seenSetContains(seen, &lhs, &rhs)) if (seenSetContains(seen, &lhs, &rhs))
return true; return true;
return lhs.name == rhs.name; return lhs.classTy == rhs.classTy;
} }
bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs) bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs)
@ -2240,14 +1982,11 @@ bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType
return areEqual(seen, *lf, *rf); return areEqual(seen, *lf, *rf);
} }
if (FFlag::LuauUserTypeFunGenerics)
{ {
{ const TypeFunctionGenericType* lg = get<TypeFunctionGenericType>(&lhs);
const TypeFunctionGenericType* lg = get<TypeFunctionGenericType>(&lhs); const TypeFunctionGenericType* rg = get<TypeFunctionGenericType>(&rhs);
const TypeFunctionGenericType* rg = get<TypeFunctionGenericType>(&rhs); if (lg && rg)
if (lg && rg) return lg->isNamed == rg->isNamed && lg->isPack == rg->isPack && lg->name == rg->name;
return lg->isNamed == rg->isNamed && lg->isPack == rg->isPack && lg->name == rg->name;
}
} }
return false; return false;
@ -2296,14 +2035,11 @@ bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunct
return areEqual(seen, *lv, *rv); return areEqual(seen, *lv, *rv);
} }
if (FFlag::LuauUserTypeFunGenerics)
{ {
{ const TypeFunctionGenericTypePack* lg = get<TypeFunctionGenericTypePack>(&lhs);
const TypeFunctionGenericTypePack* lg = get<TypeFunctionGenericTypePack>(&lhs); const TypeFunctionGenericTypePack* rg = get<TypeFunctionGenericTypePack>(&rhs);
const TypeFunctionGenericTypePack* rg = get<TypeFunctionGenericTypePack>(&rhs); if (lg && rg)
if (lg && rg) return lg->isNamed == rg->isNamed && lg->name == rg->name;
return lg->isNamed == rg->isNamed && lg->name == rg->name;
}
} }
return false; return false;
@ -2495,12 +2231,10 @@ private:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String));
break; break;
case TypeFunctionPrimitiveType::Thread: case TypeFunctionPrimitiveType::Thread:
if (FFlag::LuauUserTypeFunThreadBuffer) target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread));
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread));
break; break;
case TypeFunctionPrimitiveType::Buffer: case TypeFunctionPrimitiveType::Buffer:
if (FFlag::LuauUserTypeFunThreadBuffer) target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer));
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer));
break; break;
default: default:
break; break;
@ -2534,7 +2268,7 @@ private:
} }
else if (auto c = get<TypeFunctionClassType>(ty)) else if (auto c = get<TypeFunctionClassType>(ty))
target = ty; // Don't copy a class since they are immutable target = ty; // Don't copy a class since they are immutable
else if (auto g = get<TypeFunctionGenericType>(ty); FFlag::LuauUserTypeFunGenerics && g) else if (auto g = get<TypeFunctionGenericType>(ty))
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->isNamed, g->isPack, g->name}); target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->isNamed, g->isPack, g->name});
else else
LUAU_ASSERT(!"Unknown type"); LUAU_ASSERT(!"Unknown type");
@ -2555,7 +2289,7 @@ private:
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}});
else if (auto vPack = get<TypeFunctionVariadicTypePack>(tp)) else if (auto vPack = get<TypeFunctionVariadicTypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{});
else if (auto gPack = get<TypeFunctionGenericTypePack>(tp); gPack && FFlag::LuauUserTypeFunGenerics) else if (auto gPack = get<TypeFunctionGenericTypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->isNamed, gPack->name}); target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->isNamed, gPack->name});
else else
LUAU_ASSERT(!"Unknown type"); LUAU_ASSERT(!"Unknown type");
@ -2589,8 +2323,7 @@ private:
cloneChildren(f1, f2); cloneChildren(f1, f2);
else if (auto [c1, c2] = std::tuple{getMutable<TypeFunctionClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2) else if (auto [c1, c2] = std::tuple{getMutable<TypeFunctionClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
cloneChildren(c1, c2); cloneChildren(c1, c2);
else if (auto [g1, g2] = std::tuple{getMutable<TypeFunctionGenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)}; else if (auto [g1, g2] = std::tuple{getMutable<TypeFunctionGenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)}; g1 && g2)
FFlag::LuauUserTypeFunGenerics && g1 && g2)
cloneChildren(g1, g2); cloneChildren(g1, g2);
else else
LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types
@ -2604,7 +2337,7 @@ private:
vPack1 && vPack2) vPack1 && vPack2)
cloneChildren(vPack1, vPack2); cloneChildren(vPack1, vPack2);
else if (auto [gPack1, gPack2] = std::tuple{getMutable<TypeFunctionGenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)}; else if (auto [gPack1, gPack2] = std::tuple{getMutable<TypeFunctionGenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)};
FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) gPack1 && gPack2)
cloneChildren(gPack1, gPack2); cloneChildren(gPack1, gPack2);
else else
LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types
@ -2686,16 +2419,13 @@ private:
void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2)
{ {
if (FFlag::LuauUserTypeFunGenerics) f2->generics.reserve(f1->generics.size());
{ for (auto ty : f1->generics)
f2->generics.reserve(f1->generics.size()); f2->generics.push_back(shallowClone(ty));
for (auto ty : f1->generics)
f2->generics.push_back(shallowClone(ty));
f2->genericPacks.reserve(f1->genericPacks.size()); f2->genericPacks.reserve(f1->genericPacks.size());
for (auto tp : f1->genericPacks) for (auto tp : f1->genericPacks)
f2->genericPacks.push_back(shallowClone(tp)); f2->genericPacks.push_back(shallowClone(tp));
}
f2->argTypes = shallowClone(f1->argTypes); f2->argTypes = shallowClone(f1->argTypes);
f2->retTypes = shallowClone(f1->retTypes); f2->retTypes = shallowClone(f1->retTypes);
@ -2716,11 +2446,8 @@ private:
for (TypeFunctionTypeId& ty : t1->head) for (TypeFunctionTypeId& ty : t1->head)
t2->head.push_back(shallowClone(ty)); t2->head.push_back(shallowClone(ty));
if (FFlag::LuauUserTypeFunCloneTail) if (t1->tail)
{ t2->tail = shallowClone(*t1->tail);
if (t1->tail)
t2->tail = shallowClone(*t1->tail);
}
} }
void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2)

View file

@ -19,9 +19,7 @@
// used to control the recursion limit of any operations done by user-defined type functions // used to control the recursion limit of any operations done by user-defined type functions
// currently, controls serialization, deserialization, and `type.copy` // currently, controls serialization, deserialization, and `type.copy`
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000);
LUAU_FASTFLAG(LuauTypeFunReadWriteParents)
LUAU_FASTFLAG(LuauUserTypeFunThreadBuffer)
LUAU_FASTFLAG(LuauUserTypeFunGenerics)
namespace Luau namespace Luau
{ {
@ -161,26 +159,10 @@ private:
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String));
break; break;
case PrimitiveType::Thread: case PrimitiveType::Thread:
if (FFlag::LuauUserTypeFunThreadBuffer) target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread));
{
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread));
}
else
{
std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
break; break;
case PrimitiveType::Buffer: case PrimitiveType::Buffer:
if (FFlag::LuauUserTypeFunThreadBuffer) target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer));
{
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer));
}
else
{
std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str());
state->errors.push_back(error);
}
break; break;
case PrimitiveType::Function: case PrimitiveType::Function:
case PrimitiveType::Table: case PrimitiveType::Table:
@ -226,10 +208,13 @@ private:
} }
else if (auto c = get<ClassType>(ty)) else if (auto c = get<ClassType>(ty))
{ {
state->classesSerialized[c->name] = ty; // Since there aren't any new class types being created in type functions, we will deserialize by using a direct reference to the original
target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name}); // class
target = typeFunctionRuntime->typeArena.allocate(
TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, ty}
);
} }
else if (auto g = get<GenericType>(ty); FFlag::LuauUserTypeFunGenerics && g) else if (auto g = get<GenericType>(ty))
{ {
Name name = g->name; Name name = g->name;
@ -262,7 +247,7 @@ private:
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}});
else if (auto vPack = get<VariadicTypePack>(tp)) else if (auto vPack = get<VariadicTypePack>(tp))
target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{});
else if (auto gPack = get<GenericTypePack>(tp); FFlag::LuauUserTypeFunGenerics && gPack) else if (auto gPack = get<GenericTypePack>(tp))
{ {
Name name = gPack->name; Name name = gPack->name;
@ -308,8 +293,7 @@ private:
serializeChildren(f1, f2); serializeChildren(f1, f2);
else if (auto [c1, c2] = std::tuple{get<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2) else if (auto [c1, c2] = std::tuple{get<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
serializeChildren(c1, c2); serializeChildren(c1, c2);
else if (auto [g1, g2] = std::tuple{get<GenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)}; else if (auto [g1, g2] = std::tuple{get<GenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)}; g1 && g2)
FFlag::LuauUserTypeFunGenerics && g1 && g2)
serializeChildren(g1, g2); serializeChildren(g1, g2);
else else
{ // Either this or ty and tfti do not represent the same type { // Either this or ty and tfti do not represent the same type
@ -324,8 +308,7 @@ private:
serializeChildren(tPack1, tPack2); serializeChildren(tPack1, tPack2);
else if (auto [vPack1, vPack2] = std::tuple{get<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)}; vPack1 && vPack2) else if (auto [vPack1, vPack2] = std::tuple{get<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)}; vPack1 && vPack2)
serializeChildren(vPack1, vPack2); serializeChildren(vPack1, vPack2);
else if (auto [gPack1, gPack2] = std::tuple{get<GenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)}; else if (auto [gPack1, gPack2] = std::tuple{get<GenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)}; gPack1 && gPack2)
FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2)
serializeChildren(gPack1, gPack2); serializeChildren(gPack1, gPack2);
else else
{ // Either this or ty and tfti do not represent the same type { // Either this or ty and tfti do not represent the same type
@ -416,16 +399,13 @@ private:
void serializeChildren(const FunctionType* f1, TypeFunctionFunctionType* f2) void serializeChildren(const FunctionType* f1, TypeFunctionFunctionType* f2)
{ {
if (FFlag::LuauUserTypeFunGenerics) f2->generics.reserve(f1->generics.size());
{ for (auto ty : f1->generics)
f2->generics.reserve(f1->generics.size()); f2->generics.push_back(shallowSerialize(ty));
for (auto ty : f1->generics)
f2->generics.push_back(shallowSerialize(ty));
f2->genericPacks.reserve(f1->genericPacks.size()); f2->genericPacks.reserve(f1->genericPacks.size());
for (auto tp : f1->genericPacks) for (auto tp : f1->genericPacks)
f2->genericPacks.push_back(shallowSerialize(tp)); f2->genericPacks.push_back(shallowSerialize(tp));
}
f2->argTypes = shallowSerialize(f1->argTypes); f2->argTypes = shallowSerialize(f1->argTypes);
f2->retTypes = shallowSerialize(f1->retTypes); f2->retTypes = shallowSerialize(f1->retTypes);
@ -453,7 +433,20 @@ private:
c2->metatable = shallowSerialize(*c1->metatable); c2->metatable = shallowSerialize(*c1->metatable);
if (c1->parent) if (c1->parent)
c2->parent = shallowSerialize(*c1->parent); {
TypeFunctionTypeId parent = shallowSerialize(*c1->parent);
if (FFlag::LuauTypeFunReadWriteParents)
{
// we don't yet have read/write parents in the type inference engine.
c2->readParent = parent;
c2->writeParent = parent;
}
else
{
c2->parent_DEPRECATED = parent;
}
}
} }
void serializeChildren(const GenericType* g1, TypeFunctionGenericType* g2) void serializeChildren(const GenericType* g1, TypeFunctionGenericType* g2)
@ -590,14 +583,11 @@ private:
deserializeChildren(tfti, ty); deserializeChildren(tfti, ty);
if (FFlag::LuauUserTypeFunGenerics) // If we have completed working on all children of a function, remove the generic parameters from scope
if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty())
{ {
// If we have completed working on all children of a function, remove the generic parameters from scope closeFunctionScope(functionScopes.back().function);
if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty()) functionScopes.pop_back();
{
closeFunctionScope(functionScopes.back().function);
functionScopes.pop_back();
}
} }
} }
} }
@ -670,16 +660,10 @@ private:
target = state->ctx->builtins->stringType; target = state->ctx->builtins->stringType;
break; break;
case TypeFunctionPrimitiveType::Type::Thread: case TypeFunctionPrimitiveType::Type::Thread:
if (FFlag::LuauUserTypeFunThreadBuffer) target = state->ctx->builtins->threadType;
target = state->ctx->builtins->threadType;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
break; break;
case TypeFunctionPrimitiveType::Type::Buffer: case TypeFunctionPrimitiveType::Type::Buffer:
if (FFlag::LuauUserTypeFunThreadBuffer) target = state->ctx->builtins->bufferType;
target = state->ctx->builtins->bufferType;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
break; break;
default: default:
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
@ -720,12 +704,9 @@ private:
} }
else if (auto c = get<TypeFunctionClassType>(ty)) else if (auto c = get<TypeFunctionClassType>(ty))
{ {
if (auto result = state->classesSerialized.find(c->name)) target = c->classTy;
target = *result;
else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized");
} }
else if (auto g = get<TypeFunctionGenericType>(ty); FFlag::LuauUserTypeFunGenerics && g) else if (auto g = get<TypeFunctionGenericType>(ty))
{ {
if (g->isPack) if (g->isPack)
{ {
@ -775,7 +756,7 @@ private:
{ {
target = state->ctx->arena->addTypePack(VariadicTypePack{}); target = state->ctx->arena->addTypePack(VariadicTypePack{});
} }
else if (auto gPack = get<TypeFunctionGenericTypePack>(tp); FFlag::LuauUserTypeFunGenerics && gPack) else if (auto gPack = get<TypeFunctionGenericTypePack>(tp))
{ {
auto it = std::find_if( auto it = std::find_if(
genericPacks.rbegin(), genericPacks.rbegin(),
@ -832,8 +813,7 @@ private:
deserializeChildren(f2, f1); deserializeChildren(f2, f1);
else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2) else if (auto [c1, c2] = std::tuple{getMutable<ClassType>(ty), getMutable<TypeFunctionClassType>(tfti)}; c1 && c2)
deserializeChildren(c2, c1); deserializeChildren(c2, c1);
else if (auto [g1, g2] = std::tuple{getMutable<GenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)}; else if (auto [g1, g2] = std::tuple{getMutable<GenericType>(ty), getMutable<TypeFunctionGenericType>(tfti)}; g1 && g2)
FFlag::LuauUserTypeFunGenerics && g1 && g2)
deserializeChildren(g2, g1); deserializeChildren(g2, g1);
else else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
@ -846,8 +826,7 @@ private:
else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)}; else if (auto [vPack1, vPack2] = std::tuple{getMutable<VariadicTypePack>(tp), getMutable<TypeFunctionVariadicTypePack>(tftp)};
vPack1 && vPack2) vPack1 && vPack2)
deserializeChildren(vPack2, vPack1); deserializeChildren(vPack2, vPack1);
else if (auto [gPack1, gPack2] = std::tuple{getMutable<GenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)}; else if (auto [gPack1, gPack2] = std::tuple{getMutable<GenericTypePack>(tp), getMutable<TypeFunctionGenericTypePack>(tftp)}; gPack1 && gPack2)
FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2)
deserializeChildren(gPack2, gPack1); deserializeChildren(gPack2, gPack1);
else else
state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized");
@ -932,64 +911,60 @@ private:
void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1)
{ {
if (FFlag::LuauUserTypeFunGenerics) functionScopes.push_back({queue.size(), f2});
std::set<std::pair<bool, std::string>> genericNames;
// Introduce generic function parameters into scope
for (auto ty : f2->generics)
{ {
functionScopes.push_back({queue.size(), f2}); auto gty = get<TypeFunctionGenericType>(ty);
LUAU_ASSERT(gty && !gty->isPack);
std::set<std::pair<bool, std::string>> genericNames; std::pair<bool, std::string> nameKey = std::make_pair(gty->isNamed, gty->name);
// Introduce generic function parameters into scope // Duplicates are not allowed
for (auto ty : f2->generics) if (genericNames.find(nameKey) != genericNames.end())
{ {
auto gty = get<TypeFunctionGenericType>(ty); state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str()));
LUAU_ASSERT(gty && !gty->isPack); return;
std::pair<bool, std::string> nameKey = std::make_pair(gty->isNamed, gty->name);
// Duplicates are not allowed
if (genericNames.find(nameKey) != genericNames.end())
{
state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str()));
return;
}
genericNames.insert(nameKey);
TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{}));
genericTypes.push_back({gty->isNamed, gty->name, mapping});
} }
for (auto tp : f2->genericPacks) genericNames.insert(nameKey);
{
auto gtp = get<TypeFunctionGenericTypePack>(tp);
LUAU_ASSERT(gtp);
std::pair<bool, std::string> nameKey = std::make_pair(gtp->isNamed, gtp->name); TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{}));
genericTypes.push_back({gty->isNamed, gty->name, mapping});
// Duplicates are not allowed
if (genericNames.find(nameKey) != genericNames.end())
{
state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str()));
return;
}
genericNames.insert(nameKey);
TypePackId mapping =
state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{})
);
genericPacks.push_back({gtp->isNamed, gtp->name, mapping});
}
f1->generics.reserve(f2->generics.size());
for (auto ty : f2->generics)
f1->generics.push_back(shallowDeserialize(ty));
f1->genericPacks.reserve(f2->genericPacks.size());
for (auto tp : f2->genericPacks)
f1->genericPacks.push_back(shallowDeserialize(tp));
} }
for (auto tp : f2->genericPacks)
{
auto gtp = get<TypeFunctionGenericTypePack>(tp);
LUAU_ASSERT(gtp);
std::pair<bool, std::string> nameKey = std::make_pair(gtp->isNamed, gtp->name);
// Duplicates are not allowed
if (genericNames.find(nameKey) != genericNames.end())
{
state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str()));
return;
}
genericNames.insert(nameKey);
TypePackId mapping =
state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{}));
genericPacks.push_back({gtp->isNamed, gtp->name, mapping});
}
f1->generics.reserve(f2->generics.size());
for (auto ty : f2->generics)
f1->generics.push_back(shallowDeserialize(ty));
f1->genericPacks.reserve(f2->genericPacks.size());
for (auto tp : f2->genericPacks)
f1->genericPacks.push_back(shallowDeserialize(tp));
if (f2->argTypes) if (f2->argTypes)
f1->argTypes = shallowDeserialize(f2->argTypes); f1->argTypes = shallowDeserialize(f2->argTypes);

View file

@ -32,7 +32,10 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers) LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAG(LuauModuleHoldsAstRoot)
namespace Luau namespace Luau
{ {
@ -253,6 +256,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
currentModule->type = module.type; currentModule->type = module.type;
currentModule->allocator = module.allocator; currentModule->allocator = module.allocator;
currentModule->names = module.names; currentModule->names = module.names;
if (FFlag::LuauModuleHoldsAstRoot)
currentModule->root = module.root;
iceHandler->moduleName = module.name; iceHandler->moduleName = module.name;
normalizer.arena = &currentModule->internalTypes; normalizer.arena = &currentModule->internalTypes;
@ -761,8 +766,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state
struct Demoter : Substitution struct Demoter : Substitution
{ {
Demoter(TypeArena* arena) TypeArena* arena = nullptr;
NotNull<BuiltinTypes> builtins;
Demoter(TypeArena* arena, NotNull<BuiltinTypes> builtins)
: Substitution(TxnLog::empty(), arena) : Substitution(TxnLog::empty(), arena)
, arena(arena)
, builtins(builtins)
{ {
} }
@ -788,7 +797,8 @@ struct Demoter : Substitution
{ {
auto ftv = get<FreeType>(ty); auto ftv = get<FreeType>(ty);
LUAU_ASSERT(ftv); LUAU_ASSERT(ftv);
return addType(FreeType{demotedLevel(ftv->level)}); return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level))
: addType(FreeType{demotedLevel(ftv->level)});
} }
TypePackId clean(TypePackId tp) override TypePackId clean(TypePackId tp) override
@ -835,7 +845,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur
} }
} }
Demoter demoter{&currentModule->internalTypes}; Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes); demoter.demote(expectedTypes);
TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type;
@ -4408,7 +4418,7 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
} }
} }
Demoter demoter{&currentModule->internalTypes}; Demoter demoter{&currentModule->internalTypes, builtinTypes};
demoter.demote(expectedTypes); demoter.demote(expectedTypes);
return expectedTypes; return expectedTypes;
@ -5205,12 +5215,9 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati
ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel)
{ {
ScopePtr scope = std::make_shared<Scope>(parent, subLevel); ScopePtr scope = std::make_shared<Scope>(parent, subLevel);
if (FFlag::LuauOldSolverCreatesChildScopePointers) scope->location = location;
{ scope->returnType = parent->returnType;
scope->location = location; parent->children.emplace_back(scope.get());
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope)); currentModule->scopes.push_back(std::make_pair(location, scope));
return scope; return scope;
@ -5222,12 +5229,9 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio
ScopePtr scope = std::make_shared<Scope>(parent); ScopePtr scope = std::make_shared<Scope>(parent);
scope->level = parent->level; scope->level = parent->level;
scope->varargPack = parent->varargPack; scope->varargPack = parent->varargPack;
if (FFlag::LuauOldSolverCreatesChildScopePointers) scope->location = location;
{ scope->returnType = parent->returnType;
scope->location = location; parent->children.emplace_back(scope.get());
scope->returnType = parent->returnType;
parent->children.emplace_back(scope.get());
}
currentModule->scopes.push_back(std::make_pair(location, scope)); currentModule->scopes.push_back(std::make_pair(location, scope));
return scope; return scope;
@ -5273,7 +5277,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope)
TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::freshType(TypeLevel level)
{ {
return currentModule->internalTypes.addType(Type(FreeType(level))); return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level)
: currentModule->internalTypes.addType(Type(FreeType(level)));
} }
TypeId TypeChecker::singletonType(bool value) TypeId TypeChecker::singletonType(bool value)
@ -5716,8 +5721,18 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
TypeId ty = checkExpr(scope, *typeOf->expr).type; TypeId ty = checkExpr(scope, *typeOf->expr).type;
return ty; return ty;
} }
else if (annotation.is<AstTypeOptional>())
{
return builtinTypes->nilType;
}
else if (const auto& un = annotation.as<AstTypeUnion>()) else if (const auto& un = annotation.as<AstTypeUnion>())
{ {
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types; std::vector<TypeId> types;
for (AstType* ann : un->types) for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann)); types.push_back(resolveType(scope, *ann));
@ -5726,12 +5741,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno
} }
else if (const auto& un = annotation.as<AstTypeIntersection>()) else if (const auto& un = annotation.as<AstTypeIntersection>())
{ {
if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType)
{
if (un->types.size == 1)
return resolveType(scope, *un->types.data[0]);
}
std::vector<TypeId> types; std::vector<TypeId> types;
for (AstType* ann : un->types) for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann)); types.push_back(resolveType(scope, *ann));
return addType(IntersectionType{types}); return addType(IntersectionType{types});
} }
else if (const auto& g = annotation.as<AstTypeGroup>())
{
return resolveType(scope, *g->type);
}
else if (const auto& tsb = annotation.as<AstTypeSingletonBool>()) else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{ {
return singletonType(tsb->value); return singletonType(tsb->value);
@ -5889,8 +5914,8 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
const ScopePtr& scope, const ScopePtr& scope,
std::optional<TypeLevel> levelOpt, std::optional<TypeLevel> levelOpt,
const AstNode& node, const AstNode& node,
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericType*>& genericNames,
const AstArray<AstGenericTypePack>& genericPackNames, const AstArray<AstGenericTypePack*>& genericPackNames,
bool useCache bool useCache
) )
{ {
@ -5900,14 +5925,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
std::vector<GenericTypeDefinition> generics; std::vector<GenericTypeDefinition> generics;
for (const AstGenericType& generic : genericNames) for (const AstGenericType* generic : genericNames)
{ {
std::optional<TypeId> defaultValue; std::optional<TypeId> defaultValue;
if (generic.defaultValue) if (generic->defaultValue)
defaultValue = resolveType(scope, *generic.defaultValue); defaultValue = resolveType(scope, *generic->defaultValue);
Name n = generic.name.value; Name n = generic->name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that // These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic types have the same name. // a collision can only occur when two generic types have the same name.
@ -5936,14 +5961,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(
std::vector<GenericTypePackDefinition> genericPacks; std::vector<GenericTypePackDefinition> genericPacks;
for (const AstGenericTypePack& genericPack : genericPackNames) for (const AstGenericTypePack* genericPack : genericPackNames)
{ {
std::optional<TypePackId> defaultValue; std::optional<TypePackId> defaultValue;
if (genericPack.defaultValue) if (genericPack->defaultValue)
defaultValue = resolveTypePack(scope, *genericPack.defaultValue); defaultValue = resolveTypePack(scope, *genericPack->defaultValue);
Name n = genericPack.name.value; Name n = genericPack->name.value;
// These generics are the only thing that will ever be added to scope, so we can be certain that // These generics are the only thing that will ever be added to scope, so we can be certain that
// a collision can only occur when two generic types have the same name. // a collision can only occur when two generic types have the same name.

View file

@ -14,6 +14,7 @@
#include <type_traits> #include <type_traits>
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAGVARIABLE(LuauDisableNewSolverAssertsInMixedMode)
// Maximum number of steps to follow when traversing a path. May not always // Maximum number of steps to follow when traversing a path. May not always
// equate to the number of components in a path, depending on the traversal // equate to the number of components in a path, depending on the traversal
@ -156,14 +157,16 @@ Path PathBuilder::build()
PathBuilder& PathBuilder::readProp(std::string name) PathBuilder& PathBuilder::readProp(std::string name)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2); if (!FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauSolverV2);
components.push_back(Property{std::move(name), true}); components.push_back(Property{std::move(name), true});
return *this; return *this;
} }
PathBuilder& PathBuilder::writeProp(std::string name) PathBuilder& PathBuilder::writeProp(std::string name)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2); if (!FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauSolverV2);
components.push_back(Property{std::move(name), false}); components.push_back(Property{std::move(name), false});
return *this; return *this;
} }
@ -636,6 +639,247 @@ std::string toString(const TypePath::Path& path, bool prefixDot)
return result.str(); return result.str();
} }
std::string toStringHuman(const TypePath::Path& path)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
enum class State
{
Initial,
Normal,
Property,
PendingIs,
PendingAs,
PendingWhich,
};
std::stringstream result;
State state = State::Initial;
bool last = false;
auto strComponent = [&](auto&& c)
{
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, TypePath::Property>)
{
if (state == State::PendingIs)
result << ", ";
switch (state)
{
case State::Initial:
case State::PendingIs:
if (c.isRead)
result << "accessing `";
else
result << "writing to `";
break;
case State::Property:
// if the previous state was a property, then we're doing a sequence of indexing
result << '.';
break;
default:
break;
}
result << c.name;
state = State::Property;
}
else if constexpr (std::is_same_v<T, TypePath::Index>)
{
size_t humanIndex = c.index + 1;
if (state == State::Initial && !last)
result << "in" << ' ';
else if (state == State::PendingIs)
result << ' ' << "has" << ' ';
else if (state == State::Property)
result << '`' << ' ' << "has" << ' ';
result << "the " << humanIndex;
switch (humanIndex)
{
case 1:
result << "st";
break;
case 2:
result << "nd";
break;
case 3:
result << "rd";
break;
default:
result << "th";
}
switch (c.variant)
{
case TypePath::Index::Variant::Pack:
result << ' ' << "entry in the type pack";
break;
case TypePath::Index::Variant::Union:
result << ' ' << "component of the union";
break;
case TypePath::Index::Variant::Intersection:
result << ' ' << "component of the intersection";
break;
}
if (state == State::PendingWhich)
result << ' ' << "which";
if (state == State::PendingIs || state == State::Property)
state = State::PendingAs;
else
state = State::PendingIs;
}
else if constexpr (std::is_same_v<T, TypePath::TypeField>)
{
if (state == State::Initial && !last)
result << "in" << ' ';
else if (state == State::PendingIs)
result << ", ";
else if (state == State::Property)
result << '`' << ' ' << "has" << ' ';
switch (c)
{
case TypePath::TypeField::Table:
result << "the table portion";
if (state == State::Property)
state = State::PendingAs;
else
state = State::PendingIs;
break;
case TypePath::TypeField::Metatable:
result << "the metatable portion";
if (state == State::Property)
state = State::PendingAs;
else
state = State::PendingIs;
break;
case TypePath::TypeField::LowerBound:
result << "the lower bound of" << ' ';
state = State::Normal;
break;
case TypePath::TypeField::UpperBound:
result << "the upper bound of" << ' ';
state = State::Normal;
break;
case TypePath::TypeField::IndexLookup:
result << "the index type";
if (state == State::Property)
state = State::PendingAs;
else
state = State::PendingIs;
break;
case TypePath::TypeField::IndexResult:
result << "the result of indexing";
if (state == State::Property)
state = State::PendingAs;
else
state = State::PendingIs;
break;
case TypePath::TypeField::Negated:
result << "the negation" << ' ';
state = State::Normal;
break;
case TypePath::TypeField::Variadic:
result << "the variadic" << ' ';
state = State::Normal;
break;
}
}
else if constexpr (std::is_same_v<T, TypePath::PackField>)
{
if (state == State::PendingIs)
result << ", ";
else if (state == State::Property)
result << "`, ";
switch (c)
{
case TypePath::PackField::Arguments:
if (state == State::Initial)
result << "it" << ' ';
else if (state == State::PendingIs)
result << "the function" << ' ';
result << "takes";
break;
case TypePath::PackField::Returns:
if (state == State::Initial)
result << "it" << ' ';
else if (state == State::PendingIs)
result << "the function" << ' ';
result << "returns";
break;
case TypePath::PackField::Tail:
if (state == State::Initial)
result << "it has" << ' ';
result << "a tail of";
break;
}
if (state == State::PendingIs)
{
result << ' ';
state = State::PendingWhich;
}
else
{
result << ' ';
state = State::Normal;
}
}
else if constexpr (std::is_same_v<T, TypePath::Reduction>)
{
if (state == State::Initial)
result << "it" << ' ';
result << "reduces to" << ' ';
state = State::Normal;
}
else
{
static_assert(always_false_v<T>, "Unhandled Component variant");
}
};
size_t count = 0;
for (const TypePath::Component& component : path.components)
{
count++;
if (count == path.components.size())
last = true;
Luau::visit(strComponent, component);
}
switch (state)
{
case State::Property:
result << "` results in ";
break;
case State::PendingWhich:
// pending `which` becomes `is` if it's at the end
result << "is" << ' ';
break;
case State::PendingIs:
result << ' ' << "is" << ' ';
break;
case State::PendingAs:
result << ' ' << "as" << ' ';
break;
default:
break;
}
return result.str();
}
static bool traverse(TraversalState& state, const Path& path) static bool traverse(TraversalState& state, const Path& path)
{ {
auto step = [&state](auto&& c) auto step = [&state](auto&& c)

View file

@ -5,6 +5,7 @@
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include <algorithm> #include <algorithm>
@ -12,7 +13,8 @@
LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete);
LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope); LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope);
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
LUAU_FASTFLAG(LuauDisableNewSolverAssertsInMixedMode)
namespace Luau namespace Luau
{ {
@ -323,7 +325,7 @@ TypePack extendTypePack(
trackInteriorFreeType(ftp->scope, t); trackInteriorFreeType(ftp->scope, t);
} }
else else
t = arena.freshType(ftp->scope); t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope);
} }
newPack.head.push_back(t); newPack.head.push_back(t);
@ -548,7 +550,10 @@ std::vector<TypeId> findBlockedArgTypesIn(AstExprCall* expr, NotNull<DenseHashMa
void trackInteriorFreeType(Scope* scope, TypeId ty) void trackInteriorFreeType(Scope* scope, TypeId ty)
{ {
LUAU_ASSERT(FFlag::LuauSolverV2 && FFlag::LuauTrackInteriorFreeTypesOnScope); if (FFlag::LuauDisableNewSolverAssertsInMixedMode)
LUAU_ASSERT(FFlag::LuauTrackInteriorFreeTypesOnScope);
else
LUAU_ASSERT(FFlag::LuauSolverV2 && FFlag::LuauTrackInteriorFreeTypesOnScope);
for (; scope; scope = scope->parent.get()) for (; scope; scope = scope->parent.get())
{ {
if (scope->interiorFreeTypes) if (scope->interiorFreeTypes)

View file

@ -24,6 +24,7 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE);
#endif #endif
#endif #endif
#include <stdint.h>
#include <stdlib.h> #include <stdlib.h>
LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauFreezeArena)

View file

@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping)
LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering)
LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart) LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart)
LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds)
namespace Luau namespace Luau
{ {
@ -32,38 +33,20 @@ struct PromoteTypeLevels final : TypeOnceVisitor
const TypeArena* typeArena = nullptr; const TypeArena* typeArena = nullptr;
TypeLevel minLevel; TypeLevel minLevel;
Scope* outerScope = nullptr; PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel)
bool useScopes;
PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes)
: log(log) : log(log)
, typeArena(typeArena) , typeArena(typeArena)
, minLevel(minLevel) , minLevel(minLevel)
, outerScope(outerScope)
, useScopes(useScopes)
{ {
} }
template<typename TID, typename T> template<typename TID, typename T>
void promote(TID ty, T* t) void promote(TID ty, T* t)
{ {
if (useScopes && !t)
return;
LUAU_ASSERT(t); LUAU_ASSERT(t);
if (useScopes) if (minLevel.subsumesStrict(t->level))
{ log.changeLevel(ty, minLevel);
if (subsumesStrict(outerScope, t->scope))
log.changeScope(ty, NotNull{outerScope});
}
else
{
if (minLevel.subsumesStrict(t->level))
{
log.changeLevel(ty, minLevel);
}
}
} }
bool visit(TypeId ty) override bool visit(TypeId ty) override
@ -140,23 +123,23 @@ struct PromoteTypeLevels final : TypeOnceVisitor
} }
}; };
static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypeId ty) static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty)
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
if (ty->owningArena != typeArena) if (ty->owningArena != typeArena)
return; return;
PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; PromoteTypeLevels ptl{log, typeArena, minLevel};
ptl.traverse(ty); ptl.traverse(ty);
} }
void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypePackId tp) void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp)
{ {
// Type levels of types from other modules are already global, so we don't need to promote anything inside // Type levels of types from other modules are already global, so we don't need to promote anything inside
if (tp->owningArena != typeArena) if (tp->owningArena != typeArena)
return; return;
PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; PromoteTypeLevels ptl{log, typeArena, minLevel};
ptl.traverse(tp); ptl.traverse(tp);
} }
@ -369,12 +352,9 @@ static std::optional<std::pair<Luau::Name, const SingletonType*>> getTableMatchT
} }
template<typename TY_A, typename TY_B> template<typename TY_A, typename TY_B>
static bool subsumes(bool useScopes, TY_A* left, TY_B* right) static bool subsumes(TY_A* left, TY_B* right)
{ {
if (useScopes) return left->level.subsumes(right->level);
return subsumes(left->scope, right->scope);
else
return left->level.subsumes(right->level);
} }
TypeMismatch::Context Unifier::mismatchContext() TypeMismatch::Context Unifier::mismatchContext()
@ -463,7 +443,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
auto superFree = log.getMutable<FreeType>(superTy); auto superFree = log.getMutable<FreeType>(superTy);
auto subFree = log.getMutable<FreeType>(subTy); auto subFree = log.getMutable<FreeType>(subTy);
if (superFree && subFree && subsumes(useNewSolver, superFree, subFree)) if (superFree && subFree && subsumes(superFree, subFree))
{ {
if (!occursCheck(subTy, superTy, /* reversed = */ false)) if (!occursCheck(subTy, superTy, /* reversed = */ false))
log.replace(subTy, BoundType(superTy)); log.replace(subTy, BoundType(superTy));
@ -474,7 +454,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{ {
if (!occursCheck(superTy, subTy, /* reversed = */ true)) if (!occursCheck(superTy, subTy, /* reversed = */ true))
{ {
if (subsumes(useNewSolver, superFree, subFree)) if (subsumes(superFree, subFree))
{ {
log.changeLevel(subTy, superFree->level); log.changeLevel(subTy, superFree->level);
} }
@ -488,7 +468,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{ {
// Unification can't change the level of a generic. // Unification can't change the level of a generic.
auto subGeneric = log.getMutable<GenericType>(subTy); auto subGeneric = log.getMutable<GenericType>(subTy);
if (subGeneric && !subsumes(useNewSolver, subGeneric, superFree)) if (subGeneric && !subsumes(subGeneric, superFree))
{ {
// TODO: a more informative error message? CLI-39912 // TODO: a more informative error message? CLI-39912
reportError(location, GenericError{"Generic subtype escaping scope"}); reportError(location, GenericError{"Generic subtype escaping scope"});
@ -497,7 +477,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (!occursCheck(superTy, subTy, /* reversed = */ true)) if (!occursCheck(superTy, subTy, /* reversed = */ true))
{ {
promoteTypeLevels(log, types, superFree->level, superFree->scope, useNewSolver, subTy); promoteTypeLevels(log, types, superFree->level, subTy);
Widen widen{types, builtinTypes}; Widen widen{types, builtinTypes};
log.replace(superTy, BoundType(widen(subTy))); log.replace(superTy, BoundType(widen(subTy)));
@ -514,7 +494,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
// Unification can't change the level of a generic. // Unification can't change the level of a generic.
auto superGeneric = log.getMutable<GenericType>(superTy); auto superGeneric = log.getMutable<GenericType>(superTy);
if (superGeneric && !subsumes(useNewSolver, superGeneric, subFree)) if (superGeneric && !subsumes(superGeneric, subFree))
{ {
// TODO: a more informative error message? CLI-39912 // TODO: a more informative error message? CLI-39912
reportError(location, GenericError{"Generic supertype escaping scope"}); reportError(location, GenericError{"Generic supertype escaping scope"});
@ -523,7 +503,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (!occursCheck(subTy, superTy, /* reversed = */ false)) if (!occursCheck(subTy, superTy, /* reversed = */ false))
{ {
promoteTypeLevels(log, types, subFree->level, subFree->scope, useNewSolver, superTy); promoteTypeLevels(log, types, subFree->level, superTy);
log.replace(subTy, BoundType(superTy)); log.replace(subTy, BoundType(superTy));
} }
@ -535,7 +515,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
auto superGeneric = log.getMutable<GenericType>(superTy); auto superGeneric = log.getMutable<GenericType>(superTy);
auto subGeneric = log.getMutable<GenericType>(subTy); auto subGeneric = log.getMutable<GenericType>(subTy);
if (superGeneric && subGeneric && subsumes(useNewSolver, superGeneric, subGeneric)) if (superGeneric && subGeneric && subsumes(superGeneric, subGeneric))
{ {
if (!occursCheck(subTy, superTy, /* reversed = */ false)) if (!occursCheck(subTy, superTy, /* reversed = */ false))
log.replace(subTy, BoundType(superTy)); log.replace(subTy, BoundType(superTy));
@ -752,9 +732,6 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ
std::unique_ptr<Unifier> innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState->tryUnify_(type, superTy); innerState->tryUnify_(type, superTy);
if (useNewSolver)
logs.push_back(std::move(innerState->log));
if (auto e = hasUnificationTooComplex(innerState->errors)) if (auto e = hasUnificationTooComplex(innerState->errors))
unificationTooComplex = e; unificationTooComplex = e;
else if (innerState->failure) else if (innerState->failure)
@ -869,13 +846,8 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
if (!innerState->failure) if (!innerState->failure)
{ {
found = true; found = true;
if (useNewSolver) log.concat(std::move(innerState->log));
logs.push_back(std::move(innerState->log)); break;
else
{
log.concat(std::move(innerState->log));
break;
}
} }
else if (innerState->errors.empty()) else if (innerState->errors.empty())
{ {
@ -894,9 +866,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp
} }
} }
if (useNewSolver)
log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types});
if (unificationTooComplex) if (unificationTooComplex)
{ {
reportError(*unificationTooComplex); reportError(*unificationTooComplex);
@ -974,16 +943,10 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I
firstFailedOption = {innerState->errors.front()}; firstFailedOption = {innerState->errors.front()};
} }
if (useNewSolver) log.concat(std::move(innerState->log));
logs.push_back(std::move(innerState->log));
else
log.concat(std::move(innerState->log));
failure |= innerState->failure; failure |= innerState->failure;
} }
if (useNewSolver)
log.concat(combineLogsIntoIntersection(std::move(logs)));
if (unificationTooComplex) if (unificationTooComplex)
reportError(*unificationTooComplex); reportError(*unificationTooComplex);
else if (firstFailedOption) else if (firstFailedOption)
@ -1031,28 +994,6 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
} }
} }
if (useNewSolver && normalize)
{
// Sometimes a negation type is inside one of the types, e.g. { p: number } & { p: ~number }.
NegationTypeFinder finder;
finder.traverse(subTy);
if (finder.found)
{
// It is possible that A & B <: T even though A </: T and B </: T
// for example (string?) & ~nil <: string.
// We deal with this by type normalization.
std::shared_ptr<const NormalizedType> subNorm = normalizer->normalize(subTy);
std::shared_ptr<const NormalizedType> superNorm = normalizer->normalize(superTy);
if (subNorm && superNorm)
tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible");
else
reportError(location, NormalizationTooComplex{});
return;
}
}
std::vector<TxnLog> logs; std::vector<TxnLog> logs;
for (size_t i = 0; i < uv->parts.size(); ++i) for (size_t i = 0; i < uv->parts.size(); ++i)
@ -1069,7 +1010,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
{ {
found = true; found = true;
errorsSuppressed = innerState->failure; errorsSuppressed = innerState->failure;
if (useNewSolver || innerState->failure) if (innerState->failure)
logs.push_back(std::move(innerState->log)); logs.push_back(std::move(innerState->log));
else else
{ {
@ -1084,9 +1025,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType*
} }
} }
if (useNewSolver) if (errorsSuppressed)
log.concat(combineLogsIntoIntersection(std::move(logs)));
else if (errorsSuppressed)
log.concat(std::move(logs.front())); log.concat(std::move(logs.front()));
if (unificationTooComplex) if (unificationTooComplex)
@ -1200,24 +1139,6 @@ void Unifier::tryUnifyNormalizedTypes(
} }
} }
if (useNewSolver)
{
for (TypeId superTable : superNorm.tables)
{
std::unique_ptr<Unifier> innerState = makeChildUnifier();
innerState->tryUnify(subClass, superTable);
if (innerState->errors.empty())
{
found = true;
log.concat(std::move(innerState->log));
break;
}
else if (auto e = hasUnificationTooComplex(innerState->errors))
return reportError(*e);
}
}
if (!found) if (!found)
{ {
return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()});
@ -1502,12 +1423,6 @@ struct WeirdIter
} }
}; };
void Unifier::enableNewSolver()
{
useNewSolver = true;
log.useScopes = true;
}
ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy)
{ {
std::unique_ptr<Unifier> s = makeChildUnifier(); std::unique_ptr<Unifier> s = makeChildUnifier();
@ -1587,8 +1502,6 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
if (!occursCheck(superTp, subTp, /* reversed = */ true)) if (!occursCheck(superTp, subTp, /* reversed = */ true))
{ {
Widen widen{types, builtinTypes}; Widen widen{types, builtinTypes};
if (useNewSolver)
promoteTypeLevels(log, types, superFree->level, superFree->scope, /*useScopes*/ true, subTp);
log.replace(superTp, Unifiable::Bound<TypePackId>(widen(subTp))); log.replace(superTp, Unifiable::Bound<TypePackId>(widen(subTp)));
} }
} }
@ -1596,8 +1509,6 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
{ {
if (!occursCheck(subTp, superTp, /* reversed = */ false)) if (!occursCheck(subTp, superTp, /* reversed = */ false))
{ {
if (useNewSolver)
promoteTypeLevels(log, types, subFree->level, subFree->scope, /*useScopes*/ true, superTp);
log.replace(subTp, Unifiable::Bound<TypePackId>(superTp)); log.replace(subTp, Unifiable::Bound<TypePackId>(superTp));
} }
} }
@ -1648,7 +1559,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
if (FFlag::LuauSolverV2) if (FFlag::LuauSolverV2)
return freshType(NotNull{types}, builtinTypes, scope); return freshType(NotNull{types}, builtinTypes, scope);
else else
return types->freshType(scope, level); return FFlag::LuauFreeTypesMustHaveBounds ? types->freshType(builtinTypes, scope, level) : types->freshType_DEPRECATED(scope, level);
}; };
const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt});
@ -1687,74 +1598,28 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
// If both are at the end, we're done // If both are at the end, we're done
if (!superIter.good() && !subIter.good()) if (!superIter.good() && !subIter.good())
{ {
if (useNewSolver) const bool lFreeTail = superTpv->tail && log.getMutable<FreeTypePack>(log.follow(*superTpv->tail)) != nullptr;
const bool rFreeTail = subTpv->tail && log.getMutable<FreeTypePack>(log.follow(*subTpv->tail)) != nullptr;
if (lFreeTail && rFreeTail)
{ {
if (subIter.tail() && superIter.tail()) tryUnify_(*subTpv->tail, *superTpv->tail);
tryUnify_(*subIter.tail(), *superIter.tail());
else if (subIter.tail())
{
const TypePackId subTail = log.follow(*subIter.tail());
if (log.get<FreeTypePack>(subTail))
tryUnify_(subTail, emptyTp);
else if (log.get<GenericTypePack>(subTail))
reportError(location, TypePackMismatch{subTail, emptyTp});
else if (log.get<VariadicTypePack>(subTail) || log.get<ErrorTypePack>(subTail))
{
// Nothing. This is ok.
}
else
{
ice("Unexpected subtype tail pack " + toString(subTail), location);
}
}
else if (superIter.tail())
{
const TypePackId superTail = log.follow(*superIter.tail());
if (log.get<FreeTypePack>(superTail))
tryUnify_(emptyTp, superTail);
else if (log.get<GenericTypePack>(superTail))
reportError(location, TypePackMismatch{emptyTp, superTail});
else if (log.get<VariadicTypePack>(superTail) || log.get<ErrorTypePack>(superTail))
{
// Nothing. This is ok.
}
else
{
ice("Unexpected supertype tail pack " + toString(superTail), location);
}
}
else
{
// Nothing. This is ok.
}
} }
else else if (lFreeTail)
{ {
const bool lFreeTail = superTpv->tail && log.getMutable<FreeTypePack>(log.follow(*superTpv->tail)) != nullptr; tryUnify_(emptyTp, *superTpv->tail);
const bool rFreeTail = subTpv->tail && log.getMutable<FreeTypePack>(log.follow(*subTpv->tail)) != nullptr; }
if (lFreeTail && rFreeTail) else if (rFreeTail)
{ {
tryUnify_(emptyTp, *subTpv->tail);
}
else if (subTpv->tail && superTpv->tail)
{
if (log.getMutable<VariadicTypePack>(superIter.packId))
tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index));
else if (log.getMutable<VariadicTypePack>(subIter.packId))
tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index));
else
tryUnify_(*subTpv->tail, *superTpv->tail); tryUnify_(*subTpv->tail, *superTpv->tail);
}
else if (lFreeTail)
{
tryUnify_(emptyTp, *superTpv->tail);
}
else if (rFreeTail)
{
tryUnify_(emptyTp, *subTpv->tail);
}
else if (subTpv->tail && superTpv->tail)
{
if (log.getMutable<VariadicTypePack>(superIter.packId))
tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index));
else if (log.getMutable<VariadicTypePack>(subIter.packId))
tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index));
else
tryUnify_(*subTpv->tail, *superTpv->tail);
}
} }
break; break;
@ -2211,7 +2076,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
variance = Invariant; variance = Invariant;
std::unique_ptr<Unifier> innerState = makeChildUnifier(); std::unique_ptr<Unifier> innerState = makeChildUnifier();
if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) if (FFlag::LuauFixIndexerSubtypingOrdering)
innerState->tryUnify_(prop.type(), superTable->indexer->indexResultType); innerState->tryUnify_(prop.type(), superTable->indexer->indexResultType);
else else
{ {
@ -2496,49 +2361,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
{ {
case TableState::Free: case TableState::Free:
{ {
if (useNewSolver) tryUnify_(subTy, superMetatable->table);
{ log.bindTable(subTy, superTy);
std::unique_ptr<Unifier> innerState = makeChildUnifier();
bool missingProperty = false;
for (const auto& [propName, prop] : subTable->props)
{
if (std::optional<TypeId> mtPropTy = findTablePropertyRespectingMeta(superTy, propName))
{
innerState->tryUnify(prop.type(), *mtPropTy);
}
else
{
reportError(mismatchError);
missingProperty = true;
break;
}
}
if (const TableType* superTable = log.get<TableType>(log.follow(superMetatable->table)))
{
// TODO: Unify indexers.
}
if (auto e = hasUnificationTooComplex(innerState->errors))
reportError(*e);
else if (!innerState->errors.empty())
reportError(TypeError{
location,
TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()}
});
else if (!missingProperty)
{
log.concat(std::move(innerState->log));
log.bindTable(subTy, superTy);
failure |= innerState->failure;
}
}
else
{
tryUnify_(subTy, superMetatable->table);
log.bindTable(subTy, superTy);
}
break; break;
} }
@ -2864,18 +2688,9 @@ std::optional<TypeId> Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N
return Luau::findTablePropertyRespectingMeta(builtinTypes, errors, lhsType, name, location); return Luau::findTablePropertyRespectingMeta(builtinTypes, errors, lhsType, name, location);
} }
TxnLog Unifier::combineLogsIntoIntersection(std::vector<TxnLog> logs)
{
LUAU_ASSERT(useNewSolver);
TxnLog result(useNewSolver);
for (TxnLog& log : logs)
result.concatAsIntersections(std::move(log), NotNull{types});
return result;
}
TxnLog Unifier::combineLogsIntoUnion(std::vector<TxnLog> logs) TxnLog Unifier::combineLogsIntoUnion(std::vector<TxnLog> logs)
{ {
TxnLog result(useNewSolver); TxnLog result;
for (TxnLog& log : logs) for (TxnLog& log : logs)
result.concatAsUnion(std::move(log), NotNull{types}); result.concatAsUnion(std::move(log), NotNull{types});
return result; return result;
@ -3020,9 +2835,6 @@ std::unique_ptr<Unifier> Unifier::makeChildUnifier()
u->normalize = normalize; u->normalize = normalize;
u->checkInhabited = checkInhabited; u->checkInhabited = checkInhabited;
if (useNewSolver)
u->enableNewSolver();
return u; return u;
} }

View file

@ -18,6 +18,8 @@
#include <optional> #include <optional>
LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAGVARIABLE(LuauUnifyMetatableWithAny)
LUAU_FASTFLAG(LuauExtraFollows)
namespace Luau namespace Luau
{ {
@ -235,6 +237,10 @@ bool Unifier2::unify(TypeId subTy, TypeId superTy)
auto superMetatable = get<MetatableType>(superTy); auto superMetatable = get<MetatableType>(superTy);
if (subMetatable && superMetatable) if (subMetatable && superMetatable)
return unify(subMetatable, superMetatable); return unify(subMetatable, superMetatable);
else if (FFlag::LuauUnifyMetatableWithAny && subMetatable && superAny)
return unify(subMetatable, superAny);
else if (FFlag::LuauUnifyMetatableWithAny && subAny && superMetatable)
return unify(subAny, superMetatable);
else if (subMetatable) // if we only have one metatable, unify with the inner table else if (subMetatable) // if we only have one metatable, unify with the inner table
return unify(subMetatable->table, superTy); return unify(subMetatable->table, superTy);
else if (superMetatable) // if we only have one metatable, unify with the inner table else if (superMetatable) // if we only have one metatable, unify with the inner table
@ -277,7 +283,7 @@ bool Unifier2::unifyFreeWithType(TypeId subTy, TypeId superTy)
if (superArgTail) if (superArgTail)
return doDefault(); return doDefault();
const IntersectionType* upperBoundIntersection = get<IntersectionType>(subFree->upperBound); const IntersectionType* upperBoundIntersection = get<IntersectionType>(FFlag::LuauExtraFollows ? upperBound : subFree->upperBound);
if (!upperBoundIntersection) if (!upperBoundIntersection)
return doDefault(); return doDefault();
@ -524,6 +530,16 @@ bool Unifier2::unify(const TableType* subTable, const AnyType* superAny)
return true; return true;
} }
bool Unifier2::unify(const MetatableType* subMetatable, const AnyType*)
{
return unify(subMetatable->metatable, builtinTypes->anyType) && unify(subMetatable->table, builtinTypes->anyType);
}
bool Unifier2::unify(const AnyType*, const MetatableType* superMetatable)
{
return unify(builtinTypes->anyType, superMetatable->metatable) && unify(builtinTypes->anyType, superMetatable->table);
}
// FIXME? This should probably return an ErrorVec or an optional<TypeError> // FIXME? This should probably return an ErrorVec or an optional<TypeError>
// rather than a boolean to signal an occurs check failure. // rather than a boolean to signal an occurs check failure.
bool Unifier2::unify(TypePackId subTp, TypePackId superTp) bool Unifier2::unify(TypePackId subTp, TypePackId superTp)
@ -634,38 +650,33 @@ struct FreeTypeSearcher : TypeVisitor
{ {
} }
enum Polarity Polarity polarity = Polarity::Positive;
{
Positive,
Negative,
Both,
};
Polarity polarity = Positive;
void flip() void flip()
{ {
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
polarity = Negative; polarity = Polarity::Negative;
break; break;
case Negative: case Polarity::Negative:
polarity = Positive; polarity = Polarity::Positive;
break; break;
case Both: case Polarity::Mixed:
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
} }
DenseHashSet<const void*> seenPositive{nullptr}; DenseHashSet<const void*> seenPositive{nullptr};
DenseHashSet<const void*> seenNegative{nullptr}; DenseHashSet<const void*> seenNegative{nullptr};
bool seenWithPolarity(const void* ty) bool seenWithCurrentPolarity(const void* ty)
{ {
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
{ {
if (seenPositive.contains(ty)) if (seenPositive.contains(ty))
return true; return true;
@ -673,7 +684,7 @@ struct FreeTypeSearcher : TypeVisitor
seenPositive.insert(ty); seenPositive.insert(ty);
return false; return false;
} }
case Negative: case Polarity::Negative:
{ {
if (seenNegative.contains(ty)) if (seenNegative.contains(ty))
return true; return true;
@ -681,7 +692,7 @@ struct FreeTypeSearcher : TypeVisitor
seenNegative.insert(ty); seenNegative.insert(ty);
return false; return false;
} }
case Both: case Polarity::Mixed:
{ {
if (seenPositive.contains(ty) && seenNegative.contains(ty)) if (seenPositive.contains(ty) && seenNegative.contains(ty))
return true; return true;
@ -690,6 +701,8 @@ struct FreeTypeSearcher : TypeVisitor
seenNegative.insert(ty); seenNegative.insert(ty);
return false; return false;
} }
default:
LUAU_ASSERT(!"Unreachable");
} }
return false; return false;
@ -703,7 +716,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty) override bool visit(TypeId ty) override
{ {
if (seenWithPolarity(ty)) if (seenWithCurrentPolarity(ty))
return false; return false;
LUAU_ASSERT(ty); LUAU_ASSERT(ty);
@ -712,7 +725,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FreeType& ft) override bool visit(TypeId ty, const FreeType& ft) override
{ {
if (seenWithPolarity(ty)) if (seenWithCurrentPolarity(ty))
return false; return false;
if (!subsumes(scope, ft.scope)) if (!subsumes(scope, ft.scope))
@ -720,16 +733,18 @@ struct FreeTypeSearcher : TypeVisitor
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
positiveTypes[ty]++; positiveTypes[ty]++;
break; break;
case Negative: case Polarity::Negative:
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
case Both: case Polarity::Mixed:
positiveTypes[ty]++; positiveTypes[ty]++;
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
return true; return true;
@ -737,23 +752,25 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const TableType& tt) override bool visit(TypeId ty, const TableType& tt) override
{ {
if (seenWithPolarity(ty)) if (seenWithCurrentPolarity(ty))
return false; return false;
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope)) if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
{ {
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
positiveTypes[ty]++; positiveTypes[ty]++;
break; break;
case Negative: case Polarity::Negative:
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
case Both: case Polarity::Mixed:
positiveTypes[ty]++; positiveTypes[ty]++;
negativeTypes[ty]++; negativeTypes[ty]++;
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
} }
@ -766,7 +783,7 @@ struct FreeTypeSearcher : TypeVisitor
LUAU_ASSERT(prop.isShared()); LUAU_ASSERT(prop.isShared());
Polarity p = polarity; Polarity p = polarity;
polarity = Both; polarity = Polarity::Mixed;
traverse(prop.type()); traverse(prop.type());
polarity = p; polarity = p;
} }
@ -783,7 +800,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypeId ty, const FunctionType& ft) override bool visit(TypeId ty, const FunctionType& ft) override
{ {
if (seenWithPolarity(ty)) if (seenWithCurrentPolarity(ty))
return false; return false;
flip(); flip();
@ -802,7 +819,7 @@ struct FreeTypeSearcher : TypeVisitor
bool visit(TypePackId tp, const FreeTypePack& ftp) override bool visit(TypePackId tp, const FreeTypePack& ftp) override
{ {
if (seenWithPolarity(tp)) if (seenWithCurrentPolarity(tp))
return false; return false;
if (!subsumes(scope, ftp.scope)) if (!subsumes(scope, ftp.scope))
@ -810,16 +827,18 @@ struct FreeTypeSearcher : TypeVisitor
switch (polarity) switch (polarity)
{ {
case Positive: case Polarity::Positive:
positiveTypes[tp]++; positiveTypes[tp]++;
break; break;
case Negative: case Polarity::Negative:
negativeTypes[tp]++; negativeTypes[tp]++;
break; break;
case Both: case Polarity::Mixed:
positiveTypes[tp]++; positiveTypes[tp]++;
negativeTypes[tp]++; negativeTypes[tp]++;
break; break;
default:
LUAU_ASSERT(!"Unreachable");
} }
return true; return true;

View file

@ -38,7 +38,7 @@ private:
{ {
Page* next; Page* next;
char data[8192]; alignas(8) char data[8192];
}; };
Page* root; Page* root;

View file

@ -120,20 +120,6 @@ struct AstTypeList
using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName using AstArgumentName = std::pair<AstName, Location>; // TODO: remove and replace when we get a common struct for this pair instead of AstName
struct AstGenericType
{
AstName name;
Location location;
AstType* defaultValue = nullptr;
};
struct AstGenericTypePack
{
AstName name;
Location location;
AstTypePack* defaultValue = nullptr;
};
extern int gAstRttiIndex; extern int gAstRttiIndex;
template<typename T> template<typename T>
@ -208,6 +194,7 @@ public:
{ {
Checked, Checked,
Native, Native,
Deprecated,
}; };
AstAttr(const Location& location, Type type); AstAttr(const Location& location, Type type);
@ -253,6 +240,32 @@ public:
bool hasSemicolon; bool hasSemicolon;
}; };
class AstGenericType : public AstNode
{
public:
LUAU_RTTI(AstGenericType)
explicit AstGenericType(const Location& location, AstName name, AstType* defaultValue = nullptr);
void visit(AstVisitor* visitor) override;
AstName name;
AstType* defaultValue = nullptr;
};
class AstGenericTypePack : public AstNode
{
public:
LUAU_RTTI(AstGenericTypePack)
explicit AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue = nullptr);
void visit(AstVisitor* visitor) override;
AstName name;
AstTypePack* defaultValue = nullptr;
};
class AstExprGroup : public AstExpr class AstExprGroup : public AstExpr
{ {
public: public:
@ -424,8 +437,8 @@ public:
AstExprFunction( AstExprFunction(
const Location& location, const Location& location,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
AstLocal* self, AstLocal* self,
const AstArray<AstLocal*>& args, const AstArray<AstLocal*>& args,
bool vararg, bool vararg,
@ -441,10 +454,11 @@ public:
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool hasNativeAttribute() const; bool hasNativeAttribute() const;
bool hasAttribute(AstAttr::Type attributeType) const;
AstArray<AstAttr*> attributes; AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstLocal* self; AstLocal* self;
AstArray<AstLocal*> args; AstArray<AstLocal*> args;
std::optional<AstTypeList> returnAnnotation; std::optional<AstTypeList> returnAnnotation;
@ -857,8 +871,8 @@ public:
const Location& location, const Location& location,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
AstType* type, AstType* type,
bool exported bool exported
); );
@ -867,8 +881,8 @@ public:
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstType* type; AstType* type;
bool exported; bool exported;
}; };
@ -878,14 +892,22 @@ class AstStatTypeFunction : public AstStat
public: public:
LUAU_RTTI(AstStatTypeFunction); LUAU_RTTI(AstStatTypeFunction);
AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported); AstStatTypeFunction(
const Location& location,
const AstName& name,
const Location& nameLocation,
AstExprFunction* body,
bool exported,
bool hasErrors
);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstExprFunction* body; AstExprFunction* body = nullptr;
bool exported; bool exported = false;
bool hasErrors = false;
}; };
class AstStatDeclareGlobal : public AstStat class AstStatDeclareGlobal : public AstStat
@ -911,8 +933,8 @@ public:
const Location& location, const Location& location,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstArray<AstArgumentName>& paramNames,
bool vararg, bool vararg,
@ -925,8 +947,8 @@ public:
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstArray<AstArgumentName>& paramNames,
bool vararg, bool vararg,
@ -938,12 +960,13 @@ public:
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const; bool isCheckedFunction() const;
bool hasAttribute(AstAttr::Type attributeType) const;
AstArray<AstAttr*> attributes; AstArray<AstAttr*> attributes;
AstName name; AstName name;
Location nameLocation; Location nameLocation;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstTypeList params; AstTypeList params;
AstArray<AstArgumentName> paramNames; AstArray<AstArgumentName> paramNames;
bool vararg = false; bool vararg = false;
@ -1074,8 +1097,8 @@ public:
AstTypeFunction( AstTypeFunction(
const Location& location, const Location& location,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes, const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes const AstTypeList& returnTypes
@ -1084,8 +1107,8 @@ public:
AstTypeFunction( AstTypeFunction(
const Location& location, const Location& location,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes, const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes const AstTypeList& returnTypes
@ -1094,10 +1117,11 @@ public:
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
bool isCheckedFunction() const; bool isCheckedFunction() const;
bool hasAttribute(AstAttr::Type attributeType) const;
AstArray<AstAttr*> attributes; AstArray<AstAttr*> attributes;
AstArray<AstGenericType> generics; AstArray<AstGenericType*> generics;
AstArray<AstGenericTypePack> genericPacks; AstArray<AstGenericTypePack*> genericPacks;
AstTypeList argTypes; AstTypeList argTypes;
AstArray<std::optional<AstArgumentName>> argNames; AstArray<std::optional<AstArgumentName>> argNames;
AstTypeList returnTypes; AstTypeList returnTypes;
@ -1115,6 +1139,16 @@ public:
AstExpr* expr; AstExpr* expr;
}; };
class AstTypeOptional : public AstType
{
public:
LUAU_RTTI(AstTypeOptional)
AstTypeOptional(const Location& location);
void visit(AstVisitor* visitor) override;
};
class AstTypeUnion : public AstType class AstTypeUnion : public AstType
{ {
public: public:
@ -1204,6 +1238,18 @@ public:
const AstArray<char> value; const AstArray<char> value;
}; };
class AstTypeGroup : public AstType
{
public:
LUAU_RTTI(AstTypeGroup)
explicit AstTypeGroup(const Location& location, AstType* type);
void visit(AstVisitor* visitor) override;
AstType* type;
};
class AstTypePack : public AstNode class AstTypePack : public AstNode
{ {
public: public:
@ -1264,6 +1310,16 @@ public:
return visit(static_cast<AstNode*>(node)); return visit(static_cast<AstNode*>(node));
} }
virtual bool visit(class AstGenericType* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstGenericTypePack* node)
{
return visit(static_cast<AstNode*>(node));
}
virtual bool visit(class AstExpr* node) virtual bool visit(class AstExpr* node)
{ {
return visit(static_cast<AstNode*>(node)); return visit(static_cast<AstNode*>(node));
@ -1415,6 +1471,10 @@ public:
{ {
return visit(static_cast<AstStat*>(node)); return visit(static_cast<AstStat*>(node));
} }
virtual bool visit(class AstStatTypeFunction* node)
{
return visit(static_cast<AstStat*>(node));
}
virtual bool visit(class AstStatDeclareFunction* node) virtual bool visit(class AstStatDeclareFunction* node)
{ {
return visit(static_cast<AstStat*>(node)); return visit(static_cast<AstStat*>(node));
@ -1454,6 +1514,10 @@ public:
{ {
return visit(static_cast<AstType*>(node)); return visit(static_cast<AstType*>(node));
} }
virtual bool visit(class AstTypeOptional* node)
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeUnion* node) virtual bool visit(class AstTypeUnion* node)
{ {
return visit(static_cast<AstType*>(node)); return visit(static_cast<AstType*>(node));
@ -1470,6 +1534,10 @@ public:
{ {
return visit(static_cast<AstType*>(node)); return visit(static_cast<AstType*>(node));
} }
virtual bool visit(class AstTypeGroup* node)
{
return visit(static_cast<AstType*>(node));
}
virtual bool visit(class AstTypeError* node) virtual bool visit(class AstTypeError* node)
{ {
return visit(static_cast<AstType*>(node)); return visit(static_cast<AstType*>(node));

492
Ast/include/Luau/Cst.h Normal file
View file

@ -0,0 +1,492 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Location.h"
#include <string>
namespace Luau
{
extern int gCstRttiIndex;
template<typename T>
struct CstRtti
{
static const int value;
};
template<typename T>
const int CstRtti<T>::value = ++gCstRttiIndex;
#define LUAU_CST_RTTI(Class) \
static int CstClassIndex() \
{ \
return CstRtti<Class>::value; \
}
class CstNode
{
public:
explicit CstNode(int classIndex)
: classIndex(classIndex)
{
}
template<typename T>
bool is() const
{
return classIndex == T::CstClassIndex();
}
template<typename T>
T* as()
{
return classIndex == T::CstClassIndex() ? static_cast<T*>(this) : nullptr;
}
template<typename T>
const T* as() const
{
return classIndex == T::CstClassIndex() ? static_cast<const T*>(this) : nullptr;
}
const int classIndex;
};
class CstExprConstantNumber : public CstNode
{
public:
LUAU_CST_RTTI(CstExprConstantNumber)
explicit CstExprConstantNumber(const AstArray<char>& value);
AstArray<char> value;
};
class CstExprConstantString : public CstNode
{
public:
LUAU_CST_RTTI(CstExprConstantNumber)
enum QuoteStyle
{
QuotedSingle,
QuotedDouble,
QuotedRaw,
QuotedInterp,
};
CstExprConstantString(AstArray<char> sourceString, QuoteStyle quoteStyle, unsigned int blockDepth);
AstArray<char> sourceString;
QuoteStyle quoteStyle;
unsigned int blockDepth;
};
class CstExprCall : public CstNode
{
public:
LUAU_CST_RTTI(CstExprCall)
CstExprCall(std::optional<Position> openParens, std::optional<Position> closeParens, AstArray<Position> commaPositions);
std::optional<Position> openParens;
std::optional<Position> closeParens;
AstArray<Position> commaPositions;
};
class CstExprIndexExpr : public CstNode
{
public:
LUAU_CST_RTTI(CstExprIndexExpr)
CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition);
Position openBracketPosition;
Position closeBracketPosition;
};
class CstExprFunction : public CstNode
{
public:
LUAU_CST_RTTI(CstExprFunction)
CstExprFunction();
Position functionKeywordPosition{0, 0};
Position openGenericsPosition{0,0};
AstArray<Position> genericsCommaPositions;
Position closeGenericsPosition{0,0};
AstArray<Position> argsCommaPositions;
Position returnSpecifierPosition{0,0};
};
class CstExprTable : public CstNode
{
public:
LUAU_CST_RTTI(CstExprTable)
enum Separator
{
Comma,
Semicolon,
};
struct Item
{
std::optional<Position> indexerOpenPosition; // '[', only if Kind == General
std::optional<Position> indexerClosePosition; // ']', only if Kind == General
std::optional<Position> equalsPosition; // only if Kind != List
std::optional<Separator> separator; // may be missing for last Item
std::optional<Position> separatorPosition;
};
explicit CstExprTable(const AstArray<Item>& items);
AstArray<Item> items;
};
// TODO: Shared between unary and binary, should we split?
class CstExprOp : public CstNode
{
public:
LUAU_CST_RTTI(CstExprOp)
explicit CstExprOp(Position opPosition);
Position opPosition;
};
class CstExprTypeAssertion : public CstNode
{
public:
LUAU_CST_RTTI(CstExprTypeAssertion)
explicit CstExprTypeAssertion(Position opPosition);
Position opPosition;
};
class CstExprIfElse : public CstNode
{
public:
LUAU_CST_RTTI(CstExprIfElse)
CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf);
Position thenPosition;
Position elsePosition;
bool isElseIf;
};
class CstExprInterpString : public CstNode
{
public:
LUAU_CST_RTTI(CstExprInterpString)
explicit CstExprInterpString(AstArray<AstArray<char>> sourceStrings, AstArray<Position> stringPositions);
AstArray<AstArray<char>> sourceStrings;
AstArray<Position> stringPositions;
};
class CstStatDo : public CstNode
{
public:
LUAU_CST_RTTI(CstStatDo)
explicit CstStatDo(Position endPosition);
Position endPosition;
};
class CstStatRepeat : public CstNode
{
public:
LUAU_CST_RTTI(CstStatRepeat)
explicit CstStatRepeat(Position untilPosition);
Position untilPosition;
};
class CstStatReturn : public CstNode
{
public:
LUAU_CST_RTTI(CstStatReturn)
explicit CstStatReturn(AstArray<Position> commaPositions);
AstArray<Position> commaPositions;
};
class CstStatLocal : public CstNode
{
public:
LUAU_CST_RTTI(CstStatLocal)
CstStatLocal(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions);
AstArray<Position> varsCommaPositions;
AstArray<Position> valuesCommaPositions;
};
class CstStatFor : public CstNode
{
public:
LUAU_CST_RTTI(CstStatFor)
CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional<Position> stepCommaPosition);
Position equalsPosition;
Position endCommaPosition;
std::optional<Position> stepCommaPosition;
};
class CstStatForIn : public CstNode
{
public:
LUAU_CST_RTTI(CstStatForIn)
CstStatForIn(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions);
AstArray<Position> varsCommaPositions;
AstArray<Position> valuesCommaPositions;
};
class CstStatAssign : public CstNode
{
public:
LUAU_CST_RTTI(CstStatAssign)
CstStatAssign(AstArray<Position> varsCommaPositions, Position equalsPosition, AstArray<Position> valuesCommaPositions);
AstArray<Position> varsCommaPositions;
Position equalsPosition;
AstArray<Position> valuesCommaPositions;
};
class CstStatCompoundAssign : public CstNode
{
public:
LUAU_CST_RTTI(CstStatCompoundAssign)
explicit CstStatCompoundAssign(Position opPosition);
Position opPosition;
};
class CstStatFunction : public CstNode
{
public:
LUAU_CST_RTTI(CstStatFunction)
explicit CstStatFunction(Position functionKeywordPosition);
Position functionKeywordPosition;
};
class CstStatLocalFunction : public CstNode
{
public:
LUAU_CST_RTTI(CstStatLocalFunction)
explicit CstStatLocalFunction(Position localKeywordPosition, Position functionKeywordPosition);
Position localKeywordPosition;
Position functionKeywordPosition;
};
class CstGenericType : public CstNode
{
public:
LUAU_CST_RTTI(CstGenericType)
CstGenericType(std::optional<Position> defaultEqualsPosition);
std::optional<Position> defaultEqualsPosition;
};
class CstGenericTypePack : public CstNode
{
public:
LUAU_CST_RTTI(CstGenericTypePack)
CstGenericTypePack(Position ellipsisPosition, std::optional<Position> defaultEqualsPosition);
Position ellipsisPosition;
std::optional<Position> defaultEqualsPosition;
};
class CstStatTypeAlias : public CstNode
{
public:
LUAU_CST_RTTI(CstStatTypeAlias)
CstStatTypeAlias(
Position typeKeywordPosition,
Position genericsOpenPosition,
AstArray<Position> genericsCommaPositions,
Position genericsClosePosition,
Position equalsPosition
);
Position typeKeywordPosition;
Position genericsOpenPosition;
AstArray<Position> genericsCommaPositions;
Position genericsClosePosition;
Position equalsPosition;
};
class CstStatTypeFunction : public CstNode
{
public:
LUAU_CST_RTTI(CstStatTypeFunction)
CstStatTypeFunction(Position typeKeywordPosition, Position functionKeywordPosition);
Position typeKeywordPosition;
Position functionKeywordPosition;
};
class CstTypeReference : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeReference)
CstTypeReference(
std::optional<Position> prefixPointPosition,
Position openParametersPosition,
AstArray<Position> parametersCommaPositions,
Position closeParametersPosition
);
std::optional<Position> prefixPointPosition;
Position openParametersPosition;
AstArray<Position> parametersCommaPositions;
Position closeParametersPosition;
};
class CstTypeTable : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeTable)
struct Item
{
enum struct Kind
{
Indexer,
Property,
StringProperty,
};
Kind kind;
Position indexerOpenPosition; // '[', only if Kind != Property
Position indexerClosePosition; // ']' only if Kind != Property
Position colonPosition;
std::optional<CstExprTable::Separator> separator; // may be missing for last Item
std::optional<Position> separatorPosition;
CstExprConstantString* stringInfo = nullptr; // only if Kind == StringProperty
};
CstTypeTable(AstArray<Item> items, bool isArray);
AstArray<Item> items;
bool isArray = false;
};
class CstTypeFunction : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeFunction)
CstTypeFunction(
Position openGenericsPosition,
AstArray<Position> genericsCommaPositions,
Position closeGenericsPosition,
Position openArgsPosition,
AstArray<std::optional<Position>> argumentNameColonPositions,
AstArray<Position> argumentsCommaPositions,
Position closeArgsPosition,
Position returnArrowPosition
);
Position openGenericsPosition;
AstArray<Position> genericsCommaPositions;
Position closeGenericsPosition;
Position openArgsPosition;
AstArray<std::optional<Position>> argumentNameColonPositions;
AstArray<Position> argumentsCommaPositions;
Position closeArgsPosition;
Position returnArrowPosition;
};
class CstTypeTypeof : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeTypeof)
CstTypeTypeof(Position openPosition, Position closePosition);
Position openPosition;
Position closePosition;
};
class CstTypeUnion : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeUnion)
CstTypeUnion(std::optional<Position> leadingPosition, AstArray<Position> separatorPositions);
std::optional<Position> leadingPosition;
AstArray<Position> separatorPositions;
};
class CstTypeIntersection : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeIntersection)
explicit CstTypeIntersection(std::optional<Position> leadingPosition, AstArray<Position> separatorPositions);
std::optional<Position> leadingPosition;
AstArray<Position> separatorPositions;
};
class CstTypeSingletonString : public CstNode
{
public:
LUAU_CST_RTTI(CstTypeSingletonString)
CstTypeSingletonString(AstArray<char> sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth);
AstArray<char> sourceString;
CstExprConstantString::QuoteStyle quoteStyle;
unsigned int blockDepth;
};
class CstTypePackExplicit : public CstNode
{
public:
LUAU_CST_RTTI(CstTypePackExplicit)
CstTypePackExplicit(Position openParenthesesPosition, Position closeParenthesesPosition, AstArray<Position> commaPositions);
Position openParenthesesPosition;
Position closeParenthesesPosition;
AstArray<Position> commaPositions;
};
class CstTypePackGeneric : public CstNode
{
public:
LUAU_CST_RTTI(CstTypePackGeneric)
explicit CstTypePackGeneric(Position ellipsisPosition);
Position ellipsisPosition;
};
} // namespace Luau

View file

@ -87,6 +87,12 @@ struct Lexeme
Reserved_END Reserved_END
}; };
enum struct QuoteStyle
{
Single,
Double,
};
Type type; Type type;
Location location; Location location;
@ -111,6 +117,8 @@ public:
Lexeme(const Location& location, Type type, const char* name); Lexeme(const Location& location, Type type, const char* name);
unsigned int getLength() const; unsigned int getLength() const;
unsigned int getBlockDepth() const;
QuoteStyle getQuoteStyle() const;
std::string toString() const; std::string toString() const;
}; };
@ -179,6 +187,11 @@ public:
static bool fixupQuotedString(std::string& data); static bool fixupQuotedString(std::string& data);
static void fixupMultilineString(std::string& data); static void fixupMultilineString(std::string& data);
unsigned int getOffset() const
{
return offset;
}
private: private:
char peekch() const; char peekch() const;
char peekch(unsigned int lookahead) const; char peekch(unsigned int lookahead) const;

View file

@ -29,6 +29,8 @@ struct ParseOptions
bool allowDeclarationSyntax = false; bool allowDeclarationSyntax = false;
bool captureComments = false; bool captureComments = false;
std::optional<FragmentParseResumeSettings> parseFragment = std::nullopt; std::optional<FragmentParseResumeSettings> parseFragment = std::nullopt;
bool storeCstData = false;
bool noErrorLimit = false;
}; };
} // namespace Luau } // namespace Luau

View file

@ -10,6 +10,7 @@ namespace Luau
{ {
class AstStatBlock; class AstStatBlock;
class CstNode;
class ParseError : public std::exception class ParseError : public std::exception
{ {
@ -55,6 +56,8 @@ struct Comment
Location location; Location location;
}; };
using CstNodeMap = DenseHashMap<AstNode*, CstNode*>;
struct ParseResult struct ParseResult
{ {
AstStatBlock* root; AstStatBlock* root;
@ -64,6 +67,21 @@ struct ParseResult
std::vector<ParseError> errors; std::vector<ParseError> errors;
std::vector<Comment> commentLocations; std::vector<Comment> commentLocations;
CstNodeMap cstNodeMap{nullptr};
};
struct ParseExprResult
{
AstExpr* expr;
size_t lines = 0;
std::vector<HotComment> hotcomments;
std::vector<ParseError> errors;
std::vector<Comment> commentLocations;
CstNodeMap cstNodeMap{nullptr};
}; };
static constexpr const char* kParseNameError = "%error-id%"; static constexpr const char* kParseNameError = "%error-id%";

View file

@ -8,6 +8,7 @@
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Cst.h"
#include <initializer_list> #include <initializer_list>
#include <optional> #include <optional>
@ -62,6 +63,14 @@ public:
ParseOptions options = ParseOptions() ParseOptions options = ParseOptions()
); );
static ParseExprResult parseExpr(
const char* buffer,
std::size_t bufferSize,
AstNameTable& names,
Allocator& allocator,
ParseOptions options = ParseOptions()
);
private: private:
struct Name; struct Name;
struct Binding; struct Binding;
@ -116,7 +125,7 @@ private:
AstStat* parseFor(); AstStat* parseFor();
// funcname ::= Name {`.' Name} [`:' Name] // funcname ::= Name {`.' Name} [`:' Name]
AstExpr* parseFunctionName(Location start_DEPRECATED, bool& hasself, AstName& debugname); AstExpr* parseFunctionName(bool& hasself, AstName& debugname);
// function funcname funcbody // function funcname funcbody
LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0}); LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray<AstAttr*>& attributes = {nullptr, 0});
@ -143,12 +152,14 @@ private:
AstStat* parseReturn(); AstStat* parseReturn();
// type Name `=' Type // type Name `=' Type
AstStat* parseTypeAlias(const Location& start, bool exported); AstStat* parseTypeAlias(const Location& start, bool exported, Position typeKeywordPosition);
// type function Name ... end // type function Name ... end
AstStat* parseTypeFunction(const Location& start, bool exported); AstStat* parseTypeFunction(const Location& start, bool exported, Position typeKeywordPosition);
AstDeclaredClassProp parseDeclaredClassMethod(const AstArray<AstAttr*>& attributes);
AstDeclaredClassProp parseDeclaredClassMethod_DEPRECATED();
AstDeclaredClassProp parseDeclaredClassMethod();
// `declare global' Name: Type | // `declare global' Name: Type |
// `declare function' Name`(' [parlist] `)' [`:` Type] // `declare function' Name`(' [parlist] `)' [`:` Type]
@ -173,14 +184,19 @@ private:
); );
// explist ::= {exp `,'} exp // explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result); void parseExprList(TempVector<AstExpr*>& result, TempVector<Position>* commaPositions = nullptr);
// binding ::= Name [`:` Type] // binding ::= Name [`:` Type]
Binding parseBinding(); Binding parseBinding();
// bindinglist ::= (binding | `...') {`,' bindinglist} // bindinglist ::= (binding | `...') {`,' bindinglist}
// Returns the location of the vararg ..., or std::nullopt if the function is not vararg. // Returns the location of the vararg ..., or std::nullopt if the function is not vararg.
std::tuple<bool, Location, AstTypePack*> parseBindingList(TempVector<Binding>& result, bool allowDot3 = false); std::tuple<bool, Location, AstTypePack*> parseBindingList(
TempVector<Binding>& result,
bool allowDot3 = false,
AstArray<Position>* commaPositions = nullptr,
std::optional<Position> initialCommaPosition = std::nullopt
);
AstType* parseOptionalType(); AstType* parseOptionalType();
@ -196,19 +212,34 @@ private:
// | `(' [TypeList] `)' `->` ReturnType // | `(' [TypeList] `)' `->` ReturnType
// Returns the variadic annotation, if it exists. // Returns the variadic annotation, if it exists.
AstTypePack* parseTypeList(TempVector<AstType*>& result, TempVector<std::optional<AstArgumentName>>& resultNames); AstTypePack* parseTypeList(
TempVector<AstType*>& result,
TempVector<std::optional<AstArgumentName>>& resultNames,
TempVector<Position>* commaPositions = nullptr,
TempVector<std::optional<Position>>* nameColonPositions = nullptr
);
std::optional<AstTypeList> parseOptionalReturnType(); std::optional<AstTypeList> parseOptionalReturnType(Position* returnSpecifierPosition = nullptr);
std::pair<Location, AstTypeList> parseReturnType(); std::pair<Location, AstTypeList> parseReturnType();
AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation); struct TableIndexerResult
{
AstTableIndexer* node;
Position indexerOpenPosition;
Position indexerClosePosition;
Position colonPosition;
};
TableIndexerResult parseTableIndexer(AstTableAccess access, std::optional<Location> accessLocation, Lexeme begin);
// Remove with FFlagLuauStoreCSTData2
AstTableIndexer* parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional<Location> accessLocation, Lexeme begin);
AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray<AstAttr*>& attributes);
AstType* parseFunctionTypeTail( AstType* parseFunctionTypeTail(
const Lexeme& begin, const Lexeme& begin,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
AstArray<AstGenericType> generics, AstArray<AstGenericType*> generics,
AstArray<AstGenericTypePack> genericPacks, AstArray<AstGenericTypePack*> genericPacks,
AstArray<AstType*> params, AstArray<AstType*> params,
AstArray<std::optional<AstArgumentName>> paramNames, AstArray<std::optional<AstArgumentName>> paramNames,
AstTypePack* varargAnnotation AstTypePack* varargAnnotation
@ -259,6 +290,8 @@ private:
// args ::= `(' [explist] `)' | tableconstructor | String // args ::= `(' [explist] `)' | tableconstructor | String
AstExpr* parseFunctionArgs(AstExpr* func, bool self); AstExpr* parseFunctionArgs(AstExpr* func, bool self);
std::optional<CstExprTable::Separator> tableSeparator();
// tableconstructor ::= `{' [fieldlist] `}' // tableconstructor ::= `{' [fieldlist] `}'
// fieldlist ::= field {fieldsep field} [fieldsep] // fieldlist ::= field {fieldsep field} [fieldsep]
// field ::= `[' exp `]' `=' exp | Name `=' exp | exp // field ::= `[' exp `]' `=' exp | Name `=' exp | exp
@ -277,12 +310,21 @@ private:
Name parseIndexName(const char* context, const Position& previous); Name parseIndexName(const char* context, const Position& previous);
// `<' namelist `>' // `<' namelist `>'
std::pair<AstArray<AstGenericType>, AstArray<AstGenericTypePack>> parseGenericTypeList(bool withDefaultValues); std::pair<AstArray<AstGenericType*>, AstArray<AstGenericTypePack*>> parseGenericTypeList(
bool withDefaultValues,
Position* openPosition = nullptr,
AstArray<Position>* commaPositions = nullptr,
Position* closePosition = nullptr
);
// `<' Type[, ...] `>' // `<' Type[, ...] `>'
AstArray<AstTypeOrPack> parseTypeParams(); AstArray<AstTypeOrPack> parseTypeParams(
Position* openingPosition = nullptr,
TempVector<Position>* commaPositions = nullptr,
Position* closingPosition = nullptr
);
std::optional<AstArray<char>> parseCharArray(); std::optional<AstArray<char>> parseCharArray(AstArray<char>* originalString = nullptr);
AstExpr* parseString(); AstExpr* parseString();
AstExpr* parseNumber(); AstExpr* parseNumber();
@ -292,6 +334,9 @@ private:
void restoreLocals(unsigned int offset); void restoreLocals(unsigned int offset);
/// Returns string quote style and block depth
std::pair<CstExprConstantString::QuoteStyle, unsigned int> extractStringDetails();
// check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure // check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure
bool expectAndConsume(char value, const char* context = nullptr); bool expectAndConsume(char value, const char* context = nullptr);
bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); bool expectAndConsume(Lexeme::Type type, const char* context = nullptr);
@ -435,6 +480,7 @@ private:
std::vector<AstAttr*> scratchAttr; std::vector<AstAttr*> scratchAttr;
std::vector<AstStat*> scratchStat; std::vector<AstStat*> scratchStat;
std::vector<AstArray<char>> scratchString; std::vector<AstArray<char>> scratchString;
std::vector<AstArray<char>> scratchString2;
std::vector<AstExpr*> scratchExpr; std::vector<AstExpr*> scratchExpr;
std::vector<AstExpr*> scratchExprAux; std::vector<AstExpr*> scratchExprAux;
std::vector<AstName> scratchName; std::vector<AstName> scratchName;
@ -442,15 +488,21 @@ private:
std::vector<Binding> scratchBinding; std::vector<Binding> scratchBinding;
std::vector<AstLocal*> scratchLocal; std::vector<AstLocal*> scratchLocal;
std::vector<AstTableProp> scratchTableTypeProps; std::vector<AstTableProp> scratchTableTypeProps;
std::vector<CstTypeTable::Item> scratchCstTableTypeProps;
std::vector<AstType*> scratchType; std::vector<AstType*> scratchType;
std::vector<AstTypeOrPack> scratchTypeOrPack; std::vector<AstTypeOrPack> scratchTypeOrPack;
std::vector<AstDeclaredClassProp> scratchDeclaredClassProps; std::vector<AstDeclaredClassProp> scratchDeclaredClassProps;
std::vector<AstExprTable::Item> scratchItem; std::vector<AstExprTable::Item> scratchItem;
std::vector<CstExprTable::Item> scratchCstItem;
std::vector<AstArgumentName> scratchArgName; std::vector<AstArgumentName> scratchArgName;
std::vector<AstGenericType> scratchGenericTypes; std::vector<AstGenericType*> scratchGenericTypes;
std::vector<AstGenericTypePack> scratchGenericTypePacks; std::vector<AstGenericTypePack*> scratchGenericTypePacks;
std::vector<std::optional<AstArgumentName>> scratchOptArgName; std::vector<std::optional<AstArgumentName>> scratchOptArgName;
std::vector<Position> scratchPosition;
std::vector<std::optional<Position>> scratchOptPosition;
std::string scratchData; std::string scratchData;
CstNodeMap cstNodeMap;
}; };
} // namespace Luau } // namespace Luau

View file

@ -3,9 +3,24 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauDeprecatedAttribute);
namespace Luau namespace Luau
{ {
static bool hasAttributeInArray(const AstArray<AstAttr*> attributes, AstAttr::Type attributeType)
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
for (const auto attribute : attributes)
{
if (attribute->type == attributeType)
return true;
}
return false;
}
static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) static void visitTypeList(AstVisitor* visitor, const AstTypeList& list)
{ {
for (AstType* ty : list.types) for (AstType* ty : list.types)
@ -28,6 +43,38 @@ void AstAttr::visit(AstVisitor* visitor)
int gAstRttiIndex = 0; int gAstRttiIndex = 0;
AstGenericType::AstGenericType(const Location& location, AstName name, AstType* defaultValue)
: AstNode(ClassIndex(), location)
, name(name)
, defaultValue(defaultValue)
{
}
void AstGenericType::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
if (defaultValue)
defaultValue->visit(visitor);
}
}
AstGenericTypePack::AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue)
: AstNode(ClassIndex(), location)
, name(name)
, defaultValue(defaultValue)
{
}
void AstGenericTypePack::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
{
if (defaultValue)
defaultValue->visit(visitor);
}
}
AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, expr(expr) , expr(expr)
@ -185,8 +232,8 @@ void AstExprIndexExpr::visit(AstVisitor* visitor)
AstExprFunction::AstExprFunction( AstExprFunction::AstExprFunction(
const Location& location, const Location& location,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
AstLocal* self, AstLocal* self,
const AstArray<AstLocal*>& args, const AstArray<AstLocal*>& args,
bool vararg, bool vararg,
@ -245,6 +292,13 @@ bool AstExprFunction::hasNativeAttribute() const
return false; return false;
} }
bool AstExprFunction::hasAttribute(const AstAttr::Type attributeType) const
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
return hasAttributeInArray(attributes, attributeType);
}
AstExprTable::AstExprTable(const Location& location, const AstArray<Item>& items) AstExprTable::AstExprTable(const Location& location, const AstArray<Item>& items)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, items(items) , items(items)
@ -721,8 +775,8 @@ AstStatTypeAlias::AstStatTypeAlias(
const Location& location, const Location& location,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
AstType* type, AstType* type,
bool exported bool exported
) )
@ -740,16 +794,14 @@ void AstStatTypeAlias::visit(AstVisitor* visitor)
{ {
if (visitor->visit(this)) if (visitor->visit(this))
{ {
for (const AstGenericType& el : generics) for (AstGenericType* el : generics)
{ {
if (el.defaultValue) el->visit(visitor);
el.defaultValue->visit(visitor);
} }
for (const AstGenericTypePack& el : genericPacks) for (AstGenericTypePack* el : genericPacks)
{ {
if (el.defaultValue) el->visit(visitor);
el.defaultValue->visit(visitor);
} }
type->visit(visitor); type->visit(visitor);
@ -761,13 +813,15 @@ AstStatTypeFunction::AstStatTypeFunction(
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
AstExprFunction* body, AstExprFunction* body,
bool exported bool exported,
bool hasErrors
) )
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, name(name) , name(name)
, nameLocation(nameLocation) , nameLocation(nameLocation)
, body(body) , body(body)
, exported(exported) , exported(exported)
, hasErrors(hasErrors)
{ {
} }
@ -795,8 +849,8 @@ AstStatDeclareFunction::AstStatDeclareFunction(
const Location& location, const Location& location,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstArray<AstArgumentName>& paramNames,
bool vararg, bool vararg,
@ -822,8 +876,8 @@ AstStatDeclareFunction::AstStatDeclareFunction(
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstName& name, const AstName& name,
const Location& nameLocation, const Location& nameLocation,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& params, const AstTypeList& params,
const AstArray<AstArgumentName>& paramNames, const AstArray<AstArgumentName>& paramNames,
bool vararg, bool vararg,
@ -864,6 +918,13 @@ bool AstStatDeclareFunction::isCheckedFunction() const
return false; return false;
} }
bool AstStatDeclareFunction::hasAttribute(AstAttr::Type attributeType) const
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
return hasAttributeInArray(attributes, attributeType);
}
AstStatDeclareClass::AstStatDeclareClass( AstStatDeclareClass::AstStatDeclareClass(
const Location& location, const Location& location,
const AstName& name, const AstName& name,
@ -970,8 +1031,8 @@ void AstTypeTable::visit(AstVisitor* visitor)
AstTypeFunction::AstTypeFunction( AstTypeFunction::AstTypeFunction(
const Location& location, const Location& location,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes, const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes const AstTypeList& returnTypes
@ -990,8 +1051,8 @@ AstTypeFunction::AstTypeFunction(
AstTypeFunction::AstTypeFunction( AstTypeFunction::AstTypeFunction(
const Location& location, const Location& location,
const AstArray<AstAttr*>& attributes, const AstArray<AstAttr*>& attributes,
const AstArray<AstGenericType>& generics, const AstArray<AstGenericType*>& generics,
const AstArray<AstGenericTypePack>& genericPacks, const AstArray<AstGenericTypePack*>& genericPacks,
const AstTypeList& argTypes, const AstTypeList& argTypes,
const AstArray<std::optional<AstArgumentName>>& argNames, const AstArray<std::optional<AstArgumentName>>& argNames,
const AstTypeList& returnTypes const AstTypeList& returnTypes
@ -1027,6 +1088,13 @@ bool AstTypeFunction::isCheckedFunction() const
return false; return false;
} }
bool AstTypeFunction::hasAttribute(AstAttr::Type attributeType) const
{
LUAU_ASSERT(FFlag::LuauDeprecatedAttribute);
return hasAttributeInArray(attributes, attributeType);
}
AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, expr(expr) , expr(expr)
@ -1039,6 +1107,16 @@ void AstTypeTypeof::visit(AstVisitor* visitor)
expr->visit(visitor); expr->visit(visitor);
} }
AstTypeOptional::AstTypeOptional(const Location& location)
: AstType(ClassIndex(), location)
{
}
void AstTypeOptional::visit(AstVisitor* visitor)
{
visitor->visit(this);
}
AstTypeUnion::AstTypeUnion(const Location& location, const AstArray<AstType*>& types) AstTypeUnion::AstTypeUnion(const Location& location, const AstArray<AstType*>& types)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, types(types) , types(types)
@ -1091,6 +1169,18 @@ void AstTypeSingletonString::visit(AstVisitor* visitor)
visitor->visit(this); visitor->visit(this);
} }
AstTypeGroup::AstTypeGroup(const Location& location, AstType* type)
: AstType(ClassIndex(), location)
, type(type)
{
}
void AstTypeGroup::visit(AstVisitor* visitor)
{
if (visitor->visit(this))
type->visit(visitor);
}
AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex) AstTypeError::AstTypeError(const Location& location, const AstArray<AstType*>& types, bool isMissing, unsigned messageIndex)
: AstType(ClassIndex(), location) : AstType(ClassIndex(), location)
, types(types) , types(types)

268
Ast/src/Cst.cpp Normal file
View file

@ -0,0 +1,268 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Ast.h"
#include "Luau/Cst.h"
#include "Luau/Common.h"
namespace Luau
{
int gCstRttiIndex = 0;
CstExprConstantNumber::CstExprConstantNumber(const AstArray<char>& value)
: CstNode(CstClassIndex())
, value(value)
{
}
CstExprConstantString::CstExprConstantString(AstArray<char> sourceString, QuoteStyle quoteStyle, unsigned int blockDepth)
: CstNode(CstClassIndex())
, sourceString(sourceString)
, quoteStyle(quoteStyle)
, blockDepth(blockDepth)
{
LUAU_ASSERT(blockDepth == 0 || quoteStyle == QuoteStyle::QuotedRaw);
}
CstExprCall::CstExprCall(std::optional<Position> openParens, std::optional<Position> closeParens, AstArray<Position> commaPositions)
: CstNode(CstClassIndex())
, openParens(openParens)
, closeParens(closeParens)
, commaPositions(commaPositions)
{
}
CstExprIndexExpr::CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition)
: CstNode(CstClassIndex())
, openBracketPosition(openBracketPosition)
, closeBracketPosition(closeBracketPosition)
{
}
CstExprFunction::CstExprFunction() : CstNode(CstClassIndex())
{
}
CstExprTable::CstExprTable(const AstArray<Item>& items)
: CstNode(CstClassIndex())
, items(items)
{
}
CstExprOp::CstExprOp(Position opPosition)
: CstNode(CstClassIndex())
, opPosition(opPosition)
{
}
CstExprTypeAssertion::CstExprTypeAssertion(Position opPosition)
: CstNode(CstClassIndex())
, opPosition(opPosition)
{
}
CstExprIfElse::CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf)
: CstNode(CstClassIndex())
, thenPosition(thenPosition)
, elsePosition(elsePosition)
, isElseIf(isElseIf)
{
}
CstExprInterpString::CstExprInterpString(AstArray<AstArray<char>> sourceStrings, AstArray<Position> stringPositions)
: CstNode(CstClassIndex())
, sourceStrings(sourceStrings)
, stringPositions(stringPositions)
{
}
CstStatDo::CstStatDo(Position endPosition)
: CstNode(CstClassIndex())
, endPosition(endPosition)
{
}
CstStatRepeat::CstStatRepeat(Position untilPosition)
: CstNode(CstClassIndex())
, untilPosition(untilPosition)
{
}
CstStatReturn::CstStatReturn(AstArray<Position> commaPositions)
: CstNode(CstClassIndex())
, commaPositions(commaPositions)
{
}
CstStatLocal::CstStatLocal(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions)
: CstNode(CstClassIndex())
, varsCommaPositions(varsCommaPositions)
, valuesCommaPositions(valuesCommaPositions)
{
}
CstStatFor::CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional<Position> stepCommaPosition)
: CstNode(CstClassIndex())
, equalsPosition(equalsPosition)
, endCommaPosition(endCommaPosition)
, stepCommaPosition(stepCommaPosition)
{
}
CstStatForIn::CstStatForIn(AstArray<Position> varsCommaPositions, AstArray<Position> valuesCommaPositions)
: CstNode(CstClassIndex())
, varsCommaPositions(varsCommaPositions)
, valuesCommaPositions(valuesCommaPositions)
{
}
CstStatAssign::CstStatAssign(AstArray<Position> varsCommaPositions, Position equalsPosition, AstArray<Position> valuesCommaPositions)
: CstNode(CstClassIndex())
, varsCommaPositions(varsCommaPositions)
, equalsPosition(equalsPosition)
, valuesCommaPositions(valuesCommaPositions)
{
}
CstStatCompoundAssign::CstStatCompoundAssign(Position opPosition)
: CstNode(CstClassIndex())
, opPosition(opPosition)
{
}
CstStatFunction::CstStatFunction(Position functionKeywordPosition)
: CstNode(CstClassIndex())
, functionKeywordPosition(functionKeywordPosition)
{
}
CstStatLocalFunction::CstStatLocalFunction(Position localKeywordPosition, Position functionKeywordPosition)
: CstNode(CstClassIndex())
, localKeywordPosition(localKeywordPosition)
, functionKeywordPosition(functionKeywordPosition)
{
}
CstGenericType::CstGenericType(std::optional<Position> defaultEqualsPosition)
: CstNode(CstClassIndex())
, defaultEqualsPosition(defaultEqualsPosition)
{
}
CstGenericTypePack::CstGenericTypePack(Position ellipsisPosition, std::optional<Position> defaultEqualsPosition)
: CstNode(CstClassIndex())
, ellipsisPosition(ellipsisPosition)
, defaultEqualsPosition(defaultEqualsPosition)
{
}
CstStatTypeAlias::CstStatTypeAlias(
Position typeKeywordPosition,
Position genericsOpenPosition,
AstArray<Position> genericsCommaPositions,
Position genericsClosePosition,
Position equalsPosition
)
: CstNode(CstClassIndex())
, typeKeywordPosition(typeKeywordPosition)
, genericsOpenPosition(genericsOpenPosition)
, genericsCommaPositions(genericsCommaPositions)
, genericsClosePosition(genericsClosePosition)
, equalsPosition(equalsPosition)
{
}
CstStatTypeFunction::CstStatTypeFunction(Position typeKeywordPosition, Position functionKeywordPosition)
: CstNode(CstClassIndex())
, typeKeywordPosition(typeKeywordPosition)
, functionKeywordPosition(functionKeywordPosition)
{
}
CstTypeReference::CstTypeReference(
std::optional<Position> prefixPointPosition,
Position openParametersPosition,
AstArray<Position> parametersCommaPositions,
Position closeParametersPosition
)
: CstNode(CstClassIndex())
, prefixPointPosition(prefixPointPosition)
, openParametersPosition(openParametersPosition)
, parametersCommaPositions(parametersCommaPositions)
, closeParametersPosition(closeParametersPosition)
{
}
CstTypeTable::CstTypeTable(AstArray<Item> items, bool isArray)
: CstNode(CstClassIndex())
, items(items)
, isArray(isArray)
{
}
CstTypeFunction::CstTypeFunction(
Position openGenericsPosition,
AstArray<Position> genericsCommaPositions,
Position closeGenericsPosition,
Position openArgsPosition,
AstArray<std::optional<Position>> argumentNameColonPositions,
AstArray<Position> argumentsCommaPositions,
Position closeArgsPosition,
Position returnArrowPosition
)
: CstNode(CstClassIndex())
, openGenericsPosition(openGenericsPosition)
, genericsCommaPositions(genericsCommaPositions)
, closeGenericsPosition(closeGenericsPosition)
, openArgsPosition(openArgsPosition)
, argumentNameColonPositions(argumentNameColonPositions)
, argumentsCommaPositions(argumentsCommaPositions)
, closeArgsPosition(closeArgsPosition)
, returnArrowPosition(returnArrowPosition)
{
}
CstTypeTypeof::CstTypeTypeof(Position openPosition, Position closePosition)
: CstNode(CstClassIndex())
, openPosition(openPosition)
, closePosition(closePosition)
{
}
CstTypeUnion::CstTypeUnion(std::optional<Position> leadingPosition, AstArray<Position> separatorPositions)
: CstNode(CstClassIndex())
, leadingPosition(leadingPosition)
, separatorPositions(separatorPositions)
{
}
CstTypeIntersection::CstTypeIntersection(std::optional<Position> leadingPosition, AstArray<Position> separatorPositions)
: CstNode(CstClassIndex())
, leadingPosition(leadingPosition)
, separatorPositions(separatorPositions)
{
}
CstTypeSingletonString::CstTypeSingletonString(AstArray<char> sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth)
: CstNode(CstClassIndex())
, sourceString(sourceString)
, quoteStyle(quoteStyle)
, blockDepth(blockDepth)
{
LUAU_ASSERT(quoteStyle != CstExprConstantString::QuotedInterp);
}
CstTypePackExplicit::CstTypePackExplicit(Position openParenthesesPosition, Position closeParenthesesPosition, AstArray<Position> commaPositions)
: CstNode(CstClassIndex())
, openParenthesesPosition(openParenthesesPosition)
, closeParenthesesPosition(closeParenthesesPosition)
, commaPositions(commaPositions)
{
}
CstTypePackGeneric::CstTypePackGeneric(Position ellipsisPosition)
: CstNode(CstClassIndex())
, ellipsisPosition(ellipsisPosition)
{
}
} // namespace Luau

View file

@ -8,9 +8,6 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition2)
LUAU_FASTFLAGVARIABLE(LexerFixInterpStringStart)
namespace Luau namespace Luau
{ {
@ -306,16 +303,45 @@ static char unescape(char ch)
} }
} }
unsigned int Lexeme::getBlockDepth() const
{
LUAU_ASSERT(type == Lexeme::RawString || type == Lexeme::BlockComment);
// If we have a well-formed string, we are guaranteed to see 2 `]` characters after the end of the string contents
LUAU_ASSERT(*(data + length) == ']');
unsigned int depth = 0;
do
{
depth++;
} while (*(data + length + depth) != ']');
return depth - 1;
}
Lexeme::QuoteStyle Lexeme::getQuoteStyle() const
{
LUAU_ASSERT(type == Lexeme::QuotedString);
// If we have a well-formed string, we are guaranteed to see a closing delimiter after the string
LUAU_ASSERT(data);
char quote = *(data + length);
if (quote == '\'')
return Lexeme::QuoteStyle::Single;
else if (quote == '"')
return Lexeme::QuoteStyle::Double;
LUAU_ASSERT(!"Unknown quote style");
return Lexeme::QuoteStyle::Double; // unreachable, but required due to compiler warning
}
Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition) Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition)
: buffer(buffer) : buffer(buffer)
, bufferSize(bufferSize) , bufferSize(bufferSize)
, offset(0) , offset(0)
, line(FFlag::LexerResumesFromPosition2 ? startPosition.line : 0) , line(startPosition.line)
, lineOffset(FFlag::LexerResumesFromPosition2 ? 0u - startPosition.column : 0) , lineOffset(0u - startPosition.column)
, lexeme( , lexeme((Location(Position(startPosition.line, startPosition.column), 0)), Lexeme::Eof)
(FFlag::LexerResumesFromPosition2 ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)),
Lexeme::Eof
)
, names(names) , names(names)
, skipComments(false) , skipComments(false)
, readNames(true) , readNames(true)
@ -761,7 +787,7 @@ Lexeme Lexer::readNext()
return Lexeme(Location(start, 1), '}'); return Lexeme(Location(start, 1), '}');
} }
return readInterpolatedStringSection(FFlag::LexerFixInterpStringStart ? start : position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); return readInterpolatedStringSection(start, Lexeme::InterpStringMid, Lexeme::InterpStringEnd);
} }
case '=': case '=':

File diff suppressed because it is too large Load diff

View file

@ -10,7 +10,7 @@
std::optional<std::string> getCurrentWorkingDirectory(); std::optional<std::string> getCurrentWorkingDirectory();
std::string normalizePath(std::string_view path); std::string normalizePath(std::string_view path);
std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath); std::optional<std::string> resolvePath(std::string_view relativePath, std::string_view baseFilePath);
std::optional<std::string> readFile(const std::string& name); std::optional<std::string> readFile(const std::string& name);
std::optional<std::string> readStdin(); std::optional<std::string> readStdin();
@ -23,7 +23,7 @@ bool isDirectory(const std::string& path);
bool traverseDirectory(const std::string& path, const std::function<void(const std::string& name)>& callback); bool traverseDirectory(const std::string& path, const std::function<void(const std::string& name)>& callback);
std::vector<std::string_view> splitPath(std::string_view path); std::vector<std::string_view> splitPath(std::string_view path);
std::string joinPaths(const std::string& lhs, const std::string& rhs); std::string joinPaths(std::string_view lhs, std::string_view rhs);
std::optional<std::string> getParentPath(const std::string& path); std::optional<std::string> getParentPath(std::string_view path);
std::vector<std::string> getSourceFiles(int argc, char** argv); std::vector<std::string> getSourceFiles(int argc, char** argv);

View file

@ -20,6 +20,7 @@
#endif #endif
#include <string.h> #include <string.h>
#include <string_view>
#ifdef _WIN32 #ifdef _WIN32
static std::wstring fromUtf8(const std::string& path) static std::wstring fromUtf8(const std::string& path)
@ -90,108 +91,76 @@ std::optional<std::string> getCurrentWorkingDirectory()
return std::nullopt; return std::nullopt;
} }
// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau")
std::string normalizePath(std::string_view path) std::string normalizePath(std::string_view path)
{ {
return resolvePath(path, ""); const std::vector<std::string_view> components = splitPath(path);
} std::vector<std::string_view> normalizedComponents;
// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath. const bool isAbsolute = isAbsolutePath(path);
// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path:
// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe").
std::string resolvePath(std::string_view path, std::string_view baseFilePath)
{
std::vector<std::string_view> pathComponents;
std::vector<std::string_view> baseFilePathComponents;
// Dependent on whether the final resolved path is absolute or relative // 1. Normalize path components
// - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty const size_t startIndex = isAbsolute ? 1 : 0;
// - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc. for (size_t i = startIndex; i < components.size(); i++)
std::string resolvedPathPrefix;
bool isResolvedPathRelative = false;
if (isAbsolutePath(path))
{
// path is absolute, we use path's prefix and ignore baseFilePath
size_t afterPrefix = path.find_first_of("\\/") + 1;
resolvedPathPrefix = path.substr(0, afterPrefix);
pathComponents = splitPath(path.substr(afterPrefix));
}
else
{
size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1;
baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix));
if (isAbsolutePath(baseFilePath))
{
// path is relative and baseFilePath is absolute, we use baseFilePath's prefix
resolvedPathPrefix = baseFilePath.substr(0, afterPrefix);
}
else
{
// path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative)
isResolvedPathRelative = true;
}
pathComponents = splitPath(path);
}
// Remove filename from components
if (!baseFilePathComponents.empty())
baseFilePathComponents.pop_back();
// Resolve the path by applying pathComponents to baseFilePathComponents
int numPrependedParents = 0;
for (std::string_view component : pathComponents)
{ {
std::string_view component = components[i];
if (component == "..") if (component == "..")
{ {
if (baseFilePathComponents.empty()) if (normalizedComponents.empty())
{ {
if (isResolvedPathRelative) if (!isAbsolute)
numPrependedParents++; // "../" will later be added to the beginning of the resolved path {
normalizedComponents.emplace_back("..");
}
} }
else if (baseFilePathComponents.back() != "..") else if (normalizedComponents.back() == "..")
{ {
baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file" normalizedComponents.emplace_back("..");
}
else
{
normalizedComponents.pop_back();
} }
} }
else if (component != "." && !component.empty()) else if (!component.empty() && component != ".")
{ {
baseFilePathComponents.push_back(component); normalizedComponents.emplace_back(component);
} }
} }
// Create resolved path prefix for relative paths std::string normalizedPath;
if (isResolvedPathRelative)
// 2. Add correct prefix to formatted path
if (isAbsolute)
{ {
if (numPrependedParents > 0) normalizedPath += components[0];
{ normalizedPath += "/";
resolvedPathPrefix.reserve(numPrependedParents * 3); }
for (int i = 0; i < numPrependedParents; i++) else if (normalizedComponents.empty() || normalizedComponents[0] != "..")
{ {
resolvedPathPrefix += "../"; normalizedPath += "./";
}
}
else
{
resolvedPathPrefix = "./";
}
} }
// Join baseFilePathComponents to form the resolved path // 3. Join path components to form the normalized path
std::string resolvedPath = resolvedPathPrefix; for (auto iter = normalizedComponents.begin(); iter != normalizedComponents.end(); ++iter)
for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter)
{ {
if (iter != baseFilePathComponents.begin()) if (iter != normalizedComponents.begin())
resolvedPath += "/"; normalizedPath += "/";
resolvedPath += *iter; normalizedPath += *iter;
} }
if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/') if (normalizedPath.size() >= 2 && normalizedPath[normalizedPath.size() - 1] == '.' && normalizedPath[normalizedPath.size() - 2] == '.')
{ normalizedPath += "/";
// Remove trailing '/' if present
resolvedPath.pop_back(); return normalizedPath;
} }
return resolvedPath;
std::optional<std::string> resolvePath(std::string_view path, std::string_view baseFilePath)
{
std::optional<std::string> baseFilePathParent = getParentPath(baseFilePath);
if (!baseFilePathParent)
return std::nullopt;
return normalizePath(joinPaths(*baseFilePathParent, path));
} }
bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions) bool hasFileExtension(std::string_view name, const std::vector<std::string>& extensions)
@ -416,16 +385,16 @@ std::vector<std::string_view> splitPath(std::string_view path)
return components; return components;
} }
std::string joinPaths(const std::string& lhs, const std::string& rhs) std::string joinPaths(std::string_view lhs, std::string_view rhs)
{ {
std::string result = lhs; std::string result = std::string(lhs);
if (!result.empty() && result.back() != '/' && result.back() != '\\') if (!result.empty() && result.back() != '/' && result.back() != '\\')
result += '/'; result += '/';
result += rhs; result += rhs;
return result; return result;
} }
std::optional<std::string> getParentPath(const std::string& path) std::optional<std::string> getParentPath(std::string_view path)
{ {
if (path == "" || path == "." || path == "/") if (path == "" || path == "." || path == "/")
return std::nullopt; return std::nullopt;
@ -441,7 +410,7 @@ std::optional<std::string> getParentPath(const std::string& path)
return "/"; return "/";
if (slash != std::string::npos) if (slash != std::string::npos)
return path.substr(0, slash); return std::string(path.substr(0, slash));
return ""; return "";
} }
@ -471,10 +440,12 @@ std::vector<std::string> getSourceFiles(int argc, char** argv)
if (argv[i][0] == '-' && argv[i][1] != '\0') if (argv[i][0] == '-' && argv[i][1] != '\0')
continue; continue;
if (isDirectory(argv[i])) std::string normalized = normalizePath(argv[i]);
if (isDirectory(normalized))
{ {
traverseDirectory( traverseDirectory(
argv[i], normalized,
[&](const std::string& name) [&](const std::string& name)
{ {
std::string ext = getExtension(name); std::string ext = getExtension(name);
@ -486,7 +457,7 @@ std::vector<std::string> getSourceFiles(int argc, char** argv)
} }
else else
{ {
files.push_back(argv[i]); files.push_back(normalized);
} }
} }

View file

@ -31,7 +31,7 @@ static void setLuauFlags(bool state)
void setLuauFlagsDefault() void setLuauFlagsDefault()
{ {
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next) for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0 && !Luau::isFlagExperimental(flag->name)) if (strncmp(flag->name, "Luau", 4) == 0 && !Luau::isAnalysisFlagExperimental(flag->name))
flag->value = true; flag->value = true;
} }

View file

@ -791,8 +791,6 @@ int replMain(int argc, char** argv)
{ {
Luau::assertHandler() = assertionHandler; Luau::assertHandler() = assertionHandler;
setLuauFlagsDefault();
#ifdef _WIN32 #ifdef _WIN32
SetConsoleOutputCP(CP_UTF8); SetConsoleOutputCP(CP_UTF8);
#endif #endif

View file

@ -1,7 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Repl.h" #include "Luau/Repl.h"
#include "Luau/Flags.h"
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
setLuauFlagsDefault();
return replMain(argc, argv); return replMain(argc, argv);
} }

Some files were not shown because too many files have changed in this diff Show more