From 08c66ef2e17f685e2a2b3195909ac28816feb19a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 4 Nov 2021 19:07:18 -0700 Subject: [PATCH 001/102] Sync to upstream/release/502 Changes: - Support for time tracing for analysis/compiler (not currently exposed through CLI) - Support for type pack arguments in type aliases (#83) - Basic support for require(path) in luau-analyze - Add a lint warning for table.move with 0 index as part of TableOperation lint - Remove last STL dependency from Luau.VM - Minor VS2022 performance tuning Co-authored-by: Rodactor --- .gitignore | 7 + Analysis/include/Luau/BuiltinDefinitions.h | 3 +- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/FileResolver.h | 58 +- Analysis/include/Luau/Module.h | 10 +- Analysis/include/Luau/ModuleResolver.h | 6 - Analysis/include/Luau/RequireTracer.h | 5 +- Analysis/include/Luau/Scope.h | 67 ++ Analysis/include/Luau/Substitution.h | 14 +- Analysis/include/Luau/TypeInfer.h | 56 +- Analysis/include/Luau/TypePack.h | 3 +- Analysis/include/Luau/TypeVar.h | 22 +- Analysis/include/Luau/Unifier.h | 18 +- Analysis/src/AstQuery.cpp | 17 +- Analysis/src/Autocomplete.cpp | 49 +- Analysis/src/BuiltinDefinitions.cpp | 101 +-- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 21 - Analysis/src/Error.cpp | 108 ++- Analysis/src/Frontend.cpp | 51 +- Analysis/src/IostreamHelpers.cpp | 18 +- Analysis/src/JsonEncoder.cpp | 32 +- Analysis/src/Linter.cpp | 19 +- Analysis/src/Module.cpp | 44 +- Analysis/src/RequireTracer.cpp | 215 ++++- Analysis/src/Scope.cpp | 123 +++ Analysis/src/Substitution.cpp | 34 +- Analysis/src/ToString.cpp | 131 ++- Analysis/src/Transpiler.cpp | 38 +- Analysis/src/TypeAttach.cpp | 124 ++- Analysis/src/TypeInfer.cpp | 577 ++++++------ Analysis/src/TypePack.cpp | 13 + Analysis/src/TypeUtils.cpp | 18 +- Analysis/src/TypeVar.cpp | 36 +- Analysis/src/Unifier.cpp | 500 +++++++++-- Ast/include/Luau/Ast.h | 33 +- Ast/include/Luau/DenseHash.h | 5 +- Ast/include/Luau/Parser.h | 8 +- Ast/include/Luau/TimeTrace.h | 223 +++++ Ast/src/Ast.cpp | 37 +- Ast/src/Parser.cpp | 147 ++- Ast/src/TimeTrace.cpp | 248 ++++++ CLI/Analyze.cpp | 18 +- Compiler/src/Compiler.cpp | 10 + Sources.cmake | 5 + VM/src/ldo.cpp | 12 +- VM/src/lgc.cpp | 191 +++- VM/src/ltablib.cpp | 22 - VM/src/lvmexecute.cpp | 12 +- VM/src/lvmload.cpp | 34 +- bench/tests/deltablue.lua | 934 -------------------- tests/Autocomplete.test.cpp | 853 +++++++++--------- tests/Fixture.cpp | 49 + tests/Fixture.h | 2 + tests/Frontend.test.cpp | 31 +- tests/Linter.test.cpp | 9 +- tests/Module.test.cpp | 1 + tests/NonstrictMode.test.cpp | 1 + tests/Parser.test.cpp | 15 + tests/RequireTracer.test.cpp | 68 +- tests/ToString.test.cpp | 3 +- tests/TypeInfer.aliases.test.cpp | 557 ++++++++++++ tests/TypeInfer.provisional.test.cpp | 36 +- tests/TypeInfer.refinements.test.cpp | 72 +- tests/TypeInfer.tables.test.cpp | 230 +++-- tests/TypeInfer.test.cpp | 563 +----------- tests/TypeInfer.tryUnify.test.cpp | 1 + tests/TypeInfer.typePacks.cpp | 366 ++++++++ tests/TypeInfer.unionTypes.test.cpp | 16 +- tests/TypeVar.test.cpp | 1 + tools/tracegraph.py | 95 ++ 70 files changed, 4485 insertions(+), 2962 deletions(-) create mode 100644 .gitignore create mode 100644 Analysis/include/Luau/Scope.h create mode 100644 Analysis/src/Scope.cpp create mode 100644 Ast/include/Luau/TimeTrace.h create mode 100644 Ast/src/TimeTrace.cpp delete mode 100644 bench/tests/deltablue.lua create mode 100644 tests/TypeInfer.aliases.test.cpp create mode 100644 tools/tracegraph.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0b2422ce --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +^build/ +^coverage/ +^fuzz/luau.pb.* +^crash-* +^default.prof* +^fuzz-* +^luau$ diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 8f17fff6..57a1907a 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "TypeInfer.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" namespace Luau { diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 946bc928..ac6f13e9 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -120,6 +120,7 @@ struct IncorrectGenericParameterCount Name name; TypeFun typeFun; size_t actualParameters; + size_t actualPackParameters; bool operator==(const IncorrectGenericParameterCount& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 71f9464b..a05ec5e9 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -25,51 +25,39 @@ struct SourceCode Type type; }; +struct ModuleInfo +{ + ModuleName name; + bool optional = false; +}; + struct FileResolver { virtual ~FileResolver() {} - /** Fetch the source code associated with the provided ModuleName. - * - * FIXME: This requires a string copy! - * - * @returns The actual Lua code on success. - * @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error. - */ virtual std::optional readSource(const ModuleName& name) = 0; - /** Does the module exist? - * - * Saves a string copy over reading the source and throwing it away. - */ - virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) + { + return std::nullopt; + } - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - - /** Given a valid module name and a string of arbitrary data, figure out the concatenation. - */ - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - - /** Goes "up" a level in the hierarchy that the ModuleName represents. - * - * For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last - * element of the path. Other ModuleName representations may have other ways of doing this. - * - * @returns The parent ModuleName, if one exists. - * @returns std::nullopt if there is no parent for this module name. - */ - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; - - virtual std::optional getHumanReadableModuleName_(const ModuleName& name) const + virtual std::string getHumanReadableModuleName(const ModuleName& name) const { return name; } - virtual std::optional getEnvironmentForModule(const ModuleName& name) const = 0; + virtual std::optional getEnvironmentForModule(const ModuleName& name) const + { + return std::nullopt; + } - /** LanguageService only: - * std::optional fromInstance(Instance* inst) - */ + // DEPRECATED APIS + // These are going to be removed with LuauNewRequireTracer + virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional fromAstFragment(AstExpr* expr) const = 0; + virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; + virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - std::optional getEnvironmentForModule(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 413b68f4..d0844835 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -90,10 +90,12 @@ struct Module TypeArena internalTypes; std::vector> scopes; // never empty - std::unordered_map astTypes; - std::unordered_map astExpectedTypes; - std::unordered_map astOriginalCallTypes; - std::unordered_map astOverloadResolvedTypes; + + DenseHashMap astTypes{nullptr}; + DenseHashMap astExpectedTypes{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; + DenseHashMap astOverloadResolvedTypes{nullptr}; + std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index a394a21b..d892ccd7 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -15,12 +15,6 @@ struct Module; using ModulePtr = std::shared_ptr; -struct ModuleInfo -{ - ModuleName name; - bool optional = false; -}; - struct ModuleResolver { virtual ~ModuleResolver() {} diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index e9778876..c25545f5 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -17,12 +17,11 @@ struct AstLocal; struct RequireTraceResult { - DenseHashMap exprs{0}; - DenseHashMap optional{0}; + DenseHashMap exprs{nullptr}; std::vector> requires; }; -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h new file mode 100644 index 00000000..45338409 --- /dev/null +++ b/Analysis/include/Luau/Scope.h @@ -0,0 +1,67 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/TypeVar.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; + +using ScopePtr = std::shared_ptr; + +struct Binding +{ + TypeId typeId; + Location location; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional documentationSymbol; +}; + +struct Scope +{ + explicit Scope(TypePackId returnType); // root scope + explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. + + const ScopePtr parent; // null for the root + std::unordered_map bindings; + TypePackId returnType; + bool breakOk = false; + std::optional varargPack; + + TypeLevel level; + + std::unordered_map exportedTypeBindings; + std::unordered_map privateTypeBindings; + std::unordered_map typeAliasLocations; + + std::unordered_map> importedTypeBindings; + + std::optional lookup(const Symbol& name); + + std::optional lookupType(const Name& name); + std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + + std::unordered_map privateTypePackBindings; + std::optional lookupPack(const Name& name); + + // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); + + RefinementMap refinements; + + // For mutually recursive type aliases, it's important that + // they use the same types for the same names. + // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` + // we need that the generic type `T` in both cases is the same, so we use a cache. + std::unordered_map typeAliasTypeParameters; + std::unordered_map typeAliasTypePackParameters; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6ac868f7..80a14e8f 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -52,8 +52,6 @@ // `T`, and the type of `f` are in the same SCC, which is why `f` gets // replaced. -LUAU_FASTFLAG(DebugLuauTrackOwningArena) - namespace Luau { @@ -188,20 +186,12 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(tv); } template TypePackId addTypePack(const T& tp) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ec2a1a26..d701eb24 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -86,7 +86,10 @@ struct ApplyTypeFunction : Substitution { TypeLevel level; bool encounteredForwardedType; - std::unordered_map arguments; + std::unordered_map typeArguments; + std::unordered_map typePackArguments; + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId tp) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; @@ -328,7 +331,8 @@ private: TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location); + TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( @@ -398,54 +402,6 @@ private: int recursionCount = 0; }; -struct Binding -{ - TypeId typeId; - Location location; - bool deprecated = false; - std::string deprecatedSuggestion; - std::optional documentationSymbol; -}; - -struct Scope -{ - explicit Scope(TypePackId returnType); // root scope - explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. - - const ScopePtr parent; // null for the root - std::unordered_map bindings; - TypePackId returnType; - bool breakOk = false; - std::optional varargPack; - - TypeLevel level; - - std::unordered_map exportedTypeBindings; - std::unordered_map privateTypeBindings; - std::unordered_map typeAliasLocations; - - std::unordered_map> importedTypeBindings; - - std::optional lookup(const Symbol& name); - - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); - - std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); - - // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) - std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); - - RefinementMap refinements; - - // For mutually recursive type aliases, it's important that - // they use the same types for the same names. - // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` - // we need that the generic type `T` in both cases is the same, so we use a cache. - std::unordered_map typeAliasParameters; -}; - // Unit test hook void setPrintLine(void (*pl)(const std::string& s)); void resetPrintLine(); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 0d0adce7..d987d46c 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -117,7 +117,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -size_t size(const TypePackId tp); +size_t size(TypePackId tp); +bool finite(TypePackId tp); size_t size(const TypePack& tp); std::optional first(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 90a28b20..d4e4e491 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -228,6 +228,7 @@ struct TableTypeVar std::map methodDefinitionLocations; std::vector instantiatedTypeParams; + std::vector instantiatedTypePackParams; ModuleName definitionModuleName; std::optional boundTo; @@ -284,8 +285,9 @@ struct ClassTypeVar struct TypeFun { - /// These should all be generic + // These should all be generic std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -293,6 +295,20 @@ struct TypeFun * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. */ TypeId type; + + TypeFun() = default; + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } }; // Anything! All static checking is off. @@ -524,8 +540,4 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); -// TEMP: Clip this prototype with FFlag::LuauStringMetatable -std::optional> magicFunctionFormat( - struct TypeChecker& typechecker, const std::shared_ptr& scope, const AstExprCall& expr, ExprResult exprResult); - } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0ddc3cc0..522914b2 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -36,12 +36,17 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - std::shared_ptr counters; + UnifierCounters* counters; + UnifierCounters countersData; + + std::shared_ptr counters_DEPRECATED; + InternalErrorReporter* iceHandler; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters = nullptr); + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -58,11 +63,13 @@ private: void tryUnifyPrimitives(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void tryUnifyFreeTable(TypeId free, TypeId other); void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -80,9 +87,9 @@ private: public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -93,6 +100,9 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; }; } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index d3de1754..0aed34c0 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" @@ -143,8 +144,8 @@ std::optional findTypeAtPosition(const Module& module, const SourceModul { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - return it->second; + if (auto it = module.astTypes.find(expr)) + return *it; } return std::nullopt; @@ -154,8 +155,8 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) - return it->second; + if (auto it = module.astExpectedTypes.find(expr)) + return *it; } return std::nullopt; @@ -322,9 +323,9 @@ std::optional getDocumentationSymbolAtPosition(const Source TypeId matchingOverload = nullptr; if (parentExpr && parentExpr->is()) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end()) + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) { - matchingOverload = it->second; + matchingOverload = *it; } } @@ -345,9 +346,9 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (AstExprIndexName* indexName = targetExpr->as()) { - if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) + if (auto it = module.astTypes.find(indexName->expr)) { - TypeId parentTy = follow(it->second); + TypeId parentTy = follow(*it); if (const TableTypeVar* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dce92a0c..235abf36 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -210,10 +210,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ return TypeCorrectKind::None; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return TypeCorrectKind::None; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (canUnify(expectedType, ty)) return TypeCorrectKind::Correct; @@ -682,10 +682,10 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n return std::nullopt; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return std::nullopt; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (const FunctionTypeVar* ftv = get(expectedType)) return true; @@ -784,9 +784,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (AstExprCall* exprCall = expr->as()) { - if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end()) + if (auto it = module.astTypes.find(exprCall->func)) { - if (const FunctionTypeVar* ftv = get(follow(it->second))) + if (const FunctionTypeVar* ftv = get(follow(*it))) { if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) inferredType = *ty; @@ -798,8 +798,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (tailPos != 0) break; - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - inferredType = it->second; + if (auto it = module.astTypes.find(expr)) + inferredType = *it; } if (inferredType) @@ -815,10 +815,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return nullptr; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); if (const FunctionTypeVar* ftv = get(ty)) return ftv; @@ -1129,9 +1129,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (node->is()) { - auto it = module.astTypes.find(node->asExpr()); - if (it != module.astTypes.end()) - autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result); + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) return; @@ -1203,13 +1202,13 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } - auto parentIter = module->astTypes.find(parentExpr); - if (parentIter == module->astTypes.end()) + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) { return std::nullopt; } - Luau::TypeId parentType = Luau::follow(parentIter->second); + Luau::TypeId parentType = Luau::follow(*parentIt); if (auto parentClass = Luau::get(parentType)) { @@ -1250,8 +1249,8 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - auto iter = module->astTypes.find(candidate->func); - if (iter == module->astTypes.end()) + auto it = module->astTypes.find(candidate->func); + if (!it) { return std::nullopt; } @@ -1267,7 +1266,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; }; - auto followedId = Luau::follow(iter->second); + auto followedId = Luau::follow(*it); if (auto functionType = Luau::get(followedId)) { return performCallback(functionType); @@ -1316,10 +1315,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto indexName = node->as()) { auto it = module->astTypes.find(indexName->expr); - if (it == module->astTypes.end()) + if (!it) return {}; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (isString(ty)) @@ -1447,9 +1446,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If item doesn't have a key, maybe the value is actually the key if (key ? key == node : node->is() && value == node) { - if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end()) + if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1485,9 +1484,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { - if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end()) + if (auto it = module->astTypes.find(idxExpr->expr)) { - return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 68ad5ac9..3b0c2163 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -11,7 +11,7 @@ LUAU_FASTFLAG(LuauParseGenericFunctions) LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) +LUAU_FASTFLAG(LuauNewRequireTrace) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -218,7 +218,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ @@ -255,85 +254,18 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); - if (FFlag::LuauStringMetatable) + std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); + + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); + + addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); + + if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) { - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); - LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); - - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); - - TypeId stringLib = it->second.type; - addGlobalBinding(typeChecker, "string", stringLib, "@luau"); - } - - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringLibTy = getGlobalBinding(typeChecker, "string"); - TableTypeVar* stringLib = getMutable(stringLibTy); - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub"); - } - } - else - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType}); - - TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})}); - - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}); - - TableTypeVar::Props stringLib = { - // FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied - {"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}}, - // FIXME char takes a variadic pack of numbers - {"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}}, - {"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(arena, stringType, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}}, - {"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(arena, stringType, {stringType, optionalString}, - {arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}}, - {"pack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(arena, stringType, {}, {numberType})}}, - {"unpack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - addGlobalBinding(typeChecker, "string", - arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - } - TableTypeVar::Props debugLib{ {"info", {makeIntersection(arena, { @@ -601,9 +533,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); - auto stringLib = getMutable(getGlobalBinding(typeChecker, "string")); - attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat); - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } @@ -791,11 +720,11 @@ static std::optional> magicFunctionRequire( return std::nullopt; } - AstExpr* require = expr.args.data[0]; - - if (!checkRequirePath(typechecker, require)) + if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; + const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 61a63f06..1e91561a 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -206,27 +206,6 @@ std::string getBuiltinDefinitionSource() graphemes: (string, number?, number?) -> (() -> (number, number)), } - declare string: { - byte: (string, number?, number?) -> ...number, - char: (number, ...number) -> string, - find: (string, string, number?, boolean?) -> (number?, number?), - -- `string.format` has a magic function attached that will provide more type information for literal format strings. - format: (string, A...) -> string, - gmatch: (string, string) -> () -> (...string), - -- gsub is defined in C++ because we don't have syntax for describing a generic table. - len: (string) -> number, - lower: (string) -> string, - match: (string, string, number?) -> string?, - rep: (string, number) -> string, - reverse: (string) -> string, - sub: (string, number, number?) -> string, - upper: (string) -> string, - split: (string, string, string?) -> {string}, - pack: (string, A...) -> string, - packsize: (string) -> number, - unpack: (string, string, number?) -> R..., - } - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V )"; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 680bcf3f..92fbffc8 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,9 +7,9 @@ #include -LUAU_FASTFLAG(LuauFasterStringifier) +LUAU_FASTFLAG(LuauTypeAliasPacks) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) +static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) { std::string s = "expects " + std::to_string(expectedCount) + " "; @@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo return s; } +static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) +{ + std::string s; + + if (FFlag::LuauTypeAliasPacks) + { + s = "expects "; + + if (isVariadic) + s += "at least "; + + s += std::to_string(expectedCount) + " "; + } + else + { + s = "expects " + std::to_string(expectedCount) + " "; + } + + if (argPrefix) + s += std::string(argPrefix) + " "; + + s += "argument"; + if (expectedCount != 1) + s += "s"; + + s += ", but "; + + if (actualCount == 0) + { + s += "none"; + } + else + { + if (actualCount < expectedCount) + s += "only "; + + s += std::to_string(actualCount); + } + + s += (actualCount == 1) ? " is" : " are"; + + s += " specified"; + + return s; +} + namespace Luau { @@ -128,7 +174,10 @@ struct ErrorConverter else return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + if (FFlag::LuauTypeAliasPacks) + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + else + return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -160,13 +209,16 @@ struct ErrorConverter std::string operator()(const Luau::UnknownRequire& e) const { - return "Unknown require: " + e.modulePath; + if (e.modulePath.empty()) + return "Unknown require: unsupported path"; + else + return "Unknown require: " + e.modulePath; } std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty()) + if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) { name += "<"; bool first = true; @@ -179,10 +231,37 @@ struct ErrorConverter name += toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : e.typeFun.typePackParams) + { + if (first) + first = false; + else + name += ", "; + + name += toString(t); + } + } + name += ">"; } - return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + if (FFlag::LuauTypeAliasPacks) + { + if (e.typeFun.typeParams.size() != e.actualParameters) + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); + } + else + { + return "Generic type '" + name + "' " + + wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + } } std::string operator()(const Luau::SyntaxError& e) const @@ -471,9 +550,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; + if (FFlag::LuauTypeAliasPacks) + { + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; + } + for (size_t i = 0; i < typeFun.typeParams.size(); ++i) + { if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) return false; + } + + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) + { + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; + } + } return true; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4d385ec1..b2529840 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/Scope.h" #include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" #include "Luau/Common.h" @@ -19,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) +LUAU_FASTFLAG(LuauNewRequireTrace) namespace Luau { @@ -69,6 +73,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) { + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); @@ -350,6 +356,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) CheckResult Frontend::check(const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto it = sourceNodes.find(name); @@ -479,6 +488,9 @@ CheckResult Frontend::check(const ModuleName& name) bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) { + LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); + // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search enum Mark { @@ -597,6 +609,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); @@ -608,6 +623,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend"); + const Config& config = configResolver->getConfig(""); SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); @@ -627,6 +644,9 @@ std::pair Frontend::lintFragment(std::string_view sour CheckResult Frontend::check(const SourceModule& module) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); Mode mode = module.mode.value_or(config.mode); @@ -648,6 +668,9 @@ CheckResult Frontend::check(const SourceModule& module) LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); LintOptions options = enabledLintWarnings.value_or(config.enabledLint); @@ -746,6 +769,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + auto it = sourceNodes.find(name); if (it != sourceNodes.end() && !it->second.dirty) { @@ -815,6 +841,9 @@ std::pair Frontend::getSourceNode(CheckResult& check */ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) { + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + SourceModule sourceModule; double timestamp = getTimestamp(); @@ -864,20 +893,11 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; - const ModuleName* relativeName = exprs.find(&pathExpr); - if (!relativeName || relativeName->empty()) + const ModuleInfo* info = exprs.find(&pathExpr); + if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) return std::nullopt; - if (FFlag::LuauTraceRequireLookupChild) - { - const bool* optional = it->second.optional.find(&pathExpr); - - return {{*relativeName, optional ? *optional : false}}; - } - else - { - return {{*relativeName, false}}; - } + return *info; } const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const @@ -891,12 +911,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - return frontend->fileResolver->moduleExists(moduleName); + if (FFlag::LuauNewRequireTrace) + return frontend->sourceNodes.count(moduleName) != 0; + else + return frontend->fileResolver->moduleExists(moduleName); } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const { - return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); + return frontend->fileResolver->getHumanReadableModuleName(moduleName); } ScopePtr Frontend::addEnvironment(const std::string& environmentName) diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 84e9b77f..3b267121 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,6 +2,8 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" +LUAU_FASTFLAG(LuauTypeAliasPacks) + namespace Luau { @@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty()) + if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) { stream << "<"; bool first = true; @@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : error.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); + } + } + stream << ">"; } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index a1018297..064accba 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -3,6 +3,9 @@ #include "Luau/Ast.h" #include "Luau/StringUtils.h" +#include "Luau/Common.h" + +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); + + if (FFlag::LuauTypeAliasPacks) + { + PROP(genericPacks); + } + PROP(type); PROP(exported); }); @@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTypeOrPack node) + { + if (node.type) + write(node.type); + else + write(node.typePack); + } + void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { if (node->hasPrefix) PROP(prefix); PROP(name); - PROP(generics); + PROP(parameters); }); } @@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(class AstTypePackExplicit* node) + { + writeNode(node, "AstTypePackExplicit", [&]() { + PROP(typeList); + }); + } + void write(class AstTypePackVariadic* node) { writeNode(node, "AstTypePackVariadic", [&]() { @@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstTypePackExplicit* node) override + { + write(node); + return false; + } + bool visit(class AstTypePackVariadic* node) override { write(node); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f97f6a4a..bff947a5 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/Common.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) +LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) namespace Luau { @@ -85,10 +87,10 @@ struct LintContext return std::nullopt; auto it = module->astTypes.find(expr); - if (it == module->astTypes.end()) + if (!it) return std::nullopt; - return it->second; + return *it; } }; @@ -2144,6 +2146,19 @@ private: "wrap it in parentheses to silence"); } + if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + { + // table.move(t, 0, _, _) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // table.move(t, _, _, 0) + else if (isConstant(args[3], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + } + return true; } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index f1d975fe..df6be767 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -13,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -188,7 +190,7 @@ struct TypePackCloner template void defaultClone(const T& t) { - TypePackId cloned = dest.typePacks.allocate(t); + TypePackId cloned = dest.addTypePack(TypePackVar{t}); seenTypePacks[typePackId] = cloned; } @@ -197,7 +199,7 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); + seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); } void operator()(const Unifiable::Generic& t) @@ -219,13 +221,13 @@ struct TypePackCloner void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); seenTypePacks[typePackId] = cloned; } void operator()(const TypePack& t) { - TypePackId cloned = dest.typePacks.allocate(TypePack{}); + TypePackId cloned = dest.addTypePack(TypePack{}); TypePack* destTp = getMutable(cloned); LUAU_ASSERT(destTp != nullptr); seenTypePacks[typePackId] = cloned; @@ -241,7 +243,7 @@ struct TypePackCloner template void TypeCloner::defaultClone(const T& t) { - TypeId cloned = dest.typeVars.allocate(t); + TypeId cloned = dest.addType(t); seenTypes[typeId] = cloned; } @@ -250,7 +252,7 @@ void TypeCloner::operator()(const Unifiable::Free& t) if (encounteredFreeType) *encounteredFreeType = true; - seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{}); + seenTypes[typeId] = dest.addType(ErrorTypeVar{}); } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -275,7 +277,7 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const FunctionTypeVar& t) { - TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); FunctionTypeVar* ftv = getMutable(result); LUAU_ASSERT(ftv != nullptr); @@ -297,7 +299,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { - TypeId result = dest.typeVars.allocate(TableTypeVar{}); + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -323,7 +325,13 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); for (TypeId& arg : ttv->instantiatedTypeParams) - arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType)); + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + } if (ttv->state == TableState::Free) { @@ -343,7 +351,7 @@ void TypeCloner::operator()(const TableTypeVar& t) void TypeCloner::operator()(const MetatableTypeVar& t) { - TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); + TypeId result = dest.addType(MetatableTypeVar{}); MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; @@ -353,7 +361,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t) void TypeCloner::operator()(const ClassTypeVar& t) { - TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); ClassTypeVar* ctv = getMutable(result); seenTypes[typeId] = result; @@ -378,7 +386,7 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.typeVars.allocate(UnionTypeVar{}); + TypeId result = dest.addType(UnionTypeVar{}); seenTypes[typeId] = result; UnionTypeVar* option = getMutable(result); @@ -390,7 +398,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) void TypeCloner::operator()(const IntersectionTypeVar& t) { - TypeId result = dest.typeVars.allocate(IntersectionTypeVar{}); + TypeId result = dest.addType(IntersectionTypeVar{}); seenTypes[typeId] = result; IntersectionTypeVar* option = getMutable(result); @@ -451,8 +459,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) { TypeFun result; - for (TypeId param : typeFun.typeParams) - result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType)); + for (TypeId ty : typeFun.typeParams) + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 5b3997e2..ad4d5ef4 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) namespace Luau { @@ -12,17 +13,18 @@ namespace Luau namespace { -struct RequireTracer : AstVisitor +struct RequireTracerOld : AstVisitor { - explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) + explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) : fileResolver(fileResolver) - , currentModuleName(std::move(currentModuleName)) + , currentModuleName(currentModuleName) { + LUAU_ASSERT(!FFlag::LuauNewRequireTrace); } FileResolver* const fileResolver; ModuleName currentModuleName; - DenseHashMap locals{0}; + DenseHashMap locals{nullptr}; RequireTraceResult result; std::optional fromAstFragment(AstExpr* expr) @@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor AstExpr* expr = stat->values.data[i]; expr->visit(this); - const ModuleName* name = result.exprs.find(expr); - if (name) - locals[local] = *name; + const ModuleInfo* info = result.exprs.find(expr); + if (info) + locals[local] = info->name; } } @@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor { std::optional name = fromAstFragment(global); if (name) - result.exprs[global] = *name; + result.exprs[global] = {*name}; return false; } @@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor { const ModuleName* name = locals.find(local->local); if (name) - result.exprs[local] = *name; + result.exprs[local] = {*name}; return false; } @@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor { indexName->expr->visit(this); - const ModuleName* name = result.exprs.find(indexName->expr); - if (name) + const ModuleInfo* info = result.exprs.find(indexName->expr); + if (info) { if (indexName->index == "parent" || indexName->index == "Parent") { - if (auto parent = fileResolver->getParentModuleName(*name)) - result.exprs[indexName] = *parent; + if (auto parent = fileResolver->getParentModuleName(info->name)) + result.exprs[indexName] = {*parent}; } else - result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value); + result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; } return false; @@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor { indexExpr->expr->visit(this); - const ModuleName* name = result.exprs.find(indexExpr->expr); + const ModuleInfo* info = result.exprs.find(indexExpr->expr); const AstExprConstantString* str = indexExpr->index->as(); - if (name && str) + if (info && str) { - result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size)); + result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; } indexExpr->index->visit(this); @@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor AstExprGlobal* globalName = call->func->as(); if (globalName && globalName->name == "require" && call->args.size >= 1) { - if (const ModuleName* moduleName = result.exprs.find(call->args.data[0])) - result.requires.push_back({*moduleName, call->location}); + if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) + result.requires.push_back({moduleInfo->name, call->location}); return false; } @@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor if (FFlag::LuauTraceRequireLookupChild && !rootName) { - if (const ModuleName* moduleName = result.exprs.find(indexName->expr)) - rootName = *moduleName; + if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) + rootName = moduleInfo->name; } if (!rootName) @@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor if (v.end() != std::find(v.begin(), v.end(), '/')) return false; - result.exprs[call] = fileResolver->concat(*rootName, v); + result.exprs[call] = {fileResolver->concat(*rootName, v)}; // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.optional[call] = true; + result.exprs[call].optional = true; return false; } }; +struct RequireTracer : AstVisitor +{ + RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + : result(result) + , fileResolver(fileResolver) + , currentModuleName(currentModuleName) + , locals(nullptr) + { + LUAU_ASSERT(FFlag::LuauNewRequireTrace); + } + + bool visit(AstExprTypeAssertion* expr) override + { + // suppress `require() :: any` + return false; + } + + bool visit(AstExprCall* expr) override + { + AstExprGlobal* global = expr->func->as(); + + if (global && global->name == "require" && expr->args.size >= 1) + requires.push_back(expr); + + return true; + } + + bool visit(AstStatLocal* stat) override + { + for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) + { + AstLocal* local = stat->vars.data[i]; + AstExpr* expr = stat->values.data[i]; + + // track initializing expression to be able to trace modules through locals + locals[local] = expr; + } + + return true; + } + + bool visit(AstStatAssign* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + { + // locals that are assigned don't have a known expression + if (AstExprLocal* expr = stat->vars.data[i]->as()) + locals[expr->local] = nullptr; + } + + return true; + } + + bool visit(AstType* node) override + { + // allow resolving require inside `typeof` annotations + return true; + } + + AstExpr* getDependent(AstExpr* node) + { + if (AstExprLocal* expr = node->as()) + return locals[expr->local]; + else if (AstExprIndexName* expr = node->as()) + return expr->expr; + else if (AstExprIndexExpr* expr = node->as()) + return expr->expr; + else if (AstExprCall* expr = node->as(); expr && expr->self) + return expr->func->as()->expr; + else + return nullptr; + } + + void process() + { + ModuleInfo moduleContext{currentModuleName}; + + // seed worklist with require arguments + work.reserve(requires.size()); + + for (AstExprCall* require: requires) + work.push_back(require->args.data[0]); + + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work.size(); ++i) + if (AstExpr* dep = getDependent(work[i])) + work.push_back(dep); + + // resolve all expressions to a module info + for (size_t i = work.size(); i > 0; --i) + { + AstExpr* expr = work[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstExpr* dep = getDependent(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + // locals just inherit their dependent context, no resolution required + if (expr->is()) + info = context ? std::optional(*context) : std::nullopt; + else + info = fileResolver->resolveModule(context, expr); + } + else + { + info = fileResolver->resolveModule(&moduleContext, expr); + } + + if (info) + result.exprs[expr] = std::move(*info); + } + + // resolve all requires according to their argument + result.requires.reserve(requires.size()); + + for (AstExprCall* require : requires) + { + AstExpr* arg = require->args.data[0]; + + if (const ModuleInfo* info = result.exprs.find(arg)) + { + result.requires.push_back({info->name, require->location}); + + ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! + result.exprs[require] = std::move(infoCopy); + } + else + { + result.exprs[require] = {}; // mark require as unresolved + } + } + } + + RequireTraceResult& result; + FileResolver* fileResolver; + ModuleName currentModuleName; + + DenseHashMap locals; + std::vector work; + std::vector requires; +}; + } // anonymous namespace -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName) +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - RequireTracer tracer{fileResolver, std::move(currentModuleName)}; - root->visit(&tracer); - return tracer.result; + if (FFlag::LuauNewRequireTrace) + { + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; + } + else + { + RequireTracerOld tracer{fileResolver, currentModuleName}; + root->visit(&tracer); + return tracer.result; + } } } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp new file mode 100644 index 00000000..c30db9c2 --- /dev/null +++ b/Analysis/src/Scope.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" + +namespace Luau +{ + +Scope::Scope(TypePackId returnType) + : parent(nullptr) + , returnType(returnType) + , level(TypeLevel()) +{ +} + +Scope::Scope(const ScopePtr& parent, int subLevel) + : parent(parent) + , returnType(parent->returnType) + , level(parent->level.incr()) +{ + level.subLevel = subLevel; +} + +std::optional Scope::lookup(const Symbol& name) +{ + Scope* scope = this; + + while (scope) + { + auto it = scope->bindings.find(name); + if (it != scope->bindings.end()) + return it->second.typeId; + + scope = scope->parent.get(); + } + + return std::nullopt; +} + +std::optional Scope::lookupType(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->exportedTypeBindings.find(name); + if (it != scope->exportedTypeBindings.end()) + return it->second; + + it = scope->privateTypeBindings.find(name); + if (it != scope->privateTypeBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +{ + const Scope* scope = this; + while (scope) + { + auto it = scope->importedTypeBindings.find(moduleAlias); + if (it == scope->importedTypeBindings.end()) + { + scope = scope->parent.get(); + continue; + } + + auto it2 = it->second.find(name); + if (it2 == it->second.end()) + { + scope = scope->parent.get(); + continue; + } + + return it2->second; + } + + return std::nullopt; +} + +std::optional Scope::lookupPack(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->privateTypePackBindings.find(name); + if (it != scope->privateTypePackBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) +{ + Scope* scope = this; + + while (scope) + { + for (const auto& [n, binding] : scope->bindings) + { + if (n.local && n.local->name == name.c_str()) + return binding; + else if (n.global.value && n.global == name.c_str()) + return binding; + } + + scope = scope->parent.get(); + + if (!traverseScopeChain) + break; + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 7223998a..d861eb3d 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -6,9 +6,11 @@ #include #include -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -35,8 +37,15 @@ void Tarjan::visitChildren(TypeId ty, int index) visitChild(ttv->indexer->indexType); visitChild(ttv->indexer->indexResultType); } + for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); + } } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -332,9 +341,11 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; } @@ -350,9 +361,11 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; } @@ -382,6 +395,10 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) clone.tags = ttv->tags; result = addType(std::move(clone)); @@ -487,8 +504,15 @@ void Substitution::replaceChildren(TypeId ty) ttv->indexer->indexType = replace(ttv->indexer->indexType); ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); } + for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); + } } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9d2f47ba..5651af7e 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -9,10 +10,10 @@ #include #include -LUAU_FASTFLAG(LuauToStringFollowsBoundTo) LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -59,6 +60,13 @@ struct FindCyclicTypes { for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); + } + return exhaustive; } @@ -258,23 +266,60 @@ struct TypeVarStringifier void stringify(TypePackId tp); void stringify(TypePackId tpid, const std::vector>& names); - void stringify(const std::vector& types) + void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0) + if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) return; - if (types.size()) + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit("<"); - for (size_t i = 0; i < types.size(); ++i) + if (FFlag::LuauTypeAliasPacks) { - if (i > 0) - state.emit(", "); + bool first = true; - stringify(types[i]); + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; + + stringify(ty); + } + + bool singleTp = typePacks.size() == 1; + + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; + + if (!first) + state.emit(", "); + else + first = false; + + if (!singleTp) + state.emit("("); + + stringify(tp); + + if (!singleTp) + state.emit(")"); + } + } + else + { + for (size_t i = 0; i < types.size(); ++i) + { + if (i > 0) + state.emit(", "); + + stringify(types[i]); + } } - if (types.size()) + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit(">"); } @@ -388,7 +433,7 @@ struct TypeVarStringifier void operator()(TypeId, const TableTypeVar& ttv) { - if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo) + if (ttv.boundTo) return stringify(*ttv.boundTo); if (!state.exhaustive) @@ -411,14 +456,14 @@ struct TypeVarStringifier } state.emit(*ttv.name); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } if (ttv.syntheticName) { state.result.invalid = true; state.emit(*ttv.syntheticName); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } } @@ -900,13 +945,26 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) return result; std::vector params; for (TypeId tp : ttv->instantiatedTypeParams) params.push_back(toString(tp)); + if (FFlag::LuauTypeAliasPacks) + { + // Doesn't preserve grouping of multiple type packs + // But this is under a parent block of code that is being removed later + for (TypePackId tp : ttv->instantiatedTypePackParams) + { + std::string content = toString(tp); + + if (!content.empty()) + params.push_back(std::move(content)); + } + } + result.name += "<" + join(params, ", ") + ">"; return result; } @@ -950,30 +1008,37 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) + if (FFlag::LuauTypeAliasPacks) { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); } else { - result.name += ">"; + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) + return result; + + result.name += "<"; + + bool first = true; + for (TypeId ty : ttv->instantiatedTypeParams) + { + if (!first) + result.name += ", "; + else + first = false; + + tvs.stringify(ty); + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + { + result.truncated = true; + result.name += "... "; + } + else + { + result.name += ">"; + } } return result; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 462c70ff..1b83ccdc 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { @@ -280,10 +281,19 @@ struct Printer void visualizeTypePackAnnotation(const AstTypePack& annotation) { - if (const AstTypePackVariadic* variadic = annotation.as()) + if (const AstTypePackVariadic* variadicTp = annotation.as()) { writer.symbol("..."); - visualizeTypeAnnotation(*variadic->variadicType); + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + visualizeTypeList(explicitTp->typeList, true); } else { @@ -807,7 +817,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0) + if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -817,6 +827,17 @@ struct Printer comma(); writer.identifier(o.value); } + + if (FFlag::LuauTypeAliasPacks) + { + for (auto o : a->genericPacks) + { + comma(); + writer.identifier(o.value); + writer.symbol("..."); + } + } + writer.symbol(">"); } writer.maybeSpace(a->type->location.begin, 2); @@ -960,15 +981,20 @@ struct Printer if (const auto& a = typeAnnotation.as()) { writer.write(a->name.value); - if (a->generics.size > 0) + if (a->parameters.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); - for (auto o : a->generics) + for (auto o : a->parameters) { comma(); - visualizeTypeAnnotation(*o); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack); } + writer.symbol(">"); } } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 17c57c84..266c1986 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { @@ -33,7 +35,6 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data namespace Luau { - class TypeRehydrationVisitor { mutable std::map seen; @@ -57,6 +58,8 @@ public: { } + AstTypePack* rehydrate(TypePackId tp) const; + AstType* operator()(const PrimitiveTypeVar& ptv) const { switch (ptv.type) @@ -85,16 +88,24 @@ public: if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) { - AstArray generics; - generics.size = ttv.instantiatedTypeParams.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstType*) * generics.size)); + AstArray parameters; + parameters.size = ttv.instantiatedTypeParams.size(); + parameters.data = static_cast(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size)); for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) { - generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); + parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) + { + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; + } + } + + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); } if (hasSeen(&ttv)) @@ -222,10 +233,17 @@ public: AstTypePack* argTailAnnotation = nullptr; if (argTail) { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + argTailAnnotation = rehydrate(*argTail); + } + else + { + TypePackId tail = *argTail; + if (const VariadicTypePack* vtp = get(tail)) + { + argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -255,10 +273,17 @@ public: AstTypePack* retTailAnnotation = nullptr; if (retTail) { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + retTailAnnotation = rehydrate(*retTail); + } + else + { + TypePackId tail = *retTail; + if (const VariadicTypePack* vtp = get(tail)) + { + retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -313,6 +338,68 @@ private: const TypeRehydrationOptions& options; }; +class TypePackRehydrationVisitor +{ +public: + TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + : allocator(allocator) + , typeVisitor(typeVisitor) + { + } + + AstTypePack* operator()(const BoundTypePack& btp) const + { + return Luau::visit(*this, btp.boundTo->ty); + } + + AstTypePack* operator()(const TypePack& tp) const + { + AstArray head; + head.size = tp.head.size(); + head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); + + for (size_t i = 0; i < tp.head.size(); i++) + head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + + AstTypePack* tail = nullptr; + + if (tp.tail) + tail = Luau::visit(*this, (*tp.tail)->ty); + + return allocator->alloc(Location(), AstTypeList{head, tail}); + } + + AstTypePack* operator()(const VariadicTypePack& vtp) const + { + return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + } + + AstTypePack* operator()(const GenericTypePack& gtp) const + { + return allocator->alloc(Location(), AstName(gtp.name.c_str())); + } + + AstTypePack* operator()(const FreeTypePack& gtp) const + { + return allocator->alloc(Location(), AstName("free")); + } + + AstTypePack* operator()(const Unifiable::Error&) const + { + return allocator->alloc(Location(), AstName("Unifiable")); + } + +private: + Allocator* allocator; + const TypeRehydrationVisitor& typeVisitor; +}; + +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +{ + TypePackRehydrationVisitor tprv(allocator, *this); + return Luau::visit(tprv, tp->ty); +} + class TypeAttacher : public AstVisitor { public: @@ -406,9 +493,16 @@ public: if (tail) { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + if (FFlag::LuauTypeAliasPacks) + { + variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + } + else + { + TypePackId tailPack = *tail; + if (const VariadicTypePack* vtp = get(tailPack)) + variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + } } fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 2216881b..3a1fdfff 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5,21 +5,22 @@ #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/TimeTrace.h" #include #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0) -LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) @@ -27,14 +28,11 @@ LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) -LUAU_FASTFLAG(DebugLuauTrackOwningArena) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false) LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) @@ -45,6 +43,10 @@ LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) +LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -216,9 +218,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) - , booleanType( - FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))) - , threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread))) + , booleanType(singletonTypes.booleanType) + , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) @@ -237,6 +238,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { + LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + currentModule.reset(new Module()); currentModule->type = module.type; @@ -1177,44 +1181,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + + if (FFlag::LuauTypeAliasPacks) + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + else + bindingsMap[name] = TypeFun{binding->typeParams, errorType}; } else { ScopePtr aliasScope = childScope(scope, typealias.location); - std::vector generics; - for (AstName generic : typealias.generics) + if (FFlag::LuauTypeAliasPacks) { - Name n = generic.value; + auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) - { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } - - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; - } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; } + else + { + std::vector generics; + for (AstName generic : typealias.generics) + { + Name n = generic.value; - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; + // These generics are the only thing that will ever be added to aliasScope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); + } + + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction) + { + TypeId& cached = scope->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{aliasScope->level, n}); + g = cached; + } + else + g = addType(GenericTypeVar{aliasScope->level, n}); + generics.push_back(g); + aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; + } + + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), ty}; + } } } else @@ -1231,6 +1252,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : binding->typePackParams) + { + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; + } + } + TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); if (auto ttv = getMutable(follow(ty))) { @@ -1238,7 +1269,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (ttv->name) { // Copy can be skipped if this is an identical alias - if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams) + if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || + (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1249,6 +1281,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.name = name; clone.instantiatedTypeParams = binding->typeParams; + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = binding->typePackParams; + ty = addType(std::move(clone)); } } @@ -1256,6 +1291,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1280,7 +1318,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; if (FFlag::LuauAddMissingFollow) @@ -1465,7 +1503,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FFlag::LuauStoreMatchingOverloadFnType) { - currentModule->astTypes.try_emplace(&expr, result.type); + if (!currentModule->astTypes.find(&expr)) + currentModule->astTypes[&expr] = result.type; } else { @@ -2193,7 +2232,7 @@ TypeId TypeChecker::checkRelationalOperation( * have a better, more descriptive error teed up. */ Unifier state = mkUnifier(expr.location); - if (!FFlag::LuauEqConstraint || !isEquality) + if (!isEquality) state.tryUnify(lhsType, rhsType); bool needsMetamethod = !isEquality; @@ -2262,7 +2301,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality)) + if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2276,18 +2315,6 @@ TypeId TypeChecker::checkRelationalOperation( return errorType; } - if (!FFlag::LuauEqConstraint) - { - if (isEquality) - { - ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location); - if (!state.errors.empty() && !errVec.empty()) - reportError(expr.location, TypeMismatch{lhsType, rhsType}); - } - else - reportErrors(state.errors); - } - return booleanType; } @@ -2443,7 +2470,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)) + else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; @@ -2466,14 +2493,6 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - // Once we have EqPredicate, we should break this else branch into its' own branch. - // For now, fall through is intentional. - if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - ExprResult lhs = checkExpr(scope, *expr.left); ExprResult rhs = checkExpr(scope, *expr.right); @@ -2755,12 +2774,6 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } - else if (FFlag::LuauIndexTablesWithIndexers) - { - // We allow t[x] where x:string for tables without an indexer - unify(indexType, stringType, expr.location); - return std::pair(anyType, nullptr); - } else { TypeId resultType = freshType(scope); @@ -3076,6 +3089,13 @@ static Location getEndLocation(const AstExprFunction& function) void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkFunctionBody", "TypeChecker"); + + if (function.debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", function.debugname.value); + else + LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str()); + if (FunctionTypeVar* funTy = getMutable(ty)) { check(scope, *function.body); @@ -3885,6 +3905,20 @@ std::optional TypeChecker::matchRequire(const AstExprCall& call) TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); + + if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + { + if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + { + reportError(TypeError{location, UnknownRequire{}}); + return errorType; + } + + return anyType; + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4472,7 +4506,7 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level))); + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) @@ -4482,11 +4516,7 @@ TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneri TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric))); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4506,20 +4536,12 @@ TypeId TypeChecker::addType(const UnionTypeVar& utv) TypeId TypeChecker::addTV(TypeVar&& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePackVar&& tv) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePack&& tp) @@ -4578,7 +4600,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print") { - if (lit->generics.size != 1) + if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); return addType(ErrorTypeVar{}); @@ -4588,7 +4610,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation opts.exhaustive = true; opts.maxTableLength = 0; - TypeId param = resolveType(scope, *lit->generics.data[0]); + TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); return param; } @@ -4614,18 +4636,86 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(ErrorTypeVar{}); } - if (lit->generics.size == 0 && tf->typeParams.empty()) - return tf->type; - else if (lit->generics.size != tf->typeParams.size()) + if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}}); + return tf->type; + } + else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) + { + reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); return addType(ErrorTypeVar{}); } + else if (FFlag::LuauTypeAliasPacks) + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + return addType(ErrorTypeVar{}); + } + + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; + + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) + { + TypeId ty = resolveType(scope, *type); + + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); + } + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) + { + TypePackId tp = resolveTypePack(scope, *typePack); + + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); + else + typePackParams.push_back(tp); + } + } + + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); + + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + return addType(ErrorTypeVar{}); + } + + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; + } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); + } else { std::vector typeParams; - for (AstType* paramAnnot : lit->generics) - typeParams.push_back(resolveType(scope, *paramAnnot)); + + for (const auto& param : lit->parameters) + typeParams.push_back(resolveType(scope, *param.type)); if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { @@ -4634,7 +4724,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return tf->type; } - return instantiateTypeFun(scope, *tf, typeParams, annotation.location); + return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); } } else if (const auto& table = annotation.as()) @@ -4765,6 +4855,18 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack return *genericTy; } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + std::vector types; + + for (auto type : explicitTp->typeList.types) + types.push_back(resolveType(scope, *type)); + + if (auto tailType = explicitTp->typeList.tailType) + return addTypePack(types, resolveTypePack(scope, *tailType)); + + return addTypePack(types); + } else { ice("Unknown AstTypePack kind"); @@ -4799,12 +4901,28 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) return false; } +bool ApplyTypeFunction::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypePackId tp) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + return true; + else + return false; +} + TypeId ApplyTypeFunction::clean(TypeId ty) { // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - TypeId& arg = arguments[ty]; + TypeId& arg = typeArguments[ty]; if (arg) return arg; else @@ -4816,17 +4934,37 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - return addTypePack(FreeTypePack{level}); + if (FFlag::LuauTypeAliasPacks) + { + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; + else + return addTypePack(FreeTypePack{level}); + } + else + { + return addTypePack(FreeTypePack{level}); + } } -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location) +TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty()) + if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) return tf.type; - applyTypeFunction.arguments.clear(); + applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + + if (FFlag::LuauTypeAliasPacks) + { + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + } + applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; applyTypeFunction.encounteredForwardedType = false; @@ -4875,6 +5013,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4890,6 +5031,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4899,6 +5043,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, std::pair, std::vector> TypeChecker::createGenericTypes( const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { + LUAU_ASSERT(scope->parent); + std::vector generics; for (const AstName& generic : genericNames) { @@ -4912,7 +5058,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypeId g = addType(Unifiable::Generic{scope->level, n}); + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypeId& cached = scope->parent->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{scope->level, n}); + g = cached; + } + else + { + g = addType(Unifiable::Generic{scope->level, n}); + } + generics.push_back(g); scope->privateTypeBindings[n] = TypeFun{{}, g}; } @@ -4930,7 +5088,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + TypePackId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = cached; + } + else + { + g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + } + genericPacks.push_back(g); scope->privateTypePackBindings[n] = g; } @@ -5013,13 +5183,8 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme else if (auto isaP = get(predicate)) resolve(*isaP, errVec, refis, scope, sense); else if (auto typeguardP = get(predicate)) - { - if (FFlag::LuauImprovedTypeGuardPredicate2) - resolve(*typeguardP, errVec, refis, scope, sense); - else - DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense); - } - else if (auto eqP = get(predicate); eqP && FFlag::LuauEqConstraint) + resolve(*typeguardP, errVec, refis, scope, sense); + else if (auto eqP = get(predicate)) resolve(*eqP, errVec, refis, scope, sense); else ice("Unhandled predicate kind"); @@ -5145,7 +5310,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } } - else if (FFlag::LuauImprovedTypeGuardPredicate2) + else { auto lctv = get(option); auto rctv = get(isaP.ty); @@ -5159,19 +5324,6 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement if (canUnify(option, isaP.ty, isaP.location).empty() == sense) return isaP.ty; } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (lctv && rctv) - { - if (isSubclass(lctv, rctv) == sense) - return option; - else if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - } - } return std::nullopt; }; @@ -5266,7 +5418,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty()) + if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); @@ -5292,7 +5444,8 @@ void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, Error "userdata", // no op. Requires special handling. }; - if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty()) + if (auto typeFun = globalScope->lookupType(typeguardP.kind); + typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) { if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) addRefinement(refis, typeguardP.lvalue, typeFun->type); @@ -5319,38 +5472,41 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return; } - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; - - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); - - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + if (FFlag::LuauEqConstraint) { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); + + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + addRefinement(refis, eqP.lvalue, eqP.type); + return; } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + + std::unordered_set set; + for (TypeId left : lhs) + { + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } + } + + if (set.empty()) + return; + + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); } - - if (set.empty()) - return; - - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); } bool TypeChecker::isNonstrictMode() const @@ -5379,119 +5535,4 @@ std::vector> TypeChecker::getScopes() const return currentModule->scopes; } -Scope::Scope(TypePackId returnType) - : parent(nullptr) - , returnType(returnType) - , level(TypeLevel()) -{ -} - -Scope::Scope(const ScopePtr& parent, int subLevel) - : parent(parent) - , returnType(parent->returnType) - , level(parent->level.incr()) -{ - level.subLevel = subLevel; -} - -std::optional Scope::lookup(const Symbol& name) -{ - Scope* scope = this; - - while (scope) - { - auto it = scope->bindings.find(name); - if (it != scope->bindings.end()) - return it->second.typeId; - - scope = scope->parent.get(); - } - - return std::nullopt; -} - -std::optional Scope::lookupType(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->exportedTypeBindings.find(name); - if (it != scope->exportedTypeBindings.end()) - return it->second; - - it = scope->privateTypeBindings.find(name); - if (it != scope->privateTypeBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) -{ - const Scope* scope = this; - while (scope) - { - auto it = scope->importedTypeBindings.find(moduleAlias); - if (it == scope->importedTypeBindings.end()) - { - scope = scope->parent.get(); - continue; - } - - auto it2 = it->second.find(name); - if (it2 == it->second.end()) - { - scope = scope->parent.get(); - continue; - } - - return it2->second; - } - - return std::nullopt; -} - -std::optional Scope::lookupPack(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->privateTypePackBindings.find(name); - if (it != scope->privateTypePackBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) -{ - Scope* scope = this; - - while (scope) - { - for (const auto& [n, binding] : scope->bindings) - { - if (n.local && n.local->name == name.c_str()) - return binding; - else if (n.global.value && n.global == name.c_str()) - return binding; - } - - scope = scope->parent.get(); - - if (!traverseScopeChain) - break; - } - - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5970f304..68a16ef0 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -209,6 +209,19 @@ size_t size(TypePackId tp) return 0; } +bool finite(TypePackId tp) +{ + tp = follow(tp); + + if (auto pack = get(tp)) + return pack->tail ? finite(*pack->tail) : true; + + if (auto pack = get(tp)) + return false; + + return true; +} + size_t size(const TypePack& tp) { size_t result = tp.head.size(); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index b9f50978..0d9d91e0 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,11 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAG(LuauStringMetatable) - namespace Luau { @@ -13,21 +12,6 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto it = globalScope->bindings.find(AstName{"string"}); - if (it != globalScope->bindings.end()) - return it->second.typeId; - else - return std::nullopt; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 111f4f53..e963fc74 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,11 +19,9 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) -LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -193,27 +191,11 @@ bool isOptional(TypeId ty) bool isTableIntersection(TypeId ty) { - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - if (!get(follow(ty))) - return false; - - std::vector parts = flattenIntersection(ty); - return std::all_of(parts.begin(), parts.end(), getTableType); - } - else - { - if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId part : itv->parts) - { - if (getTableType(follow(part))) - return true; - } - } - + if (!get(follow(ty))) return false; - } + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), getTableType); } bool isOverloadedFunction(TypeId ty) @@ -236,7 +218,7 @@ std::optional getMetatable(TypeId type) else if (const ClassTypeVar* classType = get(type)) return classType->metatable; else if (const PrimitiveTypeVar* primitiveType = get(type); - FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable) + primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -871,6 +853,12 @@ void StateDot::visitChildren(TypeId ty, int index) } for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp, index, "typeParam"); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } } else if (const MetatableTypeVar* mtv = get(ty)) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 89c3f80c..117cbc28 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -3,23 +3,25 @@ #include "Luau/Common.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/TimeTrace.h" #include LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0); -LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) namespace Luau { @@ -43,21 +45,23 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(std::make_shared()) + , counters(&countersData) + , counters_DEPRECATED(std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters) + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(seen) , location(location) , variance(variance) - , counters(counters ? counters : std::make_shared()) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); @@ -65,16 +69,26 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::v void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; + + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -440,7 +454,11 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTp, subTp, isFunctionCall); } @@ -450,10 +468,16 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; + + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -762,9 +786,210 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - std::unique_ptr resetter; + if (!FFlag::LuauTableSubtypingVariance) + return DEPRECATED_tryUnifyTables(left, right, isIntersection); - resetter.reset(new Resetter{&variance}); + TableTypeVar* lt = getMutable(left); + TableTypeVar* rt = getMutable(right); + if (!lt || !rt) + ice("passed non-table types to unifyTables"); + + std::vector missingProperties; + std::vector extraProperties; + + // Reminder: left is the supertype, right is the subtype. + // Width subtyping: any property in the supertype must be in the subtype, + // and the types must agree. + for (const auto& [name, prop] : lt->props) + { + const auto& r = rt->props.find(name); + if (r != rt->props.end()) + { + // TODO: read-only properties don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, r->second.type); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (rt->indexer && isString(rt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (rt->state == TableState::Free) + { + log(rt); + rt->props[name] = prop; + } + else + missingProperties.push_back(name); + } + + for (const auto& [name, prop] : rt->props) + { + if (lt->props.count(name)) + { + // If both lt and rt contain the property, then + // we're done since we already unified them above + } + else if (lt->indexer && isString(lt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->state == TableState::Unsealed) + { + // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. + // TODO: file a JIRA + // TODO: hopefully readonly/writeonly properties will fix this. + Property clone = prop; + clone.type = deeplyOptional(clone.type); + log(lt); + lt->props[name] = clone; + } + else if (variance == Covariant) + {} + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (lt->state == TableState::Free) + { + log(lt); + lt->props[name] = prop; + } + else + extraProperties.push_back(name); + } + + // Unify indexers + if (lt->indexer && rt->indexer) + { + // TODO: read-only indexers don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(*lt->indexer, *rt->indexer); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->indexer) + { + if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + { + // passing/assigning a table without an indexer to something that has one + // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. + // TODO: we only need to do this if the supertype's indexer is read/write + // since that can add indexed elements. + log(rt); + rt->indexer = lt->indexer; + } + } + else if (rt->indexer && variance == Invariant) + { + // Symmetric if we are invariant + if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + { + log(lt); + lt->indexer = rt->indexer; + } + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + + /* + * TypeVars are commonly cyclic, so it is entirely possible + * for unifying a property of a table to change the table itself! + * We need to check for this and start over if we notice this occurring. + * + * I believe this is guaranteed to terminate eventually because this will + * only happen when a free table is bound to another table. + */ + if (lt->boundTo || rt->boundTo) + return tryUnify_(left, right); + + if (lt->state == TableState::Free) + { + log(lt); + lt->boundTo = right; + } + else if (rt->state == TableState::Free) + { + log(rt); + rt->boundTo = left; + } +} + +TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) +{ + ty = follow(ty); + if (get(ty)) + return ty; + else if (isOptional(ty)) + return ty; + else if (const TableTypeVar* ttv = get(ty)) + { + TypeId& result = seen[ty]; + if (result) + return result; + result = types->addType(*ttv); + TableTypeVar* resultTtv = getMutable(result); + for (auto& [name, prop] : resultTtv->props) + prop.type = deeplyOptional(prop.type, seen); + return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + } + else + return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); +} + +void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +{ + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + Resetter resetter{&variance}; variance = Invariant; TableTypeVar* lt = getMutable(left); @@ -894,10 +1119,7 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) if (!freeTable->boundTo && otherTable->state != TableState::Free) { - if (FFlag::LuauLogTableTypeVarBoundTo) - log(freeTable); - else - log(freeTypeId); + log(freeTable); freeTable->boundTo = otherTypeId; } } @@ -1196,9 +1418,11 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack( +static void queueTypePack_DEPRECATED( std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + while (true) { if (FFlag::LuauAddMissingFollow) @@ -1244,6 +1468,55 @@ static void queueTypePack( } } +static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (true) + { + if (FFlag::LuauAddMissingFollow) + a = follow(a); + + if (seenTypePacks.find(a)) + break; + seenTypePacks.insert(a); + + if (FFlag::LuauAddMissingFollow) + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + else + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + + if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + } +} + void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) { const VariadicTypePack* lv = get(superTp); @@ -1297,9 +1570,11 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny( +static void tryUnifyWithAny_DEPRECATED( std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + std::unordered_set seen; while (!queue.empty()) @@ -1310,6 +1585,59 @@ static void tryUnifyWithAny( continue; seen.insert(ty); + if (get(ty)) + { + state.log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } +} + +static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, + TypeId anyType, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (!queue.empty()) + { + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + if (get(ty)) { state.log(ty); @@ -1354,14 +1682,33 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); + if (FFlag::LuauTypecheckOpts) + { + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; + } + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue = {ty}; - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + tempSeenTy.clear(); + tempSeenTp.clear(); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue = {ty}; + + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + } } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1370,12 +1717,26 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorType; - std::unordered_set seenTypePacks; - std::vector queue; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue; - queueTypePack(queue, seenTypePacks, *this, ty, any); + tempSeenTy.clear(); + tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any); + queueTypePack(queue, tempSeenTp, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue; + + queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); + + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); + } } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -1387,21 +1748,6 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto found = globalScope->bindings.find(AstName{"string"}); - if (found == globalScope->bindings.end()) - return std::nullopt; - else - return found->second.typeId; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; @@ -1427,21 +1773,36 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTy.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (seen.end() != seen.find(haystack)) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; + + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1458,7 +1819,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI } auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); + occursCheck(seen_DEPRECATED, seen, needle, tv); }; if (get(haystack)) @@ -1488,19 +1849,33 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTp.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (seen.find(haystack) != seen.end()) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; + + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1508,7 +1883,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -1528,8 +1904,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl { if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); + occursCheck(seen_DEPRECATED, seen, needle, f->retType); } } } @@ -1546,7 +1922,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters}; + return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index df38cfec..a2189f7b 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -264,6 +264,10 @@ public: { return false; } + virtual bool visit(class AstTypePackExplicit* node) + { + return visit((class AstTypePack*)node); + } virtual bool visit(class AstTypePackVariadic* node) { return visit((class AstTypePack*)node); @@ -930,12 +934,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, + AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -1007,19 +1013,28 @@ public: } }; +// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use +struct AstTypeOrPack +{ + AstType* type = nullptr; + AstTypePack* typePack = nullptr; +}; + class AstTypeReference : public AstType { public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics = {}); + AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, + const AstArray& parameters = {}); void visit(AstVisitor* visitor) override; bool hasPrefix; + bool hasParameterList; AstName prefix; AstName name; - AstArray generics; + AstArray parameters; }; struct AstTableProp @@ -1152,6 +1167,18 @@ public: } }; +class AstTypePackExplicit : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackExplicit) + + AstTypePackExplicit(const Location& location, AstTypeList typeList); + + void visit(AstVisitor* visitor) override; + + AstTypeList typeList; +}; + class AstTypePackVariadic : public AstTypePack { public: diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 02924e88..a7b2515a 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -136,7 +136,10 @@ public: const Key& key = ItemInterface::getKey(data[i]); if (!eq(key, empty_key)) - *newtable.insert_unsafe(key) = data[i]; + { + Item* item = newtable.insert_unsafe(key); + *item = std::move(data[i]); + } } LUAU_ASSERT(count == newtable.count); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e6ebd503..42c64dc9 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -218,13 +218,14 @@ private: AstTableIndexer* parseTableIndexerAnnotation(); - AstType* parseFunctionTypeAnnotation(); + AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); - AstType* parseSimpleTypeAnnotation(); + AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstTypeOrPack parseTypeOrPackAnnotation(); AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); AstType* parseTypeAnnotation(); @@ -284,7 +285,7 @@ private: std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams(); AstExpr* parseString(); @@ -413,6 +414,7 @@ private: std::vector scratchLocal; std::vector scratchTableTypeProps; std::vector scratchAnnotation; + std::vector scratchTypeOrPackAnnotation; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h new file mode 100644 index 00000000..641dfd3c --- /dev/null +++ b/Ast/include/Luau/TimeTrace.h @@ -0,0 +1,223 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Common.h" + +#include + +#include + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +uint32_t getClockMicroseconds(); + +struct Token +{ + const char* name; + const char* category; +}; + +enum class EventType : uint8_t +{ + Enter, + Leave, + + ArgName, + ArgValue, +}; + +struct Event +{ + EventType type; + uint16_t token; + + union + { + uint32_t microsec; // 1 hour trace limit + uint32_t dataPos; + } data; +}; + +struct GlobalContext; +struct ThreadContext; + +GlobalContext& getGlobalContext(); + +uint16_t createToken(GlobalContext& context, const char* name, const char* category); +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); +void releaseThread(GlobalContext& context, ThreadContext* threadContext); +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data); + +struct ThreadContext +{ + ThreadContext() + : globalContext(getGlobalContext()) + { + threadId = createThread(globalContext, this); + } + + ~ThreadContext() + { + if (!events.empty()) + flushEvents(); + + releaseThread(globalContext, this); + } + + void flushEvents() + { + static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + + events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); + + TimeTrace::flushEvents(globalContext, threadId, events, data); + + events.clear(); + data.clear(); + + events.push_back({EventType::Leave, 0, {getClockMicroseconds()}}); + } + + void eventEnter(uint16_t token) + { + eventEnter(token, getClockMicroseconds()); + } + + void eventEnter(uint16_t token, uint32_t microsec) + { + events.push_back({EventType::Enter, token, {microsec}}); + } + + void eventLeave() + { + eventLeave(getClockMicroseconds()); + } + + void eventLeave(uint32_t microsec) + { + events.push_back({EventType::Leave, 0, {microsec}}); + + if (events.size() > kEventFlushLimit) + flushEvents(); + } + + void eventArgument(const char* name, const char* value) + { + uint32_t pos = uint32_t(data.size()); + data.insert(data.end(), name, name + strlen(name) + 1); + events.push_back({EventType::ArgName, 0, {pos}}); + + pos = uint32_t(data.size()); + data.insert(data.end(), value, value + strlen(value) + 1); + events.push_back({EventType::ArgValue, 0, {pos}}); + } + + GlobalContext& globalContext; + uint32_t threadId; + std::vector events; + std::vector data; + + static constexpr size_t kEventFlushLimit = 8192; +}; + +ThreadContext& getThreadContext(); + +struct Scope +{ + explicit Scope(ThreadContext& context, uint16_t token) + : context(context) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventEnter(token); + } + + ~Scope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventLeave(); + } + + ThreadContext& context; +}; + +struct OptionalTailScope +{ + explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) + : context(context) + , token(token) + , threshold(threshold) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + pos = uint32_t(context.events.size()); + microsec = getClockMicroseconds(); + } + + ~OptionalTailScope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + if (pos == context.events.size()) + { + uint32_t curr = getClockMicroseconds(); + + if (curr - microsec > threshold) + { + context.eventEnter(token, microsec); + context.eventLeave(curr); + } + } + } + + ThreadContext& context; + uint16_t token; + uint32_t threshold; + uint32_t microsec; + uint32_t pos; +}; + +LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); + +} // namespace TimeTrace +} // namespace Luau + +// Regular scope +#define LUAU_TIMETRACE_SCOPE(name, category) \ + static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + +// A scope without nested scopes that may be skipped if the time it took is less than the threshold +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ + static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + +// Extra key/value data can be added to regular scopes +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ + lttScopeStatic.second.eventArgument(name, value); \ + } while (false) + +#else + +#define LUAU_TIMETRACE_SCOPE(name, category) +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + } while (false) + +#endif diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index fff1537d..b1209faa 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) + , genericPacks(genericPacks) , type(type) , exported(exported) { @@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics) +AstTypeReference::AstTypeReference( + const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) , hasPrefix(bool(prefix)) + , hasParameterList(hasParameterList) , prefix(prefix ? *prefix : AstName()) , name(name) - , generics(generics) + , parameters(parameters) { } @@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (AstType* generic : generics) - generic->visit(visitor); + for (const AstTypeOrPack& param : parameters) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } } } @@ -849,6 +858,24 @@ void AstTypeError::visit(AstVisitor* visitor) } } +AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList) + : AstTypePack(ClassIndex(), location) + , typeList(typeList) +{ +} + +void AstTypePackExplicit::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : typeList.types) + type->visit(visitor); + + if (typeList.tailType) + typeList.tailType->visit(visitor); + } +} + AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType) : AstTypePack(ClassIndex(), location) , variadicType(variadicType) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6672efe8..40026d8b 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + #include // Warning: If you are introducing new syntax, ensure that it is behind a separate @@ -13,6 +15,8 @@ LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) namespace Luau { @@ -148,6 +152,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer) ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) { + LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); + Parser p(buffer, bufferSize, names, allocator); try @@ -769,14 +775,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - // TODO: support generic type pack parameters in type aliases CLI-39907 auto [generics, genericPacks] = parseGenericTypeList(); expectAndConsume('=', "type alias"); AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), name->name, generics, type, exported); + return allocator.alloc( + Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1333,7 +1339,7 @@ AstType* Parser::parseTableTypeAnnotation() // ReturnType ::= TypeAnnotation | `(' TypeList `)' // FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseFunctionTypeAnnotation() +AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1364,14 +1370,23 @@ AstType* Parser::parseFunctionTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; - // Not a function at all. Just a parenthesized type. - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) - return params[0]; - AstArray paramTypes = copy(params); + + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element + if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + { + if (allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; + else + return {params[0], {}}; + } + + if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; + AstArray> paramNames = copy(names); - return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation); + return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, @@ -1421,7 +1436,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isUnion = true; } else if (c == '?') @@ -1434,7 +1449,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isIntersection = true; } else @@ -1462,6 +1477,30 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } +AstTypeOrPack Parser::parseTypeOrPackAnnotation() +{ + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + TempVector parts(scratchAnnotation); + + auto [type, typePack] = parseSimpleTypeAnnotation(true); + + if (typePack) + { + LUAU_ASSERT(!type); + return {{}, typePack}; + } + + parts.push_back(type); + + recursionCounter = oldRecursionCount; + + return {parseTypeAnnotation(parts, begin), {}}; +} + AstType* Parser::parseTypeAnnotation() { unsigned int oldRecursionCount = recursionCounter; @@ -1470,7 +1509,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); recursionCounter = oldRecursionCount; @@ -1479,7 +1518,7 @@ AstType* Parser::parseTypeAnnotation() // typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseSimpleTypeAnnotation() +AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1488,7 +1527,7 @@ AstType* Parser::parseSimpleTypeAnnotation() if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return allocator.alloc(begin, std::nullopt, nameNil); + return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } else if (lexer.current().type == Lexeme::Name) { @@ -1514,22 +1553,41 @@ AstType* Parser::parseSimpleTypeAnnotation() expectMatchAndConsume(')', typeofBegin); - return allocator.alloc(Location(begin, end), expr); + return {allocator.alloc(Location(begin, end), expr), {}}; } - AstArray generics = parseTypeParams(); + if (FFlag::LuauParseTypePackTypeParameters) + { + bool hasParameters = false; + AstArray parameters{}; - Location end = lexer.previousLocation(); + if (lexer.current().type == '<') + { + hasParameters = true; + parameters = parseTypeParams(); + } - return allocator.alloc(Location(begin, end), prefix, name.name, generics); + Location end = lexer.previousLocation(); + + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; + } + else + { + AstArray generics = parseTypeParams(); + + Location end = lexer.previousLocation(); + + // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks + return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; + } } else if (lexer.current().type == '{') { - return parseTableTypeAnnotation(); + return {parseTableTypeAnnotation(), {}}; } else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) { - return parseFunctionTypeAnnotation(); + return parseFunctionTypeAnnotation(allowPack); } else { @@ -1538,7 +1596,7 @@ AstType* Parser::parseSimpleTypeAnnotation() // For a missing type annoation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); - return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -2312,18 +2370,59 @@ std::pair, AstArray> Parser::parseGenericTypeList() return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams() { - TempVector result{scratchAnnotation}; + TempVector parameters{scratchTypeOrPackAnnotation}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); nextLexeme(); + bool seenPack = false; while (true) { - result.push_back(parseTypeAnnotation()); + if (FFlag::LuauParseTypePackTypeParameters) + { + if (shouldParseTypePackAnnotation(lexer)) + { + seenPack = true; + + auto typePack = parseTypePackAnnotation(); + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + { + seenPack = true; + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else + { + parameters.push_back({type, {}}); + } + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + if (lexer.current().type == ',') nextLexeme(); else @@ -2333,7 +2432,7 @@ AstArray Parser::parseTypeParams() expectMatchAndConsume('>', begin); } - return copy(result); + return copy(parameters); } AstExpr* Parser::parseString() diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp new file mode 100644 index 00000000..e6aab20e --- /dev/null +++ b/Ast/src/TimeTrace.cpp @@ -0,0 +1,248 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TimeTrace.h" + +#include "Luau/StringUtils.h" + +#include +#include + +#include + +#ifdef _WIN32 +#include +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +static double getClockPeriod() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceFrequency(&result); + return 1.0 / double(result.QuadPart); +#elif defined(__APPLE__) + mach_timebase_info_data_t result = {}; + mach_timebase_info(&result); + return double(result.numer) / double(result.denom) * 1e-9; +#elif defined(__linux__) + return 1e-9; +#else + return 1.0 / double(CLOCKS_PER_SEC); +#endif +} + +static double getClockTimestamp() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceCounter(&result); + return double(result.QuadPart); +#elif defined(__APPLE__) + return double(mach_absolute_time()); +#elif defined(__linux__) + timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_sec * 1e9 + now.tv_nsec; +#else + return double(clock()); +#endif +} + +uint32_t getClockMicroseconds() +{ + static double period = getClockPeriod() * 1e6; + static double start = getClockTimestamp(); + + return uint32_t((getClockTimestamp() - start) * period); +} + +struct GlobalContext +{ + GlobalContext() = default; + ~GlobalContext() + { + // Ideally we would want all ThreadContext destructors to run + // But in VS, not all thread_local object instances are destroyed + for (ThreadContext* context : threads) + context->flushEvents(); + + if (traceFile) + fclose(traceFile); + } + + std::mutex mutex; + std::vector threads; + uint32_t nextThreadId = 0; + std::vector tokens; + FILE* traceFile = nullptr; +}; + +GlobalContext& getGlobalContext() +{ + static GlobalContext context; + return context; +} + +uint16_t createToken(GlobalContext& context, const char* name, const char* category) +{ + std::scoped_lock lock(context.mutex); + + LUAU_ASSERT(context.tokens.size() < 64 * 1024); + + context.tokens.push_back({name, category}); + return uint16_t(context.tokens.size() - 1); +} + +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + context.threads.push_back(threadContext); + + return ++context.nextThreadId; +} + +void releaseThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end()) + context.threads.erase(it); +} + +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data) +{ + std::scoped_lock lock(context.mutex); + + if (!context.traceFile) + { + context.traceFile = fopen("trace.json", "w"); + + if (!context.traceFile) + return; + + fprintf(context.traceFile, "[\n"); + } + + std::string temp; + const unsigned tempReserve = 64 * 1024; + temp.reserve(tempReserve); + + const char* rawData = data.data(); + + // Formatting state + bool unfinishedEnter = false; + bool unfinishedArgs = false; + + for (const Event& ev : events) + { + switch (ev.type) + { + case EventType::Enter: + { + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + Token& token = context.tokens[ev.token]; + + formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category, + ev.data.microsec, threadId); + unfinishedEnter = true; + } + break; + case EventType::Leave: + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + formatAppend(temp, + R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)" + "\n", + ev.data.microsec, threadId); + break; + case EventType::ArgName: + LUAU_ASSERT(unfinishedEnter); + + if (!unfinishedArgs) + { + formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos); + unfinishedArgs = true; + } + else + { + formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos); + } + break; + case EventType::ArgValue: + LUAU_ASSERT(unfinishedArgs); + formatAppend(temp, R"("%s")", rawData + ev.data.dataPos); + break; + } + + // Don't want to hit the string capacity and reallocate + if (temp.size() > tempReserve - 1024) + { + fwrite(temp.data(), 1, temp.size(), context.traceFile); + temp.clear(); + } + } + + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + fwrite(temp.data(), 1, temp.size(), context.traceFile); + fflush(context.traceFile); +} + +ThreadContext& getThreadContext() +{ + thread_local ThreadContext context; + return context; +} + +std::pair createScopeData(const char* name, const char* category) +{ + uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); + return {token, Luau::TimeTrace::getThreadContext()}; +} +} // namespace TimeTrace +} // namespace Luau + +#endif diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 920502b8..ed0552d7 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -111,11 +111,24 @@ struct CliFileResolver : Luau::FileResolver return Luau::SourceCode{*source, Luau::SourceCode::Module}; } + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override + { + if (Luau::AstExprConstantString* expr = node->as()) + { + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + + return {{name}}; + } + + return std::nullopt; + } + bool moduleExists(const Luau::ModuleName& name) const override { return !!readFile(name); } + std::optional fromAstFragment(Luau::AstExpr* expr) const override { return std::nullopt; @@ -130,11 +143,6 @@ struct CliFileResolver : Luau::FileResolver { return std::nullopt; } - - std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 022eccb7..797ee20d 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4,6 +4,7 @@ #include "Luau/Parser.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/TimeTrace.h" #include #include @@ -137,6 +138,11 @@ struct Compiler uint32_t compileFunction(AstExprFunction* func) { + LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); + + if (func->debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value); + LUAU_ASSERT(!functions.contains(func)); LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty()); @@ -3686,6 +3692,8 @@ struct Compiler void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) { + LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block table imports @@ -3748,6 +3756,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) { + LUAU_TIMETRACE_SCOPE("compile", "Compiler"); + Allocator allocator; AstNameTable names(allocator); ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions); diff --git a/Sources.cmake b/Sources.cmake index 6f96f6ab..83ed5230 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/ParseOptions.h Ast/include/Luau/Parser.h Ast/include/Luau/StringUtils.h + Ast/include/Luau/TimeTrace.h Ast/src/Ast.cpp Ast/src/Confusables.cpp @@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE Ast/src/Location.cpp Ast/src/Parser.cpp Ast/src/StringUtils.cpp + Ast/src/TimeTrace.cpp ) # Luau.Compiler Sources @@ -46,6 +48,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Predicate.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h + Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/TopoSortStatements.h @@ -75,6 +78,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Module.cpp Analysis/src/Predicate.cpp Analysis/src/RequireTracer.cpp + Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp Analysis/src/TopoSortStatements.cpp @@ -188,6 +192,7 @@ if(TARGET Luau.UnitTest) tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp + tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.builtins.test.cpp tests/TypeInfer.classes.test.cpp diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ee4962a5..39a61597 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -8,9 +8,13 @@ #include "lmem.h" #include "lvm.h" -#include - +#if LUA_USE_LONGJMP #include +#include +#else +#include +#endif + #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) @@ -51,8 +55,8 @@ l_noret luaD_throw(lua_State* L, int errcode) longjmp(jb->buf, 1); } - if (L->global->panic) - L->global->panic(L, errcode); + if (L->global->cb.panic) + L->global->cb.panic(L, errcode); abort(); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 9b040fb5..510a9f54 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) +LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) + LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -810,6 +812,133 @@ static size_t singlestep(lua_State* L) return cost; } +static size_t gcstep(lua_State* L, size_t limit) +{ + size_t cost = 0; + global_State* g = L->global; + switch (g->gcstate) + { + case GCSpause: + { + markroot(L); /* start a new collection */ + break; + } + case GCSpropagate: + { + if (FFlag::LuauRescanGrayAgain) + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; + + g->gcstate = GCSpropagateagain; + } + } + else + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + } + break; + } + case GCSpropagateagain: + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + break; + } + case GCSsweepstring: + { + while (g->sweepstrgc < g->strt.size && cost < limit) + { + size_t traversedcount = 0; + sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPCOST; + } + + // nothing more to sweep? + if (g->sweepstrgc >= g->strt.size) + { + // sweep string buffer list and preserve used string count + uint32_t nuse = L->global->strt.nuse; + + size_t traversedcount = 0; + sweepwholelist(L, &g->strbufgc, &traversedcount); + + L->global->strt.nuse = nuse; + + g->gcstats.currcycle.sweepitems += traversedcount; + g->gcstate = GCSsweep; // end sweep-string phase + } + break; + } + case GCSsweep: + { + while (*g->sweepgc && cost < limit) + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPMAX * GC_SWEEPCOST; + } + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } + break; + } + default: + LUAU_ASSERT(0); + } + return cost; +} + static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) { // adjust for error using Proportional-Integral controller @@ -878,33 +1007,40 @@ void luaC_step(lua_State* L, bool assist) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (assist) - g->gcstats.currcycle.assistwork += lim; - else - g->gcstats.currcycle.explicitwork += lim; - int lastgcstate = g->gcstate; double lastttimestamp = lua_clock(); - // always perform at least one single step - do + if (FFlag::LuauConsolidatedStep) { - lim -= singlestep(L); + size_t work = gcstep(L, lim); - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) + if (assist) + g->gcstats.currcycle.assistwork += work; + else + g->gcstats.currcycle.explicitwork += work; + } + else + { + // always perform at least one single step + do { - GC_INTERRUPT(lastgcstate); + lim -= singlestep(L); - double now = lua_clock(); + // if we have switched to a different state, capture the duration of last stage + // this way we reduce the number of timer calls we make + if (lastgcstate != g->gcstate) + { + GC_INTERRUPT(lastgcstate); - recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + double now = lua_clock(); - lastttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); + recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + + lastttimestamp = now; + lastgcstate = g->gcstate; + } + } while (lim > 0 && g->gcstate != GCSpause); + } recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); @@ -931,7 +1067,14 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - GC_INTERRUPT(g->gcstate); + if (FFlag::LuauConsolidatedStep) + { + GC_INTERRUPT(lastgcstate); + } + else + { + GC_INTERRUPT(g->gcstate); + } } void luaC_fullgc(lua_State* L) @@ -957,7 +1100,10 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } finishGcCycleStats(g); @@ -968,7 +1114,10 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 090e183f..de5788eb 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,14 +9,8 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false) - LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) -bool lua_telemetry_table_move_oob_src_from = false; -bool lua_telemetry_table_move_oob_src_to = false; -bool lua_telemetry_table_move_oob_dst = false; - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -202,22 +196,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - if (DFFlag::LuauTableMoveTelemetry) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - lua_telemetry_table_move_oob_src_from = true; - if (!(e == nf || (e >= 1 && e <= nf))) - lua_telemetry_table_move_oob_src_to = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt + 1))) - lua_telemetry_table_move_oob_dst = true; - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 5f0ee922..eed2862b 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op) return op == LOP_PREPVARARGS || op == LOP_BREAK; } -// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv -LUAI_FUNC int luaB_inext(lua_State* L); -LUAI_FUNC int luaB_next(lua_State* L); - template static void luau_execute(lua_State* L) { @@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: ipairs/inext - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext; - if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } @@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: pairs/next - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next; - if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 0a232342..b932a85b 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -12,7 +12,32 @@ #include -#include +// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens +template +struct TempBuffer +{ + lua_State* L; + T* data; + size_t count; + + TempBuffer(lua_State* L, size_t count) + : L(L) + , data(luaM_newarray(L, count, T, 0)) + , count(count) + { + } + + ~TempBuffer() + { + luaM_freearray(L, data, count, T, 0); + } + + T& operator[](size_t index) + { + LUAU_ASSERT(index < count); + return data[index]; + } +}; void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) { @@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset) return result; } -static TString* readString(std::vector& strings, const char* data, size_t size, size_t& offset) +static TString* readString(TempBuffer& strings, const char* data, size_t size, size_t& offset) { unsigned int id = readVarInt(data, size, offset); @@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size } // pause GC for the duration of deserialization - some objects we're creating aren't rooted + // TODO: if an allocation error happens mid-load, we do not unpause GC! size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; @@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // string table unsigned int stringCount = readVarInt(data, size, offset); - std::vector strings(stringCount); + TempBuffer strings(L, stringCount); for (unsigned int i = 0; i < stringCount; ++i) { @@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // proto table unsigned int protoCount = readVarInt(data, size, offset); - std::vector protos(protoCount); + TempBuffer protos(L, protoCount); for (unsigned int i = 0; i < protoCount; ++i) { diff --git a/bench/tests/deltablue.lua b/bench/tests/deltablue.lua deleted file mode 100644 index ecf246d3..00000000 --- a/bench/tests/deltablue.lua +++ /dev/null @@ -1,934 +0,0 @@ -local bench = script and require(script.Parent.bench_support) or require("bench_support") - --- Copyright 2008 the V8 project authors. All rights reserved. --- Copyright 1996 John Maloney and Mario Wolczko. - --- This program is free software; you can redistribute it and/or modify --- it under the terms of the GNU General Public License as published by --- the Free Software Foundation; either version 2 of the License, or --- (at your option) any later version. --- --- This program is distributed in the hope that it will be useful, --- but WITHOUT ANY WARRANTY; without even the implied warranty of --- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the --- GNU General Public License for more details. --- --- You should have received a copy of the GNU General Public License --- along with this program; if not, write to the Free Software --- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - - --- This implementation of the DeltaBlue benchmark is derived --- from the Smalltalk implementation by John Maloney and Mario --- Wolczko. Some parts have been translated directly, whereas --- others have been modified more aggresively to make it feel --- more like a JavaScript program. - - --- --- A JavaScript implementation of the DeltaBlue constraint-solving --- algorithm, as described in: --- --- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver" --- Bjorn N. Freeman-Benson and John Maloney --- January 1990 Communications of the ACM, --- also available as University of Washington TR 89-08-06. --- --- Beware: this benchmark is written in a grotesque style where --- the constraint model is built by side-effects from constructors. --- I've kept it this way to avoid deviating too much from the original --- implementation. --- - -function class(base) - local T = {} - T.__index = T - - if base then - T.super = base - setmetatable(T, base) - end - - function T.new(...) - local O = {} - setmetatable(O, T) - O:constructor(...) - return O - end - - return T -end - -local planner - ---- O b j e c t M o d e l --- - -local function alert (...) print(...) end - -local OrderedCollection = class() - -function OrderedCollection:constructor() - self.elms = {} -end - -function OrderedCollection:add(elm) - self.elms[#self.elms + 1] = elm -end - -function OrderedCollection:at (index) - return self.elms[index] -end - -function OrderedCollection:size () - return #self.elms -end - -function OrderedCollection:removeFirst () - local e = self.elms[#self.elms] - self.elms[#self.elms] = nil - return e -end - -function OrderedCollection:remove (elm) - local index = 0 - local skipped = 0 - - for i = 1, #self.elms do - local value = self.elms[i] - if value ~= elm then - self.elms[index] = value - index = index + 1 - else - skipped = skipped + 1 - end - end - - local l = #self.elms - for i = 1, skipped do self.elms[l - i + 1] = nil end -end - --- --- S t r e n g t h --- - --- --- Strengths are used to measure the relative importance of constraints. --- New strengths may be inserted in the strength hierarchy without --- disrupting current constraints. Strengths cannot be created outside --- this class, so pointer comparison can be used for value comparison. --- - -local Strength = class() - -function Strength:constructor(strengthValue, name) - self.strengthValue = strengthValue - self.name = name -end - -function Strength.stronger (s1, s2) - return s1.strengthValue < s2.strengthValue -end - -function Strength.weaker (s1, s2) - return s1.strengthValue > s2.strengthValue -end - -function Strength.weakestOf (s1, s2) - return Strength.weaker(s1, s2) and s1 or s2 -end - -function Strength.strongest (s1, s2) - return Strength.stronger(s1, s2) and s1 or s2 -end - -function Strength:nextWeaker () - local v = self.strengthValue - if v == 0 then return Strength.WEAKEST - elseif v == 1 then return Strength.WEAK_DEFAULT - elseif v == 2 then return Strength.NORMAL - elseif v == 3 then return Strength.STRONG_DEFAULT - elseif v == 4 then return Strength.PREFERRED - elseif v == 5 then return Strength.REQUIRED - end -end - --- Strength constants. -Strength.REQUIRED = Strength.new(0, "required"); -Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred"); -Strength.PREFERRED = Strength.new(2, "preferred"); -Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault"); -Strength.NORMAL = Strength.new(4, "normal"); -Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault"); -Strength.WEAKEST = Strength.new(6, "weakest"); - --- --- C o n s t r a i n t --- - --- --- An abstract class representing a system-maintainable relationship --- (or "constraint") between a set of variables. A constraint supplies --- a strength instance variable; concrete subclasses provide a means --- of storing the constrained variables and other information required --- to represent a constraint. --- - -local Constraint = class () - -function Constraint:constructor(strength) - self.strength = strength -end - --- --- Activate this constraint and attempt to satisfy it. --- -function Constraint:addConstraint () - self:addToGraph() - planner:incrementalAdd(self) -end - --- --- Attempt to find a way to enforce this constraint. If successful, --- record the solution, perhaps modifying the current dataflow --- graph. Answer the constraint that this constraint overrides, if --- there is one, or nil, if there isn't. --- Assume: I am not already satisfied. --- -function Constraint:satisfy (mark) - self:chooseMethod(mark) - if not self:isSatisfied() then - if self.strength == Strength.REQUIRED then - alert("Could not satisfy a required constraint!") - end - return nil - end - self:markInputs(mark) - local out = self:output() - local overridden = out.determinedBy - if overridden ~= nil then overridden:markUnsatisfied() end - out.determinedBy = self - if not planner:addPropagate(self, mark) then alert("Cycle encountered") end - out.mark = mark - return overridden -end - -function Constraint:destroyConstraint () - if self:isSatisfied() - then planner:incrementalRemove(self) - else self:removeFromGraph() - end -end - --- --- Normal constraints are not input constraints. An input constraint --- is one that depends on external state, such as the mouse, the --- keybord, a clock, or some arbitraty piece of imperative code. --- -function Constraint:isInput () - return false -end - - --- --- U n a r y C o n s t r a i n t --- - --- --- Abstract superclass for constraints having a single possible output --- variable. --- - -local UnaryConstraint = class(Constraint) - -function UnaryConstraint:constructor (v, strength) - UnaryConstraint.super.constructor(self, strength) - self.myOutput = v - self.satisfied = false - self:addConstraint() -end - --- --- Adds this constraint to the constraint graph --- -function UnaryConstraint:addToGraph () - self.myOutput:addConstraint(self) - self.satisfied = false -end - --- --- Decides if this constraint can be satisfied and records that --- decision. --- -function UnaryConstraint:chooseMethod (mark) - self.satisfied = (self.myOutput.mark ~= mark) - and Strength.stronger(self.strength, self.myOutput.walkStrength); -end - --- --- Returns true if this constraint is satisfied in the current solution. --- -function UnaryConstraint:isSatisfied () - return self.satisfied; -end - -function UnaryConstraint:markInputs (mark) - -- has no inputs -end - --- --- Returns the current output variable. --- -function UnaryConstraint:output () - return self.myOutput -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function UnaryConstraint:recalculate () - self.myOutput.walkStrength = self.strength - self.myOutput.stay = not self:isInput() - if self.myOutput.stay then - self:execute() -- Stay optimization - end -end - --- --- Records that this constraint is unsatisfied --- -function UnaryConstraint:markUnsatisfied () - self.satisfied = false -end - -function UnaryConstraint:inputsKnown () - return true -end - -function UnaryConstraint:removeFromGraph () - if self.myOutput ~= nil then - self.myOutput:removeConstraint(self) - end - self.satisfied = false -end - --- --- S t a y C o n s t r a i n t --- - --- --- Variables that should, with some level of preference, stay the same. --- Planners may exploit the fact that instances, if satisfied, will not --- change their output during plan execution. This is called "stay --- optimization". --- - -local StayConstraint = class(UnaryConstraint) - -function StayConstraint:constructor(v, str) - StayConstraint.super.constructor(self, v, str) -end - -function StayConstraint:execute () - -- Stay constraints do nothing -end - --- --- E d i t C o n s t r a i n t --- - --- --- A unary input constraint used to mark a variable that the client --- wishes to change. --- - -local EditConstraint = class (UnaryConstraint) - -function EditConstraint:constructor(v, str) - EditConstraint.super.constructor(self, v, str) -end - --- --- Edits indicate that a variable is to be changed by imperative code. --- -function EditConstraint:isInput () - return true -end - -function EditConstraint:execute () - -- Edit constraints do nothing -end - --- --- B i n a r y C o n s t r a i n t --- - -local Direction = {} -Direction.NONE = 0 -Direction.FORWARD = 1 -Direction.BACKWARD = -1 - --- --- Abstract superclass for constraints having two possible output --- variables. --- - -local BinaryConstraint = class(Constraint) - -function BinaryConstraint:constructor(var1, var2, strength) - BinaryConstraint.super.constructor(self, strength); - self.v1 = var1 - self.v2 = var2 - self.direction = Direction.NONE - self:addConstraint() -end - - --- --- Decides if this constraint can be satisfied and which way it --- should flow based on the relative strength of the variables related, --- and record that decision. --- -function BinaryConstraint:chooseMethod (mark) - if self.v1.mark == mark then - self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE - end - if self.v2.mark == mark then - self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE - end - if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then - self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE - else - self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD - end -end - --- --- Add this constraint to the constraint graph --- -function BinaryConstraint:addToGraph () - self.v1:addConstraint(self) - self.v2:addConstraint(self) - self.direction = Direction.NONE -end - --- --- Answer true if this constraint is satisfied in the current solution. --- -function BinaryConstraint:isSatisfied () - return self.direction ~= Direction.NONE -end - --- --- Mark the input variable with the given mark. --- -function BinaryConstraint:markInputs (mark) - self:input().mark = mark -end - --- --- Returns the current input variable --- -function BinaryConstraint:input () - return (self.direction == Direction.FORWARD) and self.v1 or self.v2 -end - --- --- Returns the current output variable --- -function BinaryConstraint:output () - return (self.direction == Direction.FORWARD) and self.v2 or self.v1 -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this --- constraint. Assume this constraint is satisfied. --- -function BinaryConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength); - out.stay = ihn.stay - if out.stay then self:execute() end -end - --- --- Record the fact that self constraint is unsatisfied. --- -function BinaryConstraint:markUnsatisfied () - self.direction = Direction.NONE -end - -function BinaryConstraint:inputsKnown (mark) - local i = self:input() - return i.mark == mark or i.stay or i.determinedBy == nil -end - -function BinaryConstraint:removeFromGraph () - if (self.v1 ~= nil) then self.v1:removeConstraint(self) end - if (self.v2 ~= nil) then self.v2:removeConstraint(self) end - self.direction = Direction.NONE -end - --- --- S c a l e C o n s t r a i n t --- - --- --- Relates two variables by the linear scaling relationship: "v2 = --- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain --- this relationship but the scale factor and offset are considered --- read-only. --- - -local ScaleConstraint = class (BinaryConstraint) - -function ScaleConstraint:constructor(src, scale, offset, dest, strength) - self.direction = Direction.NONE - self.scale = scale - self.offset = offset - ScaleConstraint.super.constructor(self, src, dest, strength) -end - - --- --- Adds this constraint to the constraint graph. --- -function ScaleConstraint:addToGraph () - ScaleConstraint.super.addToGraph(self) - self.scale:addConstraint(self) - self.offset:addConstraint(self) -end - -function ScaleConstraint:removeFromGraph () - ScaleConstraint.super.removeFromGraph(self) - if (self.scale ~= nil) then self.scale:removeConstraint(self) end - if (self.offset ~= nil) then self.offset:removeConstraint(self) end -end - -function ScaleConstraint:markInputs (mark) - ScaleConstraint.super.markInputs(self, mark); - self.offset.mark = mark - self.scale.mark = mark -end - --- --- Enforce this constraint. Assume that it is satisfied. --- -function ScaleConstraint:execute () - if self.direction == Direction.FORWARD then - self.v2.value = self.v1.value * self.scale.value + self.offset.value - else - self.v1.value = (self.v2.value - self.offset.value) / self.scale.value - end -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function ScaleConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength) - out.stay = ihn.stay and self.scale.stay and self.offset.stay - if out.stay then self:execute() end -end - --- --- E q u a l i t y C o n s t r a i n t --- - --- --- Constrains two variables to have the same value. --- - -local EqualityConstraint = class (BinaryConstraint) - -function EqualityConstraint:constructor(var1, var2, strength) - EqualityConstraint.super.constructor(self, var1, var2, strength) -end - - --- --- Enforce this constraint. Assume that it is satisfied. --- -function EqualityConstraint:execute () - self:output().value = self:input().value -end - --- --- V a r i a b l e --- - --- --- A constrained variable. In addition to its value, it maintain the --- structure of the constraint graph, the current dataflow graph, and --- various parameters of interest to the DeltaBlue incremental --- constraint solver. --- -local Variable = class () - -function Variable:constructor(name, initialValue) - self.value = initialValue or 0 - self.constraints = OrderedCollection.new() - self.determinedBy = nil - self.mark = 0 - self.walkStrength = Strength.WEAKEST - self.stay = true - self.name = name -end - --- --- Add the given constraint to the set of all constraints that refer --- this variable. --- -function Variable:addConstraint (c) - self.constraints:add(c) -end - --- --- Removes all traces of c from this variable. --- -function Variable:removeConstraint (c) - self.constraints:remove(c) - if self.determinedBy == c then - self.determinedBy = nil - end -end - --- --- P l a n n e r --- - --- --- The DeltaBlue planner --- -local Planner = class() -function Planner:constructor() - self.currentMark = 0 -end - --- --- Attempt to satisfy the given constraint and, if successful, --- incrementally update the dataflow graph. Details: If satifying --- the constraint is successful, it may override a weaker constraint --- on its output. The algorithm attempts to resatisfy that --- constraint using some other method. This process is repeated --- until either a) it reaches a variable that was not previously --- determined by any constraint or b) it reaches a constraint that --- is too weak to be satisfied using any of its methods. The --- variables of constraints that have been processed are marked with --- a unique mark value so that we know where we've been. This allows --- the algorithm to avoid getting into an infinite loop even if the --- constraint graph has an inadvertent cycle. --- -function Planner:incrementalAdd (c) - local mark = self:newMark() - local overridden = c:satisfy(mark) - while overridden ~= nil do - overridden = overridden:satisfy(mark) - end -end - --- --- Entry point for retracting a constraint. Remove the given --- constraint and incrementally update the dataflow graph. --- Details: Retracting the given constraint may allow some currently --- unsatisfiable downstream constraint to be satisfied. We therefore collect --- a list of unsatisfied downstream constraints and attempt to --- satisfy each one in turn. This list is traversed by constraint --- strength, strongest first, as a heuristic for avoiding --- unnecessarily adding and then overriding weak constraints. --- Assume: c is satisfied. --- -function Planner:incrementalRemove (c) - local out = c:output() - c:markUnsatisfied() - c:removeFromGraph() - local unsatisfied = self:removePropagateFrom(out) - local strength = Strength.REQUIRED - repeat - for i = 1, unsatisfied:size() do - local u = unsatisfied:at(i) - if u.strength == strength then - self:incrementalAdd(u) - end - end - strength = strength:nextWeaker() - until strength == Strength.WEAKEST -end - --- --- Select a previously unused mark value. --- -function Planner:newMark () - self.currentMark = self.currentMark + 1 - return self.currentMark -end - --- --- Extract a plan for resatisfaction starting from the given source --- constraints, usually a set of input constraints. This method --- assumes that stay optimization is desired; the plan will contain --- only constraints whose output variables are not stay. Constraints --- that do no computation, such as stay and edit constraints, are --- not included in the plan. --- Details: The outputs of a constraint are marked when it is added --- to the plan under construction. A constraint may be appended to --- the plan when all its input variables are known. A variable is --- known if either a) the variable is marked (indicating that has --- been computed by a constraint appearing earlier in the plan), b) --- the variable is 'stay' (i.e. it is a constant at plan execution --- time), or c) the variable is not determined by any --- constraint. The last provision is for past states of history --- variables, which are not stay but which are also not computed by --- any constraint. --- Assume: sources are all satisfied. --- -local Plan -- FORWARD DECLARATION -function Planner:makePlan (sources) - local mark = self:newMark() - local plan = Plan.new() - local todo = sources - while todo:size() > 0 do - local c = todo:removeFirst() - if c:output().mark ~= mark and c:inputsKnown(mark) then - plan:addConstraint(c) - c:output().mark = mark - self:addConstraintsConsumingTo(c:output(), todo) - end - end - return plan -end - --- --- Extract a plan for resatisfying starting from the output of the --- given constraints, usually a set of input constraints. --- -function Planner:extractPlanFromConstraints (constraints) - local sources = OrderedCollection.new() - for i = 1, constraints:size() do - local c = constraints:at(i) - if c:isInput() and c:isSatisfied() then - -- not in plan already and eligible for inclusion - sources:add(c) - end - end - return self:makePlan(sources) -end - --- --- Recompute the walkabout strengths and stay flags of all variables --- downstream of the given constraint and recompute the actual --- values of all variables whose stay flag is true. If a cycle is --- detected, remove the given constraint and answer --- false. Otherwise, answer true. --- Details: Cycles are detected when a marked variable is --- encountered downstream of the given constraint. The sender is --- assumed to have marked the inputs of the given constraint with --- the given mark. Thus, encountering a marked node downstream of --- the output constraint means that there is a path from the --- constraint's output to one of its inputs. --- -function Planner:addPropagate (c, mark) - local todo = OrderedCollection.new() - todo:add(c) - while todo:size() > 0 do - local d = todo:removeFirst() - if d:output().mark == mark then - self:incrementalRemove(c) - return false - end - d:recalculate() - self:addConstraintsConsumingTo(d:output(), todo) - end - return true -end - - --- --- Update the walkabout strengths and stay flags of all variables --- downstream of the given constraint. Answer a collection of --- unsatisfied constraints sorted in order of decreasing strength. --- -function Planner:removePropagateFrom (out) - out.determinedBy = nil - out.walkStrength = Strength.WEAKEST - out.stay = true - local unsatisfied = OrderedCollection.new() - local todo = OrderedCollection.new() - todo:add(out) - while todo:size() > 0 do - local v = todo:removeFirst() - for i = 1, v.constraints:size() do - local c = v.constraints:at(i) - if not c:isSatisfied() then unsatisfied:add(c) end - end - local determining = v.determinedBy - for i = 1, v.constraints:size() do - local next = v.constraints:at(i); - if next ~= determining and next:isSatisfied() then - next:recalculate() - todo:add(next:output()) - end - end - end - return unsatisfied -end - -function Planner:addConstraintsConsumingTo (v, coll) - local determining = v.determinedBy - local cc = v.constraints - for i = 1, cc:size() do - local c = cc:at(i) - if c ~= determining and c:isSatisfied() then - coll:add(c) - end - end -end - --- --- P l a n --- - --- --- A Plan is an ordered list of constraints to be executed in sequence --- to resatisfy all currently satisfiable constraints in the face of --- one or more changing inputs. --- -Plan = class() -function Plan:constructor() - self.v = OrderedCollection.new() -end - -function Plan:addConstraint (c) - self.v:add(c) -end - -function Plan:size () - return self.v:size() -end - -function Plan:constraintAt (index) - return self.v:at(index) -end - -function Plan:execute () - for i = 1, self:size() do - local c = self:constraintAt(i) - c:execute() - end -end - --- --- M a i n --- - --- --- This is the standard DeltaBlue benchmark. A long chain of equality --- constraints is constructed with a stay constraint on one end. An --- edit constraint is then added to the opposite end and the time is --- measured for adding and removing this constraint, and extracting --- and executing a constraint satisfaction plan. There are two cases. --- In case 1, the added constraint is stronger than the stay --- constraint and values must propagate down the entire length of the --- chain. In case 2, the added constraint is weaker than the stay --- constraint so it cannot be accomodated. The cost in this case is, --- of course, very low. Typical situations lie somewhere between these --- two extremes. --- -local function chainTest(n) - planner = Planner.new() - local prev = nil - local first = nil - local last = nil - - -- Build chain of n equality constraints - for i = 0, n do - local name = "v" .. i; - local v = Variable.new(name) - if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end - if i == 0 then first = v end - if i == n then last = v end - prev = v - end - - StayConstraint.new(last, Strength.STRONG_DEFAULT) - local edit = EditConstraint.new(first, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 0, 99 do - first.value = i - plan:execute() - if last.value ~= i then - alert("Chain test failed.") - end - end -end - -local function change(v, newValue) - local edit = EditConstraint.new(v, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 1, 10 do - v.value = newValue - plan:execute() - end - edit:destroyConstraint() -end - --- --- This test constructs a two sets of variables related to each --- other by a simple linear transformation (scale and offset). The --- time is measured to change a variable on either side of the --- mapping and to change the scale and offset factors. --- -local function projectionTest(n) - planner = Planner.new(); - local scale = Variable.new("scale", 10); - local offset = Variable.new("offset", 1000); - local src = nil - local dst = nil; - - local dests = OrderedCollection.new(); - for i = 0, n - 1 do - src = Variable.new("src" .. i, i); - dst = Variable.new("dst" .. i, i); - dests:add(dst); - StayConstraint.new(src, Strength.NORMAL); - ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED); - end - - change(src, 17) - if dst.value ~= 1170 then alert("Projection 1 failed") end - change(dst, 1050) - if src.value ~= 5 then alert("Projection 2 failed") end - change(scale, 5) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 1000 then - alert("Projection 3 failed") - end - end - change(offset, 2000) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 2000 then - alert("Projection 4 failed") - end - end -end - -function test() - local t0 = os.clock() - chainTest(1000); - projectionTest(1000); - local t1 = os.clock() - return t1-t0 -end - -bench.runCode(test, "deltablue") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9cd642cf..07910a0a 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -23,19 +23,17 @@ static std::optional nullCallback(std::string tag, std::op return std::nullopt; } -struct ACFixture : Fixture +template +struct ACFixtureImpl : BaseType { AutocompleteResult autocomplete(unsigned row, unsigned column) { - return Luau::autocomplete(frontend, "MainModule", Position{row, column}, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker) { - auto i = markerPosition.find(marker); - LUAU_ASSERT(i != markerPosition.end()); - const Position& pos = i->second; - return Luau::autocomplete(frontend, "MainModule", pos, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), nullCallback); } CheckResult check(const std::string& source) @@ -45,16 +43,18 @@ struct ACFixture : Fixture filteredSource.reserve(source.size()); Position curPos(0, 0); + char prevChar{}; for (char c : source) { - if (c == '@' && !filteredSource.empty()) + if (prevChar == '@') { - char prevChar = filteredSource.back(); - filteredSource.pop_back(); - curPos.column--; // Adjust column position since we removed a character from the output - LUAU_ASSERT("Illegal marker character" && prevChar >= '0' && prevChar <= '9'); - LUAU_ASSERT("Duplicate marker found" && markerPosition.count(prevChar) == 0); - markerPosition.insert(std::pair{prevChar, curPos}); + LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); + markerPosition.insert(std::pair{c, curPos}); + } + else if (c == '@') + { + // skip the '@' character } else { @@ -69,22 +69,39 @@ struct ACFixture : Fixture curPos.column++; } } + prevChar = c; } + LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); return Fixture::check(filteredSource); } + const Position& getPosition(char marker) const + { + auto i = markerPosition.find(marker); + LUAU_ASSERT(i != markerPosition.end()); + return i->second; + } + // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; +struct ACFixture : ACFixtureImpl +{ +}; + +struct UnfrozenACFixture : ACFixtureImpl +{ +}; + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") { - check(" "); + check(" @1"); - auto ac = autocomplete(0, 1); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("table")); @@ -93,26 +110,26 @@ TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "local_initializer") { - check("local a = "); + check("local a = @1"); - auto ac = autocomplete(0, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); } TEST_CASE_FIXTURE(ACFixture, "leave_numbers_alone") { - check("local a = 3.1"); + check("local a = 3.@11"); - auto ac = autocomplete(0, 12); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "user_defined_globals") { - check("local myLocal = 4; "); + check("local myLocal = 4; @1"); - auto ac = autocomplete(0, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("table")); @@ -124,20 +141,20 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") check(R"( local myLocal = 4 function abc() - local myInnerLocal = 1 - +@1 local myInnerLocal = 1 +@2 end - )"); +@3 )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); - ac = autocomplete(4, 0); + ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("myInnerLocal")); - ac = autocomplete(6, 0); + ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); } @@ -146,10 +163,10 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function") { check(R"( function foo() - end +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("foo")); } @@ -158,11 +175,11 @@ TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") check(R"( local function outer() local function inner() - end +@1 end end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("inner")); CHECK(ac.entryMap.count("outer")); } @@ -171,11 +188,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") { check(R"( local function abc() - +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("table")); @@ -183,11 +200,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") check(R"( local abc = function() - +@1 end )"); - ac = autocomplete(2, 0); + ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); // FIXME: This is actually incorrect! CHECK(ac.entryMap.count("table")); @@ -202,9 +219,9 @@ TEST_CASE_FIXTURE(ACFixture, "global_functions_are_not_scoped_lexically") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("abc")); @@ -220,9 +237,9 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(!ac.entryMap.count("abc")); @@ -233,10 +250,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") check(R"( function abc(test) - end +@1 end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("test")); } @@ -244,11 +261,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = table. -- Line 1 - -- | Column 23 + local a = table.@1 )"); - auto ac = autocomplete(1, 24); + auto ac = autocomplete('1'); CHECK_EQ(16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); @@ -260,10 +276,10 @@ TEST_CASE_FIXTURE(ACFixture, "nested_member_completions") { check(R"( local tbl = { abc = { def = 1234, egh = false } } - tbl.abc. + tbl.abc. @1 )"); - auto ac = autocomplete(2, 17); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("def")); CHECK(ac.entryMap.count("egh")); @@ -274,10 +290,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table") check(R"( local tbl = {} tbl.prop = 5 - tbl. + tbl.@1 )"); - auto ac = autocomplete(3, 12); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -288,10 +304,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") local tbl = {} local inner = { prop = 5 } tbl.inner = inner - tbl.inner. + tbl.inner. @1 )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -302,10 +318,10 @@ TEST_CASE_FIXTURE(ACFixture, "cyclic_table") local abc = {} local def = { abc = abc } abc.def = def - abc.def. + abc.def. @1 )"); - auto ac = autocomplete(4, 17); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); } @@ -315,11 +331,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_union") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 | t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("b2")); } @@ -330,11 +346,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 & t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(3, ac.entryMap.size()); CHECK(ac.entryMap.count("a1")); CHECK(ac.entryMap.count("b2")); @@ -344,20 +360,19 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") TEST_CASE_FIXTURE(ACFixture, "get_string_completions") { check(R"( - local a = ("foo"): -- Line 1 - -- | Column 26 + local a = ("foo"):@1 )"); - auto ac = autocomplete(1, 26); + auto ac = autocomplete('1'); CHECK_EQ(17, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") { - check(""); + check("@1"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -366,12 +381,12 @@ TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_the_very_start_of_the_script") { - check(R"( + check(R"(@1 function aaa() end )"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } @@ -382,11 +397,11 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") local game = { GetService=function(s) return 'hello' end } function a() - game: + game: @1 end )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -396,10 +411,10 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") { check(R"( - if table: + if table: @1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); @@ -411,12 +426,12 @@ TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") check(R"( function getmyscripts() end - g + g@1 getmyscripts() )"); - auto ac = autocomplete(3, 9); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -431,11 +446,11 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") function B() local A = {two=2} - A + A @1 end )"); - auto ac = autocomplete(6, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("A")); @@ -448,12 +463,12 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") TEST_CASE_FIXTURE(ACFixture, "recommend_statement_starting_keywords") { - check(""); - auto ac = autocomplete(0, 0); + check("@1"); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("local")); - check("local i = "); - auto ac2 = autocomplete(0, 10); + check("local i = @1"); + auto ac2 = autocomplete('1'); CHECK(!ac2.entryMap.count("local")); } @@ -464,9 +479,9 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") end - )"); +@1 )"); - auto ac = autocomplete(5, 0); + auto ac = autocomplete('1'); AutocompleteEntry entry = ac.entryMap["continue"]; CHECK(entry.kind == AutocompleteEntryKind::Binding); @@ -480,11 +495,11 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") function foo:bar() end --[[ - foo: + foo:@1 ]] )"); - auto ac = autocomplete(6, 16); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -492,10 +507,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comment") { check(R"( - --!strict + --!strict@1 )"); - auto ac = autocomplete(1, 17); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -505,10 +520,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; check(R"( - --[[ + --[[ @1 )"); - auto ac = autocomplete(1, 13); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -517,129 +532,129 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co { ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[["); + check("--[[@1"); - auto ac = autocomplete(0, 4); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") { check(R"( - for x = + for x @1= )"); - auto ac1 = autocomplete(1, 14); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - for x = 1 + for x =@1 1 )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("do"), 0); CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - for x = 1, 2 + for x = 1,@1 2 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("do"), 1); check(R"( - for x = 1, 2, + for x = 1, @12, )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("do"), 0); CHECK_EQ(ac4.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 + for x = 1, 2, @15 )"); - auto ac5 = autocomplete(1, 22); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("do"), 1); CHECK_EQ(ac5.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 f + for x = 1, 2, 5 f@1 )"); - auto ac6 = autocomplete(1, 25); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.size(), 1); CHECK_EQ(ac6.entryMap.count("do"), 1); check(R"( - for x = 1, 2, 5 do + for x = 1, 2, 5 do @1 )"); - auto ac7 = autocomplete(1, 32); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") { check(R"( - for + for @1 )"); - auto ac1 = autocomplete(1, 12); + auto ac1 = autocomplete('1'); CHECK_EQ(0, ac1.entryMap.size()); check(R"( - for x + for x@1 @2 )"); - auto ac2 = autocomplete(1, 13); + auto ac2 = autocomplete('1'); CHECK_EQ(0, ac2.entryMap.size()); - auto ac2a = autocomplete(1, 14); + auto ac2a = autocomplete('2'); CHECK_EQ(1, ac2a.entryMap.size()); CHECK_EQ(1, ac2a.entryMap.count("in")); check(R"( - for x in y + for x in y@1 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("table"), 1); CHECK_EQ(ac3.entryMap.count("do"), 0); check(R"( - for x in y + for x in y @1 )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.size(), 1); CHECK_EQ(ac4.entryMap.count("do"), 1); check(R"( - for x in f f + for x in f f@1 )"); - auto ac5 = autocomplete(1, 20); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.size(), 1); CHECK_EQ(ac5.entryMap.count("do"), 1); check(R"( - for x in y do + for x in y do @1 )"); - auto ac6 = autocomplete(1, 23); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.count("in"), 0); CHECK_EQ(ac6.entryMap.count("table"), 1); CHECK_EQ(ac6.entryMap.count("end"), 1); CHECK_EQ(ac6.entryMap.count("function"), 1); check(R"( - for x in y do e + for x in y do e@1 )"); - auto ac7 = autocomplete(1, 23); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("in"), 0); CHECK_EQ(ac7.entryMap.count("table"), 1); CHECK_EQ(ac7.entryMap.count("end"), 1); @@ -649,33 +664,33 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") { check(R"( - while + while@1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - while true + while true @1 )"); - auto ac2 = autocomplete(1, 19); + auto ac2 = autocomplete('1'); CHECK_EQ(1, ac2.entryMap.size()); CHECK_EQ(ac2.entryMap.count("do"), 1); check(R"( - while true do + while true do @1 )"); - auto ac3 = autocomplete(1, 23); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("end"), 1); check(R"( - while true d + while true d@1 )"); - auto ac4 = autocomplete(1, 20); + auto ac4 = autocomplete('1'); CHECK_EQ(1, ac4.entryMap.size()); CHECK_EQ(ac4.entryMap.count("do"), 1); } @@ -683,10 +698,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") { check(R"( - if + if @1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("then"), 0); CHECK_EQ(ac1.entryMap.count("function"), 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. @@ -696,10 +711,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - if x + if x @1 )"); - auto ac2 = autocomplete(1, 14); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("then"), 1); CHECK_EQ(ac2.entryMap.count("function"), 0); CHECK_EQ(ac2.entryMap.count("else"), 0); @@ -707,20 +722,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - if x t + if x t@1 )"); - auto ac3 = autocomplete(1, 14); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("then"), 1); check(R"( if x then - +@1 end )"); - auto ac4 = autocomplete(2, 0); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("then"), 0); CHECK_EQ(ac4.entryMap.count("else"), 1); CHECK_EQ(ac4.entryMap.count("function"), 1); @@ -729,11 +744,11 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - t + t@1 end )"); - auto ac4a = autocomplete(2, 13); + auto ac4a = autocomplete('1'); CHECK_EQ(ac4a.entryMap.count("then"), 0); CHECK_EQ(ac4a.entryMap.count("table"), 1); CHECK_EQ(ac4a.entryMap.count("else"), 1); @@ -741,12 +756,12 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - +@1 elseif x then end )"); - auto ac5 = autocomplete(2, 0); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("then"), 0); CHECK_EQ(ac5.entryMap.count("function"), 1); CHECK_EQ(ac5.entryMap.count("else"), 0); @@ -757,10 +772,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") { check(R"( - repeat + repeat @1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("until"), 1); } @@ -769,48 +784,48 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_expression") { check(R"( repeat - until + until @1 )"); - auto ac = autocomplete(2, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_names") { check(R"( - local ab + local ab@1 )"); - auto ac1 = autocomplete(1, 16); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local ab, cd + local ab, cd@1 )"); - auto ac2 = autocomplete(1, 20); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_fn_exprs") { check(R"( - local function f() + local function f() @1 )"); - auto ac = autocomplete(1, 28); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") { check(R"( - local a = function() local bar = foo en + local a = function() local bar = foo en@1 )"); - auto ac = autocomplete(1, 47); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } @@ -818,10 +833,10 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( repeat - for x + for x @1 )"); - auto ac1 = autocomplete(2, 18); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("in"), 1); CHECK_EQ(ac1.entryMap.count("until"), 0); } @@ -829,112 +844,112 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_repeat_middle_keyword") { check(R"( - repeat + repeat @1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); CHECK_EQ(ac1.entryMap.count("until"), 1); check(R"( - repeat f f + repeat f f@1 )"); - auto ac2 = autocomplete(1, 18); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("function"), 1); CHECK_EQ(ac2.entryMap.count("until"), 1); check(R"( repeat - u + u@1 until )"); - auto ac3 = autocomplete(2, 13); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("until"), 0); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local f + local f@1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local f, cd + local f@1, cd )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local function + local function @1 )"); - auto ac = autocomplete(1, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( - local function s + local function @1s@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 24); + ac = autocomplete('2'); CHECK(ac.entryMap.empty()); check(R"( - local function () + local function @1()@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 25); + ac = autocomplete('2'); CHECK(ac.entryMap.count("end")); check(R"( - local function something + local function something@1 )"); - ac = autocomplete(1, 32); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( local tbl = {} - function tbl.something() end + function tbl.something@1() end )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function_params") { check(R"( - local function abc(def) + local function @1a@2bc(@3d@4ef)@5 @6 )"); - CHECK(autocomplete(1, 23).entryMap.empty()); - CHECK(autocomplete(1, 24).entryMap.empty()); - CHECK(autocomplete(1, 27).entryMap.empty()); - CHECK(autocomplete(1, 28).entryMap.empty()); - CHECK(!autocomplete(1, 31).entryMap.empty()); + CHECK(autocomplete('1').entryMap.empty()); + CHECK(autocomplete('2').entryMap.empty()); + CHECK(autocomplete('3').entryMap.empty()); + CHECK(autocomplete('4').entryMap.empty()); + CHECK(!autocomplete('5').entryMap.empty()); - CHECK(!autocomplete(1, 32).entryMap.empty()); + CHECK(!autocomplete('6').entryMap.empty()); check(R"( local function abc(def) - end +@1 end )"); for (unsigned int i = 23; i < 31; ++i) @@ -943,16 +958,16 @@ TEST_CASE_FIXTURE(ACFixture, "local_function_params") } CHECK(!autocomplete(1, 32).entryMap.empty()); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - local function abc(def, ghi) + local function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 35); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } @@ -981,48 +996,48 @@ TEST_CASE_FIXTURE(ACFixture, "global_function_params") check(R"( function abc(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - function abc(def, ghi) + function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 29); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "arguments_to_global_lambda") { check(R"( - abc = function(def, ghi) + abc = function(def, ghi@1) end )"); - auto ac = autocomplete(1, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { check(R"( - abc = function(def) + abc = function(def) @1 )"); for (unsigned int i = 20; i < 27; ++i) { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( - abc = function(def) + abc = function(def) @1 end )"); @@ -1030,25 +1045,25 @@ TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( abc = function(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("def"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_initializer") { check(R"( - local a = t + local a = t@1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("true"), 1); } @@ -1056,20 +1071,20 @@ TEST_CASE_FIXTURE(ACFixture, "local_initializer") TEST_CASE_FIXTURE(ACFixture, "local_initializer_2") { check(R"( - local a= + local a=@1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = 12.3 + local a = 12.@13 )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -1083,21 +1098,21 @@ TEST_CASE_FIXTURE(ACFixture, "sometimes_the_metatable_is_an_error") return setmetatable({x=6}, X) -- oops! end local t = T.new() - t. + t. @1 )"); - autocomplete(8, 12); + autocomplete('1'); // Don't crash! } TEST_CASE_FIXTURE(ACFixture, "local_types_builtin") { check(R"( -local a: n +local a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1108,23 +1123,23 @@ TEST_CASE_FIXTURE(ACFixture, "private_types") check(R"( do type num = number - local a: nu - local b: num + local a: n@1u + local b: nu@2m end -local a: nu +local a: nu@3 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 15); + ac = autocomplete('2'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(6, 11); + ac = autocomplete('3'); CHECK(!ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); @@ -1136,11 +1151,11 @@ TEST_CASE_FIXTURE(ACFixture, "type_scoping_easy") type Table = { a: number, b: number } do type Table = { x: string, y: string } - local a: T + local a: T@1 end )"); - auto ac = autocomplete(4, 14); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("Table")); REQUIRE(ac.entryMap["Table"].type); @@ -1198,11 +1213,11 @@ local a: aaa. TEST_CASE_FIXTURE(ACFixture, "argument_types") { check(R"( -local function f(a: n +local function f(a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1211,11 +1226,11 @@ local b: string = "don't trip" TEST_CASE_FIXTURE(ACFixture, "return_types") { check(R"( -local function f(a: number): n +local function f(a: number): n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1225,10 +1240,10 @@ TEST_CASE_FIXTURE(ACFixture, "as_types") { check(R"( local a: any = 5 -local b: number = (a :: n +local b: number = (a :: n@1 )"); - auto ac = autocomplete(2, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1237,34 +1252,34 @@ local b: number = (a :: n TEST_CASE_FIXTURE(ACFixture, "function_type_types") { check(R"( -local a: (n -local b: (number, (n -local c: (number, (number) -> n -local d: (number, (number) -> (number, n -local e: (n: n +local a: (n@1 +local b: (number, (n@2 +local c: (number, (number) -> n@3 +local d: (number, (number) -> (number, n@4 +local e: (n: n@5 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(2, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(3, 31); + ac = autocomplete('3'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 40); + ac = autocomplete('4'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(5, 14); + ac = autocomplete('5'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1276,11 +1291,11 @@ TEST_CASE_FIXTURE(ACFixture, "generic_types") ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); check(R"( -function f(a: T +function f(a: T@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("Tee")); } @@ -1293,10 +1308,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(o +return target(o@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1307,10 +1322,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(one, t +return target(one, t@1 )"); - ac = autocomplete(5, 20); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1321,10 +1336,10 @@ return target(one, t local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1334,10 +1349,10 @@ return target(a. local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a.one, a. +return target(a.one, a.@1 )"); - ac = autocomplete(4, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1348,10 +1363,10 @@ return target(a.one, a. local function target(a: string?) return #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1363,10 +1378,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { a = a. +local b: Foo = { a = a.@1 )"); - auto ac = autocomplete(3, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1375,10 +1390,10 @@ local b: Foo = { a = a. check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { b = a. +local b: Foo = { b = a.@1 )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1392,10 +1407,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1406,10 +1421,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end -return target(bar1, b +return target(bar1, b@1 )"); - ac = autocomplete(5, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar2")); CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1420,10 +1435,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1433,69 +1448,69 @@ return target(b TEST_CASE_FIXTURE(ACFixture, "type_correct_local_type_suggestion") { check(R"( -local b: s = "str" +local b: s@1 = "str" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return "str" end -local b: s = f() +local b: s@1 = f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: s, c: n = "str", 2 +local b: s@1, c: n@2 = "str", 2 )"); - ac = autocomplete(1, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return 1, "str", 3 end -local a: b, b: n, c: s, d: n = false, f() +local a: b@1, b: n@2, c: s@3, d: n@4 = false, f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 22); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 28); + ac = autocomplete('4'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f(): ...number return 1, 2, 3 end -local a: boolean, b: n = false, f() +local a: boolean, b: n@1 = false, f() )"); - ac = autocomplete(2, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1504,46 +1519,46 @@ local a: boolean, b: n = false, f() TEST_CASE_FIXTURE(ACFixture, "type_correct_function_type_suggestion") { check(R"( -local b: (n) -> number = function(a: number, b: string) return a + #b end +local b: (n@1) -> number = function(a: number, b: string) return a + #b end )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, s = function(a: number, b: string) return a + #b end +local b: (number, s@1 = function(a: number, b: string) return a + #b end )"); - ac = autocomplete(1, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, string) -> b = function(a: number, b: string): boolean return a + #b == 0 end +local b: (number, string) -> b@1 = function(a: number, b: string): boolean return a + #b == 0 end )"); - ac = autocomplete(1, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, ...s) = function(a: number, ...: string) return a end +local b: (number, ...s@1) = function(a: number, ...: string) return a end )"); - ac = autocomplete(1, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" end +local b: (number) -> ...s@1 = function(a: number): ...string return "a", "b", "c" end )"); - ac = autocomplete(1, 25); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1552,24 +1567,24 @@ local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" TEST_CASE_FIXTURE(ACFixture, "type_correct_full_type_suggestion") { check(R"( -local b: = "str" +local b:@1 @2= "str" )"); - auto ac = autocomplete(1, 8); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 9); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: = function(a: number) return -a end +local b: @1= function(a: number) return -a end )"); - ac = autocomplete(1, 9); + ac = autocomplete('1'); CHECK(ac.entryMap.count("(number) -> number")); CHECK(ac.entryMap["(number) -> number"].typeCorrect == TypeCorrectKind::Correct); @@ -1580,12 +1595,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: n, b) +local function d(a: n@1, b) return target(a, b) end )"); - auto ac = autocomplete(3, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1593,12 +1608,12 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: s) +local function d(a, b: s@1) return target(a, b) end )"); - ac = autocomplete(3, 24); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1606,17 +1621,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: , b) +local function d(a:@1 @2, b) return target(a, b) end )"); - ac = autocomplete(3, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1624,17 +1639,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: ): number +local function d(a, b: @1)@2: number return target(a, b) end )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 24); + ac = autocomplete('2'); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::None); } @@ -1644,10 +1659,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion") check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1655,10 +1670,10 @@ local x = target(function(a: check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n +local x = target(function(a: n@1 )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1666,17 +1681,17 @@ local x = target(function(a: n check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n, b: ) +local x = target(function(a: n@1, b: @2) return a + #b end) )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1684,12 +1699,12 @@ end) check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a: n) +local x = target(function(a: n@1) return a end )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1700,12 +1715,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(...:n) +local x = target(function(...:n@1) return a end )"); - auto ac = autocomplete(3, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1713,12 +1728,12 @@ end check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a:number, b:number, ...:) +local x = target(function(a:number, b:number, ...:@1) return a + b end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1729,12 +1744,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") check(R"( local function target(callback: () -> number) return callback() end -local x = target(function(): n +local x = target(function(): n@1 return 1 end )"); - auto ac = autocomplete(3, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1742,12 +1757,12 @@ end check(R"( local function target(callback: () -> (number, number)) return callback() end -local x = target(function(): (number, n +local x = target(function(): (number, n@1 return 1, 2 end )"); - ac = autocomplete(3, 39); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1758,12 +1773,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): ...n +local x = target(function(): ...n@1 return 1, 2, 3 end )"); - auto ac = autocomplete(3, 33); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1771,12 +1786,12 @@ end check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): (number, number, ...n +local x = target(function(): (number, number, ...n@1 return 1, 2, 3 end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1787,10 +1802,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_opt check(R"( local function target(callback: nil | (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1803,21 +1818,21 @@ local t = {} t.x = 5 function t:target(callback: (a: number, b: string) -> number) return callback(self.x, "hello") end -local x = t:target(function(a: , b: ) end) -local y = t.target(t, function(a: number, b: ) end) +local x = t:target(function(a: @1, b:@2 ) end) +local y = t.target(t, function(a: number, b: @3) end) )"); - auto ac = autocomplete(5, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(5, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(6, 45); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1899,26 +1914,26 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_synthetic_table_name") { check(R"( local foo = { a = 1, b = 2 } -local bar: = foo +local bar: @1= foo )"); - auto ac = autocomplete(2, 11); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_function_no_parenthesis") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") { check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(frontend, "MainModule", Position{5, 15}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1930,16 +1945,16 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( local function f(a: { x: number, y: number }) return a.x + a.y end -local fp: = f +local fp: @1= f )"); - auto ac = autocomplete(2, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_keywords") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") { check(R"( local function a(x: boolean) end @@ -1951,33 +1966,33 @@ local function e(x: ((number) -> string) & ((boolean) -> number)) end local tru = {} local ni = false -local ac = a(t) -local bc = b(n) -local cc = c(f) -local dc = d(f) -local ec = e(f) +local ac = a(t@1) +local bc = b(n@2) +local cc = c(f@3) +local dc = d(f@4) +local ec = e(f@5) )"); - auto ac = autocomplete(frontend, "MainModule", Position{10, 14}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{11, 14}, nullCallback); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{12, 14}, nullCallback); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{13, 14}, nullCallback); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{14, 14}, nullCallback); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -1988,10 +2003,10 @@ local target: ((number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2002,10 +2017,10 @@ local target: ((number) -> string) & ((number) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2016,10 +2031,10 @@ local target: ((number, number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(1, o) +return target(1, o@1) )"); - ac = autocomplete(5, 18); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2032,10 +2047,10 @@ TEST_CASE_FIXTURE(ACFixture, "optional_members") local a = { x = 2, y = 3 } type A = typeof(a) local b: A? = a -return b. +return b.@1 )"); - auto ac = autocomplete(4, 9); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2045,10 +2060,10 @@ return b. local a = { x = 2, y = 3 } type A = typeof(a) local b: nil | A = a -return b. +return b.@1 )"); - ac = autocomplete(4, 9); + ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2056,10 +2071,10 @@ return b. check(R"( local b: nil | nil -return b. +return b.@1 )"); - ac = autocomplete(2, 9); + ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -2067,26 +2082,26 @@ return b. TEST_CASE_FIXTURE(ACFixture, "no_function_name_suggestions") { check(R"( -function na +function na@1 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function +local function @1 )"); - ac = autocomplete(1, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function na +local function na@1 )"); - ac = autocomplete(1, 17); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -2095,20 +2110,20 @@ TEST_CASE_FIXTURE(ACFixture, "skip_current_local") { check(R"( local other = 1 -local name = na +local name = na@1 )"); - auto ac = autocomplete(2, 15); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(ac.entryMap.count("other")); check(R"( local other = 1 -local name, test = na +local name, test = na@1 )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(!ac.entryMap.count("test")); @@ -2119,26 +2134,26 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_members") { check(R"( local a = { done = 1, forever = 2 } -local b = a.do -local c = a.for -local d = a. +local b = a.do@1 +local c = a.for@2 +local d = a.@3 do end )"); - auto ac = autocomplete(2, 14); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(3, 15); + ac = autocomplete('2'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(4, 12); + ac = autocomplete('3'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2150,10 +2165,10 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_methods") check(R"( local a = {} function a:done() end -local b = a:do +local b = a:do@1 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2247,29 +2262,29 @@ local elsewhere = false local doover = false local endurance = true -if 1 then -else +if 1 then@1 +else@2 end -while false do +while false do@3 end -repeat +repeat@4 until )"); - auto ac = autocomplete(6, 9); + auto ac = autocomplete('1'); CHECK(ac.entryMap.size() == 1); CHECK(ac.entryMap.count("then")); - ac = autocomplete(7, 4); + ac = autocomplete('2'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); - ac = autocomplete(10, 14); + ac = autocomplete('3'); CHECK(ac.entryMap.count("do")); - ac = autocomplete(13, 6); + ac = autocomplete('4'); CHECK(ac.entryMap.count("do")); // FIXME: ideally we want to handle start and end of all statements as well @@ -2284,11 +2299,11 @@ local elsewhere = false if true then return 1 -el +el@1 end )"); - auto ac = autocomplete(5, 2); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere") == 0); @@ -2300,11 +2315,11 @@ if true then return 1 else return 2 -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); CHECK(ac.entryMap.count("elsewhere")); @@ -2316,10 +2331,10 @@ if true then print("1") elif true then print("2") -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere")); @@ -2360,30 +2375,30 @@ TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { check(R"( type Test = { first: number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - auto ac = autocomplete(2, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Intersection check(R"( type Test = { first: number } & { second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Union check(R"( type Test = { first: number, second: number } | { second: number, third: number } -local t: Test = { s } +local t: Test = { s@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("third")); @@ -2391,60 +2406,60 @@ local t: Test = { s } // No parenthesis suggestion check(R"( type Test = { first: (number) -> number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap["first"].parens == ParenthesesRecommendation::None); // When key is changed check(R"( type Test = { first: number, second: number } -local t: Test = { f = 2 } +local t: Test = { f@1 = 2 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { ["f"] } +local t: Test = { ["f@1"] } )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Not an alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { "f" } +local t: Test = { "f@1" } )"); - ac = autocomplete(2, 20); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("second")); // Skip keys that are already defined check(R"( type Test = { first: number, second: number } -local t: Test = { first = 2, s } +local t: Test = { first = 2, s@1 } )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Don't skip active key check(R"( type Test = { first: number, second: number } -local t: Test = { first } +local t: Test = { first@1 } )"); - ac = autocomplete(2, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); @@ -2452,22 +2467,22 @@ local t: Test = { first } check(R"( local t = { { first = 5, second = 10 }, - { f } + { f@1 } } )"); - ac = autocomplete(3, 7); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); check(R"( local t = { [2] = { first = 5, second = 10 }, - [5] = { f } + [5] = { f@1 } } )"); - ac = autocomplete(3, 13); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); } @@ -2502,15 +2517,15 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") local temp = false local even = true; local a = true -a = if t1@emp then t -a = if temp t2@ -a = if temp then e3@ -a = if temp then even e4@ -a = if temp then even elseif t5@ -a = if temp then even elseif true t6@ -a = if temp then even elseif true then t7@ -a = if temp then even elseif true then temp e8@ -a = if temp then even elseif true then temp else e9@ +a = if t@1emp then t +a = if temp t@2 +a = if temp then e@3 +a = if temp then even e@4 +a = if temp then even elseif t@5 +a = if temp then even elseif true t@6 +a = if temp then even elseif true then t@7 +a = if temp then even elseif true then temp e@8 +a = if temp then even elseif true then temp else e@9 )"); auto ac = autocomplete('1'); @@ -2573,4 +2588,20 @@ a = if temp then even elseif true then temp else e9@ } } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + check(R"( +type A = () -> T... +local a: A<(number, s@1> + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 26bc77f7..29c33f7c 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -32,6 +32,55 @@ std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const return std::nullopt; } +std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) +{ + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "game") + return ModuleInfo{"game"}; + if (g->name == "workspace") + return ModuleInfo{"workspace"}; + if (g->name == "script") + return context ? std::optional(*context) : std::nullopt; + } + else if (AstExprIndexName* i = expr->as(); i && context) + { + if (i->index == "Parent") + { + std::string_view view = context->name; + size_t lastSeparatorIndex = view.find_last_of('/'); + + if (lastSeparatorIndex == std::string_view::npos) + return std::nullopt; + + return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional}; + } + else + { + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + } + else if (AstExprIndexExpr* i = expr->as(); i && context) + { + if (AstExprConstantString* index = i->index->as()) + { + return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional}; + } + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; +} + ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const { return lhs + "/" + ModuleName(rhs); diff --git a/tests/Fixture.h b/tests/Fixture.h index c6294b01..1480a7f6 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -65,6 +65,8 @@ struct TestFileResolver } std::optional fromAstFragment(AstExpr* expr) const override; + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; std::optional getParentModuleName(const ModuleName& name) const override; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3f33a5d1..fbfec636 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -58,6 +58,35 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override + { + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "Modules") + return ModuleInfo{"Modules"}; + + if (g->name == "game") + return ModuleInfo{"game"}; + } + else if (AstExprIndexName* i = expr->as()) + { + if (context) + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; + } + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override { return lhs + "/" + ModuleName(rhs); @@ -528,7 +557,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file") { fileResolver.source["Modules/A"] = R"( local Modules = script - local B = require(Modules.B :: any) + local B = require(Modules.B) :: any )"; CheckResult result = frontend.check("Modules/A"); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c8eff399..a9ed139f 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1400,6 +1400,8 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { + ScopedFastFlag sff("LuauLinterTableMoveZero", true); + LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1417,9 +1419,12 @@ table.remove(t, 0) table.remove(t, #t-1) table.insert(t, string.find("hello", "h")) + +table.move(t, 0, #t, 1, tt) +table.move(t, 1, #t, 0, tt) )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE_EQ(result.warnings.size(), 8); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1429,6 +1434,8 @@ table.insert(t, string.find("hello", "h")) "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[5].text, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); + CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 1b146ed2..18f55d2c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Fixture.h" diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index f3c76d55..931a8403 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index cb03a7bd..a80718e4 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2519,4 +2519,19 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index cbd4af29..b9fd04d6 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") { AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz + require(m) )"); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") AstExprIndexName* value = loc->values.data[0]->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo", result.exprs[value]); + CHECK_EQ("workspace/Foo", result.exprs[value].name); AstExprGlobal* workspace = value->expr->as(); REQUIRE(workspace); REQUIRE(result.exprs.contains(workspace)); - CHECK_EQ("workspace", result.exprs[workspace]); + CHECK_EQ("workspace", result.exprs[workspace].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") @@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz local n = m.Quux + require(n) )"); - REQUIRE_EQ(2, block->body.size); + REQUIRE_EQ(3, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") REQUIRE_EQ(1, local->vars.size); REQUIRE(result.exprs.contains(local->values.data[0])); - CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]); + CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") { AstStatBlock* block = parse(R"( - local M = require(workspace.Game.Thing, workspace.Something.Else) + local M = require(workspace.Game.Thing) )"); REQUIRE_EQ(1, block->body.size); @@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") AstExprCall* call = local->values.data[0]->as(); REQUIRE(call != nullptr); - REQUIRE_EQ(2, call->args.size); - - CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]); - CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls") -{ - AstStatBlock* block = parse(R"( - local R = game:GetService('ReplicatedStorage').Roact - local Roact = require(R) - )"); - REQUIRE_EQ(2, block->body.size); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - AstStatLocal* local = block->body.data[0]->as(); - REQUIRE(local != nullptr); - - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]); - - AstStatLocal* local2 = block->body.data[1]->as(); - REQUIRE(local2 != nullptr); - REQUIRE_EQ(1, local2->values.size); - - AstExprCall* call = local2->values.data[0]->as(); - REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls") -{ - ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true); - - AstStatBlock* block = parse(R"( -local A = require(workspace:WaitForChild('ReplicatedStorage').Content) -local B = require(workspace:FindFirstChild('ReplicatedFirst').Data) - )"); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - REQUIRE_EQ(2, result.requires.size()); - CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first); - CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first); + CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") @@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]); + CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") { AstStatBlock* block = parse(R"( local R = game["Test"] + require(R) )"); - REQUIRE_EQ(1, block->body.size); + REQUIRE_EQ(2, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); AstStatLocal* local = block->body.data[0]->as(); REQUIRE(local != nullptr); - CHECK_EQ("game/Test", result.exprs[local->values.data[0]]); + CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7d68c46..e18bf7cd 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Fixture.h" @@ -416,8 +417,6 @@ function foo(a, b) return a(b) end TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") { - ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true}; - TypeVar tv1{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp new file mode 100644 index 00000000..045f0230 --- /dev/null +++ b/tests/TypeInfer.aliases.test.cpp @@ -0,0 +1,557 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeAliases"); + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type F = () -> F? + local function f() + return f + end + + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") +{ + CheckResult result = check(R"( + --!strict + type Node = { Parent: Node?; } + local node: Node; + node.Parent = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Node?", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: b, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_ERRORS(result); + + // We had a UAF in this example caused by not cloning type function arguments + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + module->internalTypes.clear(); + module->astTypes.clear(); + + // Make sure the error strings don't include "VALUELESS" + for (auto error : module->errors) + CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); +} + +TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") +{ + CheckResult result = check(R"( + type Pair = {first: T, second: U} + local a: Pair + local b: Pair + + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK_EQ("Pair", toString(tm->wantedType)); + CHECK_EQ("Pair", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") +{ + CheckResult result = check(R"( + type A = number + type A = string -- Redefinition of type 'A', previously defined at line 1 + local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = Table + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Wrapped", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = (Table) -> string + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +// Check that recursive intersection type doesn't generate an OOM +TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") +{ + CheckResult result = check(R"( + function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any + end + type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) + _(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") +{ + CheckResult result = check(R"( + local foo: Id = 1 + type Id = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") +{ + const std::string code = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb = aa + )"; + + const std::string expected = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type A = () -> (number, B) + type B = () -> (string, A) + local a: A + local b: B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "generic_param_remap") +{ + const std::string code = R"( + -- An example of a forwarded use of a type that has different type arguments than parameters + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb = aa + )"; + + const std::string expected = R"( + + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + export type Foo = number + type Foo = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "Foo"); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") +{ + ScopedFastFlag sffs3{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node?) + local a: Node + + if node then + a = node.child -- Observe the output of the error message. + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + CHECK_EQ("Node?", toString(e->givenType)); + CHECK_EQ("Node", toString(e->wantedType)); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +{ + fileResolver.source["workspace/A"] = R"( + export type myvec2 = {x: number, y: number} + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + export type myvec3 = {x: number, y: number, z: number} + return {} + )"; + + fileResolver.source["workspace/C"] = R"( + local Foo, Bar = require(workspace.A), require(workspace.B) + + local a: Foo.myvec2 + local b: Bar.myvec3 + )"; + + CheckResult result = frontend.check("workspace/C"); + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; + + REQUIRE(m != nullptr); + + std::optional aTypeId = lookupName(m->getModuleScope(), "a"); + REQUIRE(aTypeId); + const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + REQUIRE(aType); + REQUIRE(aType->props.size() == 2); + + std::optional bTypeId = lookupName(m->getModuleScope(), "b"); + REQUIRE(bTypeId); + const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + REQUIRE(bType); + REQUIRE(bType->props.size() == 3); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +{ + CheckResult result = check("type t10 = typeof(table)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); + CHECK_EQ(toString(ty), "table"); + + const TableTypeVar* ttv = get(ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +local c: Cool = { a = 1, b = "s" } +type NotCool = Cool +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +type NotCool = Cool +local c: Cool = { a = 1, b = "s" } +local d: NotCool = { a = 1, b = "s" } +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + ty = requireType("d"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "NotCool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") +{ + CheckResult result = check(R"( +local c = { a = 1, b = "s" } +type Cool = typeof(c) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK_EQ(ttv->name, "Cool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: number, b: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(follow(*ty1), follow(*ty2)); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: T, b: U, C: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); + + bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return function(obj) return true end +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return {a = 1, b = function(obj) return true end} +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: Forest } + type Forest = {Tree} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- OK because forwarded types are used with their parameters. + type Tree = { data: T, children: Forest } + type Forest = {Tree<{T}>} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- Not OK because forwarded types are used with different types than their parameters. + type Forest = {Tree<{T}>} + type Tree = { data: T, children: Forest } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") +{ + CheckResult result = check(R"( + function f(x) return x[1] end + -- x has type X? for a free type variable X + local x = f ({}) + type ContainsFree = { this: a, that: typeof(x) } + type ContainsContainsFree = { that: ContainsFree } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") +{ + ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; + + CheckResult result = check(R"( + type Array = { [number]: T } + type Tuple = Array + + local p: Tuple + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{number | string}", toString(requireType("p"), {true})); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 46496fdb..8bcb0242 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -30,6 +30,8 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + const std::string code = R"( function f(a) if type(a) == "boolean" then @@ -41,11 +43,11 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") )"; const std::string expected = R"( - function f(a:{fn:()->(free)}): () + function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a elseif a.fn()then - local a2:{fn:()->(free)}=a + local a2:{fn:()->(free,free...)}=a end end )"; @@ -231,16 +233,7 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") local r2 = b == a )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'"); - CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'"); - } + LUAU_REQUIRE_NO_ERRORS(result); } // Belongs in TypeInfer.refinements.test.cpp. @@ -542,6 +535,25 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct } } +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + // Really this should return an error, but it doesn't + LUAU_REQUIRE_NO_ERRORS(result); +} + // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f2ba0ddc..31739cdc 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Fixture.h" @@ -6,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) LUAU_FASTFLAG(LuauOrPredicate) using namespace Luau; @@ -199,16 +199,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x) if type(x) == "vector" then @@ -544,8 +534,6 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local t = {"hello"} local v = t[2] @@ -573,8 +561,6 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | boolean) if type(x) ~= "string" then @@ -593,8 +579,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) if type(x) == "table" then @@ -613,8 +597,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) if type(x) == "function" then @@ -698,8 +680,6 @@ struct RefinementClassFixture : Fixture TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(vec) local X, Y, Z = vec.X, vec.Y, vec.Z @@ -726,8 +706,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: Instance | Vector3) if typeof(x) == "Vector3" then @@ -746,8 +724,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | Instance | Vector3) if type(x) == "userdata" then @@ -766,10 +742,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | string) @@ -789,10 +762,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) @@ -812,10 +782,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( --!nonstrict @@ -839,7 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { ScopedFastFlag sffs[] = { {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, {"LuauTypeGuardPeelsAwaySubclasses", true}, }; @@ -861,8 +827,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type XYCoord = {x: number} & {y: number} local function f(t: XYCoord?) @@ -882,8 +846,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) local function f(g: SomeOverloadedFunction?) @@ -903,8 +865,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(t: {x: number}) if type(t) ~= "table" then @@ -999,10 +959,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(x: any) @@ -1036,10 +993,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local a: (number | string)? @@ -1057,10 +1011,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -1081,10 +1032,7 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(a: string | number | boolean) diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 1d1b2fae..b7f0dc7b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -46,6 +46,21 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") CHECK(tType->props.find("foo") != tType->props.end()); } +TEST_CASE_FIXTURE(Fixture, "augment_nested_table") +{ + CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* tType = getMutable(requireType("t")); + REQUIRE(tType != nullptr); + + REQUIRE(tType->props.find("p") != tType->props.end()); + const TableTypeVar* pType = get(tType->props["p"].type); + REQUIRE(pType != nullptr); + + CHECK(pType->props.find("foo") != pType->props.end()); +} + TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); @@ -260,6 +275,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local a = {} a.x = 99 @@ -272,10 +289,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ(error->key, "y"); + CHECK_EQ("y", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time // CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25})); @@ -328,6 +346,8 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( --!strict function foo(o) @@ -340,14 +360,17 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") LUAU_REQUIRE_ERROR_COUNT(1, result); - UnknownProperty* error = get(result.errors[0]); + MissingProperties* error = get(result.errors[0]); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ("baz", error->key); + CHECK_EQ("baz", error->properties[0]); } TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -359,8 +382,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("baz", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time /* @@ -448,6 +474,73 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local t = { u = {} } + t = { u = { p = 37 } } + t = { u = { q = "hi" } } + local x = t.u.p + local y = t.u.q + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("x"))); + CHECK_EQ("string?", toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") +{ + CheckResult result = check(R"( + --!strict + function get(x) return x.opts["MYOPT"] end + function set(x,y) x.opts["MYOPT"] = y end + local t = { opts = {} } + set(t,37) + local x = get(t) + )"); + + // Currently this errors but it shouldn't, since set only needs write access + // TODO: file a JIRA for this + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + function f(x : { q : number }) + x.q = 8 + end + local t : { q : number, r : string } = { q = 8, r = "hi" } + f(t) + local x : string = t.r + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance") +{ + CheckResult result = check(R"( + --!strict + function f(x : { p : { q : number }}) + x.p = { q = 8, r = 5 } + end + local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } } + f(t) -- Shouldn't typecheck + local x : string = t.p.r -- x is 5 + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "infer_array") { CheckResult result = check(R"( @@ -676,16 +769,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer") +TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); +} - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm != nullptr); +TEST_CASE_FIXTURE(Fixture, "array_factory_function") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + function empty() return {} end + local array: {string} = empty() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") @@ -756,37 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t: { a: string } - function f(x: string) return t[x] end - local a = f("a") - local b = f("b") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.anyType, *requireType("a")); - CHECK_EQ(*typeChecker.anyType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t = { a = true } - function f(x: number) return t[x] end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); -} - TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -1392,6 +1465,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1402,8 +1477,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("string", toString(tm->wantedType, o)); - CHECK_EQ("number", toString(tm->givenType, o)); + CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}, i: string) + return a[i] + end + local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" + )"); + + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") @@ -1446,22 +1534,32 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {x = 1, y = 2, z = 3} - local vec1 = {x = 1} + local vec3 = {{x = 1, y = 2, z = 3}} + local vec1 = {{x = 1}} vec1 = vec3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Extra); - REQUIRE_EQ(2, mp->properties.size()); - CHECK_EQ(mp->properties[0], "y"); - CHECK_EQ(mp->properties[1], "z"); - CHECK_EQ("vec1", toString(mp->superType)); - CHECK_EQ("vec3", toString(mp->subType)); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("vec1", toString(tm->wantedType)); + CHECK_EQ("vec3", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec1 = vec3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") @@ -1824,4 +1922,32 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 37333b19..b75878b7 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" @@ -978,23 +979,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type F = () -> F? - local function f() - return f - end - - local g: F = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); -} - // TODO: File a Jira about this /* TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") @@ -1257,23 +1241,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") REQUIRE_EQ(follow(*methodArg), follow(arg)); } -TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") -{ - CheckResult result = check(R"( - --!strict - type Node = { Parent: Node?; } - local node: Node; - node.Parent = 1 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( @@ -2591,48 +2558,6 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: a, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = "lo", i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: b, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = 5, i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_ERRORS(result); - - // We had a UAF in this example caused by not cloning type function arguments - ModulePtr module = frontend.moduleResolver.getModule("MainModule"); - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - module->internalTypes.clear(); - module->astTypes.clear(); - - // Make sure the error strings don't include "VALUELESS" - for (auto error : module->errors) - CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); -} - TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 @@ -3369,16 +3294,7 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") end )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") @@ -3388,18 +3304,8 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable print((x == true and (x .. "y")) .. 1) )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0])); - CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") @@ -3511,25 +3417,6 @@ _(...)(...,setfenv,_):_G() )"); } -TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") -{ - CheckResult result = check(R"( - type Pair = {first: T, second: U} - local a: Pair - local b: Pair - - a = b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - - CHECK_EQ("Pair", toString(tm->wantedType)); - CHECK_EQ("Pair", toString(tm->givenType)); -} - TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") { // this has a risk of creating cyclic type packs, causing infinite loops / OOMs @@ -3639,17 +3526,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") -{ - CheckResult result = check(R"( - type A = number - type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { CheckResult result = check(R"( @@ -3752,38 +3628,6 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); } -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = Table - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = (Table) -> string - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") { CheckResult result = check(R"( @@ -3909,19 +3753,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -// Check that recursive intersection type doesn't generate an OOM -TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") -{ - CheckResult result = check(R"( - function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any - end - type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) - _(_) - )"); - - LUAU_REQUIRE_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed @@ -3974,16 +3805,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") -{ - CheckResult result = check(R"( - local foo: Id = 1 - type Id = T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { CheckResult result = check(R"( @@ -4014,81 +3835,6 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") -{ - const std::string code = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb = aa - )"; - - const std::string expected = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type A = () -> (number, B) - type B = () -> (string, A) - local a: A - local b: B - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); - CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "generic_param_remap") -{ - const std::string code = R"( - -- An example of a forwarded use of a type that has different type arguments than parameters - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb = aa - )"; - - const std::string expected = R"( - - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") -{ - CheckResult result = check(R"( - export type Foo = number - type Foo = number - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto dtd = get(result.errors[0]); - REQUIRE(dtd); - CHECK_EQ(dtd->name, "Foo"); -} - TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( @@ -4193,30 +3939,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } -TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") -{ - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - - CheckResult result = check(R"( - type Node = { value: T, child: Node? } - - local function visitor(node: Node?) - local a: Node - - if node then - a = node.child -- Observe the output of the error message. - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto e = get(result.errors[0]); - CHECK_EQ("Node?", toString(e->givenType)); - CHECK_EQ("Node", toString(e->wantedType)); -} - TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") { CheckResult result = check(R"( @@ -4272,181 +3994,6 @@ local tbl: string = require(game.A) CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") -{ - fileResolver.source["workspace/A"] = R"( - export type myvec2 = {x: number, y: number} - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - export type myvec3 = {x: number, y: number, z: number} - return {} - )"; - - fileResolver.source["workspace/C"] = R"( - local Foo, Bar = require(workspace.A), require(workspace.B) - - local a: Foo.myvec2 - local b: Bar.myvec3 - )"; - - CheckResult result = frontend.check("workspace/C"); - LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; - - REQUIRE(m != nullptr); - - std::optional aTypeId = lookupName(m->getModuleScope(), "a"); - REQUIRE(aTypeId); - const Luau::TableTypeVar* aType = get(follow(*aTypeId)); - REQUIRE(aType); - REQUIRE(aType->props.size() == 2); - - std::optional bTypeId = lookupName(m->getModuleScope(), "b"); - REQUIRE(bTypeId); - const Luau::TableTypeVar* bType = get(follow(*bTypeId)); - REQUIRE(bType); - REQUIRE(bType->props.size() == 3); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") -{ - CheckResult result = check("type t10 = typeof(table)"); - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); - CHECK_EQ(toString(ty), "table"); - - const TableTypeVar* ttv = get(ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -local c: Cool = { a = 1, b = "s" } -type NotCool = Cool -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -type NotCool = Cool -local c: Cool = { a = 1, b = "s" } -local d: NotCool = { a = 1, b = "s" } -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - ty = requireType("d"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "NotCool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") -{ - CheckResult result = check(R"( -local c = { a = 1, b = "s" } -type Cool = typeof(c) -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK_EQ(ttv->name, "Cool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: number, b: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(follow(*ty1), follow(*ty2)); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: T, b: U, C: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); - - bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); - CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); -} - TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( @@ -4560,32 +4107,6 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return function(obj) return true end -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return {a = 1, b = function(obj) return true end} -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "custom_require_global") { CheckResult result = check(R"( @@ -4768,8 +4289,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") { - ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true}; - fileResolver.source["Module/Backend/Types"] = R"( export type Fiber = { return_: Fiber? @@ -4849,8 +4368,8 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") ModulePtr module = getMainModule(); auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it != module->astOverloadResolvedTypes.end()); - CHECK_EQ(toString(it->second), "(number) -> number"); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -5013,76 +4532,6 @@ g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: Forest } - type Forest = {Tree} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- OK because forwarded types are used with their parameters. - type Tree = { data: T, children: Forest } - type Forest = {Tree<{T}>} - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- Not OK because forwarded types are used with different types than their parameters. - type Forest = {Tree<{T}>} - type Tree = { data: T, children: Forest } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") -{ - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") -{ - CheckResult result = check(R"( - function f(x) return x[1] end - -- x has type X? for a free type variable X - local x = f ({}) - type ContainsFree = { this: a, that: typeof(x) } - type ContainsContainsFree = { that: ContainsFree } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 91ac9f06..1f4b63ef 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5f7f2847..3e1dedd4 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -294,4 +294,370 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T...) -> T... +local a: Packed<> +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "(T...) -> (T...)"); + CHECK_EQ(toString(requireType("a")), "() -> ()"); + CHECK_EQ(toString(requireType("b")), "(number) -> number"); + CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)"); + + result = check(R"( +-- (U..., T) cannot be parsed right now +type Packed = { f: (a: T, U...) -> (T, U...) } +local a: Packed +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); + + auto ttvA = get(requireType("a")); + REQUIRE(ttvA); + CHECK_EQ(toString(requireType("a")), "Packed"); + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + REQUIRE(ttvA->instantiatedTypeParams.size() == 1); + REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); + CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), ""); + + auto ttvB = get(requireType("b")); + REQUIRE(ttvB); + CHECK_EQ(toString(requireType("b")), "Packed"); + CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); + REQUIRE(ttvB->instantiatedTypeParams.size() == 1); + REQUIRE(ttvB->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number"); + + auto ttvC = get(requireType("c")); + REQUIRE(ttvC); + CHECK_EQ(toString(requireType("c")), "Packed"); + CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); + REQUIRE(ttvC->instantiatedTypeParams.size() == 1); + REQUIRE(ttvC->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +local a: Import.Packed +local b: Import.Packed +local c: Import.Packed +local d: { a: typeof(c) } + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + auto tf = lookupImportedType("Import", "Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); + CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); + CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult cResult = check(R"( +local Import = require(game.A) +type Alias = Import.Packed +local a: Alias + +type B = Import.Packed +type C = Import.Packed + )"); + LUAU_REQUIRE_NO_ERRORS(cResult); + + auto tf = lookupType("Alias"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Alias"); + CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + + tf = lookupType("B"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "B"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); + + tf = lookupType("C"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "C"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed1 = (T...) -> (T...) +type Packed2 = (Packed1, T...) -> (Packed1, T...) +type Packed3 = (Packed2, T...) -> (Packed2, T...) +type Packed4 = (Packed3, T...) -> (Packed3, T...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed4"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> " + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (string, T...) + +type D = X<...number> +type E = X<(number, ...string)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)"); + CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) +type A = Y +type B = Y<(number, ...string), S...> + +type Z = (T) -> (U...) +type E = Z +type F = Z + +type W = (T, U...) -> (T, V...) +type H = W +type I = W + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)"); + + CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)"); + CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)"); + + CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)"); + CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (T...) + +type A = X<(S...)> +type B = X<()> +type C = X<(number)> +type D = X<(number, string)> +type E = X<(...number)> +type F = X<(string, ...number)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> number"); + CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)"); + CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)"); + CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) + +type A = Y<(number, string), (boolean)> +type B = Y<(), ()> +type C = Y<...string, (number, S...)> +type D = Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)"); + CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block + + CheckResult result = check(R"( +type Y = { f: (T...) -> (U...) } + +local a: Y<(number, string), (boolean)> +local b: Y<(), ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>"); + CHECK_EQ(toString(requireType("b")), "Y<(), ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = () -> T +type Y = (T) -> U + +type A = X<(number)> +type B = Y<(number), (boolean)> +type C = Y<(number), boolean> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "() -> number"); + CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T, U) -> (V...) +local b: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects at least 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T, U) -> () +type B = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 0 type pack arguments, but 1 is specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters"); + + result = check(R"( +type Packed = (T) -> U +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T...) -> T... +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but none are specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ae4d836b..037144e2 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -237,21 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") local z = a == c )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.booleanType, *requireType("x")); - CHECK_EQ(*typeChecker.booleanType, *requireType("y")); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(number | string)?", toString(*tm->wantedType)); - CHECK_EQ("boolean | number", toString(*tm->givenType)); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "optional_union_members") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 98ce9f93..a679e3fd 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tools/tracegraph.py b/tools/tracegraph.py new file mode 100644 index 00000000..a46423e7 --- /dev/null +++ b/tools/tracegraph.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file +# The result of analysis is a .svg file which can be viewed in a browser + +import sys +import svg +import json + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.caption = "" + self.description = "" + self.ticks = 0 + + def text(self): + return self.caption + + def title(self): + return self.caption + + def details(self, root): + return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks) + +with open(sys.argv[1]) as f: + dump = f.read() + +root = Node() + +# Finish the file +if not dump.endswith("]"): + dump += "{}]" + +data = json.loads(dump) + +stacks = {} + +for l in data: + if len(l) == 0: + continue + + # Track stack of each thread, but aggregate values together + tid = l["tid"] + + if not tid in stacks: + stacks[tid] = [] + stack = stacks[tid] + + if l["ph"] == 'B': + stack.append(l) + elif l["ph"] == 'E': + node = root + + for e in stack: + caption = e["name"] + description = '' + + if "args" in e: + for arg in e["args"]: + if len(description) != 0: + description += ", " + + description += "{}: {}".format(arg, e["args"][arg]) + + child = node.child(caption + description) + child.caption = caption + child.description = description + + node = child + + begin = stack[-1] + + ticks = l["ts"] - begin["ts"] + rawticks = ticks + + # Flame graph requires ticks without children duration + if "childts" in begin: + ticks -= begin["childts"] + + node.ticks += int(ticks) + + stack.pop() + + if len(stack): + parent = stack[-1] + + if "childts" in parent: + parent["childts"] += rawticks + else: + parent["childts"] = rawticks + +svg.layout(root, lambda n: n.ticks) +svg.display(root, "Flame Graph", "hot", flip = True) From 34cf695fbc35eb435dcd9fb85c3b98234fdd266c Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 4 Nov 2021 19:42:00 -0700 Subject: [PATCH 002/102] Sync to upstream/release/503 - A series of major optimizations to type checking performance on complex programs/types (up to two orders of magnitude speedup for programs involving huge tagged unions) - Fix a few issues encountered by UBSAN (and maybe fix s390x builds) - Fix gcc-11 test builds - Fix a rare corner case where luau_load wouldn't wake inactive threads which could result in a use-after-free due to GC - Fix CLI crash when error object that's not a string escapes to top level --- Analysis/include/Luau/BuiltinDefinitions.h | 1 - Analysis/include/Luau/Quantify.h | 14 + Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/TopoSortStatements.h | 1 + Analysis/include/Luau/TxnLog.h | 30 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeVar.h | 9 +- Analysis/include/Luau/Unifier.h | 20 +- Analysis/include/Luau/UnifierSharedState.h | 44 ++ Analysis/include/Luau/VisitTypeVar.h | 59 +- Analysis/src/Autocomplete.cpp | 5 +- Analysis/src/BuiltinDefinitions.cpp | 12 - Analysis/src/Error.cpp | 21 +- Analysis/src/Frontend.cpp | 11 +- Analysis/src/Module.cpp | 18 +- Analysis/src/Quantify.cpp | 90 +++ Analysis/src/RequireTracer.cpp | 6 +- Analysis/src/ToString.cpp | 25 +- Analysis/src/TopoSortStatements.cpp | 25 + Analysis/src/TxnLog.cpp | 37 +- Analysis/src/TypeAttach.cpp | 93 ++- Analysis/src/TypeInfer.cpp | 203 ++--- Analysis/src/TypeVar.cpp | 86 ++- Analysis/src/Unifier.cpp | 349 ++++++++- Ast/include/Luau/TimeTrace.h | 16 +- Ast/src/Parser.cpp | 2 +- Ast/src/TimeTrace.cpp | 5 +- CLI/Repl.cpp | 24 +- Compiler/include/Luau/Bytecode.h | 4 +- Compiler/src/Compiler.cpp | 14 +- Makefile | 2 + Sources.cmake | 3 + VM/include/lualib.h | 2 +- VM/src/lapi.cpp | 4 +- VM/src/laux.cpp | 2 +- VM/src/lcorolib.cpp | 72 +- VM/src/ldo.cpp | 14 +- VM/src/ldo.h | 2 +- VM/src/lfunc.cpp | 2 +- VM/src/lgc.cpp | 333 ++++---- VM/src/lgc.h | 8 +- VM/src/lmem.cpp | 2 +- VM/src/lstring.cpp | 2 +- VM/src/lvmload.cpp | 4 +- bench/tests/chess.lua | 849 +++++++++++++++++++++ bench/tests/shootout/scimark.lua | 2 +- tests/Autocomplete.test.cpp | 24 +- tests/Compiler.test.cpp | 84 +- tests/Conformance.test.cpp | 13 + tests/IostreamOptional.h | 7 +- tests/Linter.test.cpp | 14 +- tests/TypeInfer.aliases.test.cpp | 50 ++ tests/TypeInfer.builtins.test.cpp | 2 +- tests/TypeInfer.classes.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 21 + tests/TypeInfer.provisional.test.cpp | 25 +- tests/TypeInfer.refinements.test.cpp | 11 +- tests/TypeInfer.tables.test.cpp | 2 +- tests/TypeInfer.test.cpp | 62 +- tests/TypeInfer.tryUnify.test.cpp | 10 +- tests/TypeInfer.typePacks.cpp | 6 +- tests/TypeInfer.unionTypes.test.cpp | 25 +- tests/TypeVar.test.cpp | 60 ++ tests/conformance/closure.lua | 2 +- tests/conformance/coroutine.lua | 2 +- tests/conformance/gc.lua | 2 +- tests/conformance/locals.lua | 2 +- tests/conformance/math.lua | 2 +- tests/conformance/pm.lua | 4 +- tests/conformance/tmerror.lua | 15 + tools/gdb-printers.py | 8 +- 71 files changed, 2304 insertions(+), 687 deletions(-) create mode 100644 Analysis/include/Luau/Quantify.h create mode 100644 Analysis/include/Luau/UnifierSharedState.h create mode 100644 Analysis/src/Quantify.cpp create mode 100644 bench/tests/chess.lua create mode 100644 tests/conformance/tmerror.lua diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 57a1907a..07d897b2 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachFunctionTag(TypeId ty, std::string constraint); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h new file mode 100644 index 00000000..f46df146 --- /dev/null +++ b/Analysis/include/Luau/Quantify.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +namespace Luau +{ + +struct Module; +using ModulePtr = std::shared_ptr; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 0897ec85..e5683fc4 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); void dump(TypeId ty); void dump(TypePackId ty); +std::string generateName(size_t n); + } // namespace Luau diff --git a/Analysis/include/Luau/TopoSortStatements.h b/Analysis/include/Luau/TopoSortStatements.h index 751694f0..4a4acfa3 100644 --- a/Analysis/include/Luau/TopoSortStatements.h +++ b/Analysis/include/Luau/TopoSortStatements.h @@ -12,6 +12,7 @@ struct AstArray; class AstStat; bool containsFunctionCall(const AstStat& stat); +bool containsFunctionCallOrReturn(const AstStat& stat); bool isFunction(const AstStat& stat); void toposort(std::vector& stats); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 055441ce..322abd19 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -3,19 +3,37 @@ #include "Luau/TypeVar.h" +LUAU_FASTFLAG(LuauShareTxnSeen); + namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be struct TxnLog { - TxnLog() = default; - - explicit TxnLog(const std::vector>& seen) - : seen(seen) + TxnLog() + : originalSeenSize(0) + , ownedSeen() + , sharedSeen(&ownedSeen) { } + explicit TxnLog(std::vector>* sharedSeen) + : originalSeenSize(sharedSeen->size()) + , ownedSeen() + , sharedSeen(sharedSeen) + { + } + + explicit TxnLog(const std::vector>& ownedSeen) + : originalSeenSize(ownedSeen.size()) + , ownedSeen(ownedSeen) + , sharedSeen(nullptr) + { + // This is deprecated! + LUAU_ASSERT(!FFlag::LuauShareTxnSeen); + } + TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; @@ -38,9 +56,11 @@ private: std::vector> typeVarChanges; std::vector> typePackChanges; std::vector>> tableChanges; + size_t originalSeenSize; public: - std::vector> seen; // used to avoid infinite recursion when types are cyclic + std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic + std::vector>* sharedSeen; // shared with all the descendent logs }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d701eb24..9d62fef0 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -11,6 +11,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" +#include "Luau/UnifierSharedState.h" #include #include @@ -121,7 +122,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); @@ -336,7 +337,7 @@ private: // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -383,6 +384,8 @@ public: std::function prepareModuleScope; InternalErrorReporter* iceHandler; + UnifierSharedState unifierState; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d4e4e491..9611e881 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -405,7 +405,7 @@ const std::string* getName(TypeId type); // Checks whether a union contains all types of another union. bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); -// Checks if a type conains generic type binders +// Checks if a type contains generic type binders bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders @@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); +void attachTag(TypeId ty, const std::string& tagName); +void attachTag(Property& prop, const std::string& tagName); + +bool hasTag(TypeId ty, const std::string& tagName); +bool hasTag(const Property& prop, const std::string& tagName); +bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. + } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 522914b2..56632e33 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -6,6 +6,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/UnifierSharedState.h" #include @@ -41,11 +42,14 @@ struct Unifier std::shared_ptr counters_DEPRECATED; - InternalErrorReporter* iceHandler; + UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. @@ -69,7 +73,8 @@ private: void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); - TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + void cacheResult(TypeId superTy, TypeId subTy); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -101,8 +106,9 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - DenseHashSet tempSeenTy{nullptr}; - DenseHashSet tempSeenTp{nullptr}; + // Remove with FFlagLuauCacheUnifyTableResults + DenseHashSet tempSeenTy_DEPRECATED{nullptr}; + DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h new file mode 100644 index 00000000..f252a004 --- /dev/null +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -0,0 +1,44 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ +struct InternalErrorReporter; + +struct TypeIdPairHash +{ + size_t hashOne(Luau::TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +struct UnifierSharedState +{ + UnifierSharedState(InternalErrorReporter* iceHandler) + : iceHandler(iceHandler) + { + } + + InternalErrorReporter* iceHandler; + + DenseHashSet seenAny{nullptr}; + DenseHashMap skipCacheForType{nullptr}; + DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index df0bd420..a866655c 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/TypeVar.h" #include "Luau/TypePack.h" +LUAU_FASTFLAG(LuauCacheUnifyTableResults) + namespace Luau { @@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set& seen, const void* tv) return !seen.insert(ttv).second; } +inline bool hasSeen(DenseHashSet& seen, const void* tv) +{ + void* ttv = const_cast(tv); + + if (seen.contains(ttv)) + return true; + + seen.insert(ttv); + return false; +} + inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen); +inline void unsee(DenseHashSet& seen, const void* tv) +{ + // When DenseHashSet is used for 'visitOnce', where don't forget visited elements +} -template -void visit(TypeId ty, F& f, std::unordered_set& seen) +template +void visit(TypePackId tp, F& f, Set& seen); + +template +void visit(TypeId ty, F& f, Set& seen) { if (visit_detail::hasSeen(seen, ty)) { @@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) else if (auto ttv = get(ty)) { + // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) + if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); + visit(*ttv->boundTo, f, seen); + } + else + { + for (auto& [_name, prop] : ttv->props) + visit(prop.type, f, seen); + + if (ttv->indexer) + { + visit(ttv->indexer->indexType, f, seen); + visit(ttv->indexer->indexResultType, f, seen); + } } } } @@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) visit_detail::unsee(seen, ty); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen) +template +void visit(TypePackId tp, F& f, Set& seen) { if (visit_detail::hasSeen(seen, tp)) { @@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set& seen) visit_detail::unsee(seen, tp); } + } // namespace visit_detail template @@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f) visit_detail::visit(ty, f, seen); } +template +void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) +{ + seen.clear(); + visit_detail::visit(ty, f, seen); +} + } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 235abf36..3c43c808 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { InternalErrorReporter iceReporter; - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter); + UnifierSharedState unifierState(&iceReporter); + Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); unifier.tryUnify(expectedType, actualType); @@ -1460,7 +1461,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M result.erase(std::string(stringKey->value.data, stringKey->value.size)); } - // If we know for sure that a key is being written, do not offer general epxression suggestions + // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 3b0c2163..f6f2363c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } -void attachFunctionTag(TypeId ty, std::string tag) -{ - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(std::move(tag)); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } -} - Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 92fbffc8..04d91444 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -158,21 +158,20 @@ struct ErrorConverter std::string operator()(const Luau::CountMismatch& e) const { + const std::string expectedS = e.expected == 1 ? "" : "s"; + const std::string actualS = e.actual == 1 ? "" : "s"; + const std::string actualVerb = e.actual == 1 ? "is" : "are"; + switch (e.context) { case CountMismatch::Return: - { - const std::string expectedS = e.expected == 1 ? "" : "s"; - const std::string actualS = e.actual == 1 ? "is" : "are"; - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualS + - " returned here"; - } + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + + std::to_string(e.actual) + " " + actualVerb + " returned here"; case CountMismatch::Result: - if (e.expected > e.actual) - return "Function returns " + std::to_string(e.expected) + " values but there are only " + std::to_string(e.expected) + - " values to unpack them into."; - else - return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; + // It is alright if right hand side produces more values than the + // left hand side accepts. In this context consider only the opposite case. + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b2529840..5e7af50c 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -248,7 +249,7 @@ struct RequireCycle // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) std::vector getRequireCycles( - const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -282,9 +283,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(node->name); + cycle.push_back(resolver->getHumanReadableModuleName(node->name)); - cycle.push_back(top->name); + cycle.push_back(resolver->getHumanReadableModuleName(top->name)); break; } } @@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name) // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); + requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); // This is used by the type checker to replace the resulting type of cyclic modules with any sourceModule.cyclic = !requireCycles.empty(); @@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + if (FFlag::LuauClearScopes) + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index df6be767..2fd95896 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) namespace Luau { @@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { + // If table is now bound to another one, we ignore the content of the original + if (FFlag::LuauCloneBoundTables && t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + seenTypes[typeId] = boundTo; + return; + } + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + if (!FFlag::LuauCloneBoundTables) + { + if (t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + } for (TypeId& arg : ttv->instantiatedTypeParams) arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t) if (ttv->state == TableState::Free) { - if (!t.boundTo) + if (FFlag::LuauCloneBoundTables || !t.boundTo) { if (encounteredFreeType) *encounteredFreeType = true; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp new file mode 100644 index 00000000..bf6d81aa --- /dev/null +++ b/Analysis/src/Quantify.cpp @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Quantify.h" + +#include "Luau/VisitTypeVar.h" + +namespace Luau +{ + +struct Quantifier +{ + ModulePtr module; + TypeLevel level; + std::vector generics; + std::vector genericPacks; + + Quantifier(ModulePtr module, TypeLevel level) + : module(module) + , level(level) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); + + return false; + } + + template + bool operator()(TypeId ty, const T& t) + { + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + if (!level.subsumes(ttv.level)) + return false; + + if (ttv.state == TableState::Free) + ttv.state = TableState::Generic; + else if (ttv.state == TableState::Unsealed) + ttv.state = TableState::Sealed; + + ttv.level = level; + + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + if (!level.subsumes(ftp.level)) + return false; + + *asMutable(tp) = GenericTypePack{level}; + genericPacks.push_back(tp); + return true; + } +}; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level) +{ + Quantifier q{std::move(module), level}; + visitTypeVar(ty, q); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; +} + +} // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index ad4d5ef4..95910b56 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -171,7 +171,7 @@ struct RequireTracerOld : AstVisitor result.exprs[call] = {fileResolver->concat(*rootName, v)}; - // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime + // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") result.exprs[call].optional = true; @@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor struct RequireTracer : AstVisitor { - RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) : result(result) , fileResolver(fileResolver) , currentModuleName(currentModuleName) @@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor // seed worklist with require arguments work.reserve(requires.size()); - for (AstExprCall* require: requires) + for (AstExprCall* require : requires) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5651af7e..cd8180db 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -159,15 +158,6 @@ struct StringifierState seen.erase(iter); } - static std::string generateName(size_t i) - { - std::string n; - n = char('a' + i % 26); - if (i >= 26) - n += std::to_string(i / 26); - return n; - } - std::string getName(TypeId ty) { const size_t s = result.nameMap.typeVars.size(); @@ -584,8 +574,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : &uv) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); if (isNil(el)) { @@ -649,8 +638,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : uv.parts) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); std::string saved = std::move(state.result.name); @@ -1204,4 +1192,13 @@ void dump(TypePackId ty) printf("%s\n", toString(ty, opts).c_str()); } +std::string generateName(size_t i) +{ + std::string n; + n = char('a' + i % 26); + if (i >= 26) + n += std::to_string(i / 26); + return n; +} + } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 2d356384..dba694be 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor struct ContainsFunctionCall : public AstVisitor { + bool alsoReturn = false; bool result = false; + ContainsFunctionCall() = default; + explicit ContainsFunctionCall(bool alsoReturn) + : alsoReturn(alsoReturn) + { + } + bool visit(AstExpr*) override { return !result; // short circuit if result is true @@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor return false; } + bool visit(AstStatReturn* stat) override + { + if (alsoReturn) + { + result = true; + return false; + } + else + return AstVisitor::visit(stat); + } + bool visit(AstExprFunction*) override { return false; @@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat) return cfc.result; } +bool containsFunctionCallOrReturn(const AstStat& stat) +{ + detail::ContainsFunctionCall cfc{true}; + const_cast(stat).visit(&cfc); + return cfc.result; +} + bool isFunction(const AstStat& stat) { return stat.is() || stat.is(); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 702d0ca2..383bb050 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) + namespace Luau { @@ -33,6 +35,12 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); + + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); + } } void TxnLog::concat(TxnLog rhs) @@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - seen.swap(rhs.seen); - rhs.seen.clear(); + if (!FFlag::LuauShareTxnSeen) + { + ownedSeen.swap(rhs.ownedSeen); + rhs.ownedSeen.clear(); + } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); + if (FFlag::LuauShareTxnSeen) + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); + else + return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - seen.push_back(sortedPair); + if (FFlag::LuauShareTxnSeen) + sharedSeen->push_back(sortedPair); + else + ownedSeen.push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == seen.back()); - seen.pop_back(); + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); + } + else + { + LUAU_ASSERT(sortedPair == ownedSeen.back()); + ownedSeen.pop_back(); + } } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 266c1986..49f8e0ca 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -6,6 +6,7 @@ #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data return result; } +using SyntheticNames = std::unordered_map; + namespace Luau { + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + class TypeRehydrationVisitor { - mutable std::map seen; - mutable int count = 0; + std::map seen; + int count = 0; - bool hasSeen(const void* tv) const + bool hasSeen(const void* tv) { void* ttv = const_cast(tv); auto it = seen.find(ttv); @@ -52,15 +70,16 @@ class TypeRehydrationVisitor } public: - TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) + TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions()) : allocator(alloc) + , syntheticNames(syntheticNames) , options(options) { } - AstTypePack* rehydrate(TypePackId tp) const; + AstTypePack* rehydrate(TypePackId tp); - AstType* operator()(const PrimitiveTypeVar& ptv) const + AstType* operator()(const PrimitiveTypeVar& ptv) { switch (ptv.type) { @@ -78,11 +97,11 @@ public: return nullptr; } } - AstType* operator()(const AnyTypeVar&) const + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); } - AstType* operator()(const TableTypeVar& ttv) const + AstType* operator()(const TableTypeVar& ttv) { RecursionCounter counter(&count); @@ -144,12 +163,12 @@ public: return allocator->alloc(Location(), props, indexer); } - AstType* operator()(const MetatableTypeVar& mtv) const + AstType* operator()(const MetatableTypeVar& mtv) { return Luau::visit(*this, mtv.table->ty); } - AstType* operator()(const ClassTypeVar& ctv) const + AstType* operator()(const ClassTypeVar& ctv) { RecursionCounter counter(&count); @@ -176,7 +195,7 @@ public: return allocator->alloc(Location(), props); } - AstType* operator()(const FunctionTypeVar& ftv) const + AstType* operator()(const FunctionTypeVar& ftv) { RecursionCounter counter(&count); @@ -253,10 +272,12 @@ public: size_t i = 0; for (const auto& el : ftv.argNames) { + std::optional* arg = &argNames.data[i++]; + if (el) - argNames.data[i++] = {AstName(el->name.c_str()), el->location}; + new (arg) std::optional(AstArgumentName(AstName(el->name.c_str()), el->location)); else - argNames.data[i++] = {}; + new (arg) std::optional(); } AstArray returnTypes; @@ -290,23 +311,23 @@ public: return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); } - AstType* operator()(const Unifiable::Error&) const + AstType* operator()(const Unifiable::Error&) { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); } - AstType* operator()(const GenericTypeVar& gtv) const + AstType* operator()(const GenericTypeVar& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(gtv.name.c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); } - AstType* operator()(const Unifiable::Bound& bound) const + AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(Unifiable::Free ftv) const + AstType* operator()(const FreeTypeVar& ftv) { return allocator->alloc(Location(), std::nullopt, AstName("free")); } - AstType* operator()(const UnionTypeVar& uv) const + AstType* operator()(const UnionTypeVar& uv) { AstArray unionTypes; unionTypes.size = uv.options.size(); @@ -317,7 +338,7 @@ public: } return allocator->alloc(Location(), unionTypes); } - AstType* operator()(const IntersectionTypeVar& uv) const + AstType* operator()(const IntersectionTypeVar& uv) { AstArray intersectionTypes; intersectionTypes.size = uv.parts.size(); @@ -328,23 +349,28 @@ public: } return allocator->alloc(Location(), intersectionTypes); } - AstType* operator()(const LazyTypeVar& ltv) const + AstType* operator()(const LazyTypeVar& ltv) { return allocator->alloc(Location(), std::nullopt, AstName("")); } private: Allocator* allocator; + SyntheticNames* syntheticNames; const TypeRehydrationOptions& options; }; class TypePackRehydrationVisitor { public: - TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor) : allocator(allocator) + , syntheticNames(syntheticNames) , typeVisitor(typeVisitor) { + LUAU_ASSERT(allocator); + LUAU_ASSERT(syntheticNames); + LUAU_ASSERT(typeVisitor); } AstTypePack* operator()(const BoundTypePack& btp) const @@ -359,7 +385,7 @@ public: head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); for (size_t i = 0; i < tp.head.size(); i++) - head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty); AstTypePack* tail = nullptr; @@ -371,12 +397,12 @@ public: AstTypePack* operator()(const VariadicTypePack& vtp) const { - return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } AstTypePack* operator()(const GenericTypePack& gtp) const { - return allocator->alloc(Location(), AstName(gtp.name.c_str())); + return allocator->alloc(Location(), AstName(getName(allocator, syntheticNames, gtp))); } AstTypePack* operator()(const FreeTypePack& gtp) const @@ -391,12 +417,13 @@ public: private: Allocator* allocator; - const TypeRehydrationVisitor& typeVisitor; + SyntheticNames* syntheticNames; + TypeRehydrationVisitor* typeVisitor; }; -AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) { - TypePackRehydrationVisitor tprv(allocator, *this); + TypePackRehydrationVisitor tprv(allocator, syntheticNames, this); return Luau::visit(tprv, tp->ty); } @@ -431,7 +458,7 @@ public: { if (!type) return nullptr; - return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty); } AstArray typeAstPack(TypePackId type) @@ -443,7 +470,7 @@ public: result.data = static_cast(allocator->allocate(sizeof(AstType*) * v.size())); for (size_t i = 0; i < v.size(); ++i) { - result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); + result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty); } return result; } @@ -495,7 +522,7 @@ public: { if (FFlag::LuauTypeAliasPacks) { - variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); } else { @@ -515,6 +542,7 @@ public: private: Module& module; Allocator* allocator; + SyntheticNames syntheticNames; }; void attachTypeData(SourceModule& source, Module& result) @@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result) AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) { - return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); + SyntheticNames syntheticNames; + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 3a1fdfff..38e2e527 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" +#include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" @@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) -LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name) TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) : resolver(resolver) , iceHandler(iceHandler) + , unifierState(iceHandler) , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) @@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) return; } + int subLevel = 0; + std::vector sorted(block.body.data, block.body.data + block.body.size); toposort(sorted); for (const auto& stat : sorted) { if (const auto& typealias = stat->as()) - check(scope, *typealias, true); + { + check(scope, *typealias, subLevel, true); + ++subLevel; + } } auto protoIter = sorted.begin(); @@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } }; - int subLevel = 0; - while (protoIter != sorted.end()) { // protoIter walks forward @@ -416,7 +419,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // ``` // These both call each other, so `f` will be ordered before `g`, so the call to `g` // is typechecked before `g` has had its body checked. For this reason, there's three - // types for each functuion: before its body is checked, during checking its body, + // types for each function: before its body is checked, during checking its body, // and after its body is checked. // // We currently treat the before-type and the during-type as the same, @@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCall(**protoIter)) + if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter)) { while (checkIter != protoIter) { @@ -1080,7 +1083,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type - // in case this function has a differing signature. The signature discrepency will be caught in checkBlock. + // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) globalBindings[name] = oldBinding; else @@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare) +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) { // This function should be called at most twice for each type alias. // Once with forwardDeclare, and once without. @@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = childScope(scope, typealias.location); + ScopePtr aliasScope = + FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); if (FFlag::LuauTypeAliasPacks) { - auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); FreeTypeVar* ftv = getMutable(ty); @@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo { ScopePtr funScope = childFunctionScope(scope, global.location); - auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks); + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); @@ -1610,25 +1614,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = resolveLValue(scope, *lvalue)) return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - if (FFlag::LuauExtraNilRecovery) - lhsType = stripFromNilAndReport(lhsType, expr.expr->location); + lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - if (!FFlag::LuauMissingUnionPropertyError) - reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(lhsType)) - { - if (std::optional ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false)) - return {*ty}; - } - } - return {errorType}; } @@ -1694,61 +1684,37 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (const UnionTypeVar* utv = get(type)) { - if (FFlag::LuauMissingUnionPropertyError) + std::vector goodOptions; + std::vector badOptions; + + for (TypeId t : utv) { - std::vector goodOptions; - std::vector badOptions; + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - for (TypeId t : utv) - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - reportError(location, UnknownProperty{type, name}); - else - reportError(location, MissingUnionProperty{type, badOptions, name}); - } - return std::nullopt; - } - - std::vector result = reduceUnion(goodOptions); - - if (result.size() == 1) - return result[0]; - - return addType(UnionTypeVar{std::move(result)}); + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); } - else + + if (!badOptions.empty()) { - std::vector options; - - for (TypeId t : utv->options) + if (addErrors) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - options.push_back(*ty); + if (goodOptions.empty()) + reportError(location, UnknownProperty{type, name}); else - return std::nullopt; + reportError(location, MissingUnionProperty{type, badOptions, name}); } - - std::vector result = reduceUnion(options); - - if (result.size() == 1) - return result[0]; - - return addType(UnionTypeVar{std::move(result)}); + return std::nullopt; } + + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1765,7 +1731,7 @@ std::optional TypeChecker::getIndexTypeFromType( // If no parts of the intersection had the property we looked up for, it never existed at all. if (parts.empty()) { - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; } @@ -1779,7 +1745,7 @@ std::optional TypeChecker::getIndexTypeFromType( return addType(IntersectionTypeVar{result}); } - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; @@ -2062,8 +2028,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn case AstExprUnary::Len: tablify(operandType); - if (FFlag::LuauExtraNilRecovery) - operandType = stripFromNilAndReport(operandType, expr.location); + operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) return {errorType}; @@ -2635,8 +2600,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Name name = expr.index.value; - if (FFlag::LuauExtraNilRecovery) - lhs = stripFromNilAndReport(lhs, expr.expr->location); + lhs = stripFromNilAndReport(lhs, expr.expr->location); if (TableTypeVar* lhsTable = getMutableTableType(lhs)) { @@ -2710,8 +2674,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); - if (FFlag::LuauExtraNilRecovery) - exprType = stripFromNilAndReport(exprType, expr.expr->location); + exprType = stripFromNilAndReport(exprType, expr.expr->location); TypeId indexType = checkExpr(scope, *expr.index).type; @@ -2738,10 +2701,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { - if (FFlag::LuauExtraNilRecovery) - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - else - reportError(TypeError{expr.location, NotATable{exprType}}); + reportError(TypeError{expr.expr->location, NotATable{exprType}}); return std::pair(errorType, nullptr); } @@ -2910,7 +2870,7 @@ std::pair TypeChecker::checkFunctionSignature( if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); } TypePackId retPack; @@ -3016,9 +2976,6 @@ std::pair TypeChecker::checkFunctionSignature( if (expectedArgsCurr != expectedArgsEnd) { argType = *expectedArgsCurr; - - if (!FFlag::LuauInferFunctionArgsFix) - ++expectedArgsCurr; } else if (auto expectedArgsTail = expectedArgsCurr.tail()) { @@ -3034,7 +2991,7 @@ std::pair TypeChecker::checkFunctionSignature( funScope->bindings[local] = {argType, local->location}; argTypes.push_back(argType); - if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd) + if (expectedArgsCurr != expectedArgsEnd) ++expectedArgsCurr; } @@ -3170,7 +3127,7 @@ void TypeChecker::checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) { /* Important terminology refresher: - * A function requires paramaters. + * A function requires parameters. * To call a function, you supply arguments. */ TypePackIterator argIter = begin(argPack); @@ -3402,8 +3359,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!FFlag::LuauRankNTypes) instantiate(scope, selfType, expr.func->location); - if (FFlag::LuauExtraNilRecovery) - selfType = stripFromNilAndReport(selfType, expr.func->location); + selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) { @@ -3412,34 +3368,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - if (!FFlag::LuauMissingUnionPropertyError) - reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(selfType)) - { - if (std::optional propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false)) - { - selfType = *strippedUnion; - - functionType = *propTy; - actualFunctionType = instantiate(scope, functionType, expr.func->location); - } - } - - if (!actualFunctionType) - { - functionType = errorType; - actualFunctionType = errorType; - } - } - else - { - functionType = errorType; - actualFunctionType = errorType; - } + functionType = errorType; + actualFunctionType = errorType; } } else @@ -3555,8 +3485,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& errors) { - if (FFlag::LuauExtraNilRecovery) - fn = stripFromNilAndReport(fn, expr.func->location); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) { @@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) return ty; + if (FFlag::LuauQuantifyInPlace2) + { + Luau::quantify(currentModule, ty, scope->level); + return ty; + } + quantification.level = scope->level; quantification.generics.clear(); quantification.genericPacks.clear(); @@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); } // TODO: better error message CLI-39912 @@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); + const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; + std::vector generics; for (const AstName& generic : genericNames) { @@ -5063,12 +5000,12 @@ std::pair, std::vector> TypeChecker::createGener { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) - cached = addType(GenericTypeVar{scope->level, n}); + cached = addType(GenericTypeVar{level, n}); g = cached; } else { - g = addType(Unifiable::Generic{scope->level, n}); + g = addType(Unifiable::Generic{level, n}); } generics.push_back(g); @@ -5093,12 +5030,12 @@ std::pair, std::vector> TypeChecker::createGener { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); g = cached; } else { - g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } genericPacks.push_back(g); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e963fc74..e82f7519 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) namespace Luau { @@ -217,8 +218,7 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); - primitiveType && primitiveType->metatable) + else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -1490,4 +1490,86 @@ std::vector filterMap(TypeId type, TypeIdPredicate predicate) return {}; } +static Tags* getTags(TypeId ty) +{ + ty = follow(ty); + + if (auto ftv = getMutable(ty)) + return &ftv->tags; + else if (auto ttv = getMutable(ty)) + return &ttv->tags; + else if (auto ctv = getMutable(ty)) + return &ctv->tags; + + return nullptr; +} + +void attachTag(TypeId ty, const std::string& tagName) +{ + if (!FFlag::LuauRefactorTagging) + { + if (auto ftv = getMutable(ty)) + { + ftv->tags.emplace_back(tagName); + } + else + { + LUAU_ASSERT(!"Got a non functional type"); + } + } + else + { + if (auto tags = getTags(ty)) + tags->push_back(tagName); + else + LUAU_ASSERT(!"This TypeId does not support tags"); + } +} + +void attachTag(Property& prop, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + + prop.tags.push_back(tagName); +} + +// We would ideally not expose this because it could cause a footgun. +// If the Base class has a tag and you ask if Derived has that tag, it would return false. +// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. +bool hasTag(const Tags& tags, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + return std::find(tags.begin(), tags.end(), tagName) != tags.end(); +} + +bool hasTag(TypeId ty, const std::string& tagName) +{ + ty = follow(ty); + + // We special case classes because getTags only returns a pointer to one vector of tags. + // But classes has multiple vector of tags, represented throughout the hierarchy. + if (auto ctv = get(ty)) + { + while (ctv) + { + if (hasTag(ctv->tags, tagName)) + return true; + else if (!ctv->parent) + return false; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + else if (auto tags = getTags(ty)) + return hasTag(*tags, tagName); + + return false; +} + +bool hasTag(const Property& prop, const std::string& tagName) +{ + return hasTag(prop.tags, tagName); +} + } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 117cbc28..2539650a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -7,6 +7,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/VisitTypeVar.h" #include @@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) +LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) namespace Luau { +struct SkipCacheForType +{ + SkipCacheForType(const DenseHashMap& skipCacheForType) + : skipCacheForType(skipCacheForType) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.boundTo) + { + result = true; + return false; + } + + if (ttv.state != TableState::Sealed) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypeId ty, const T& t) + { + const bool* prev = skipCacheForType.find(ty); + + if (prev && *prev) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const BoundTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const GenericTypePack& ftp) + { + result = true; + return false; + } + + const DenseHashMap& skipCacheForType; + bool result = false; +}; static std::optional hasUnificationTooComplex(const ErrorVec& errors) { @@ -39,7 +130,7 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) @@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , variance(variance) , counters(&countersData) , counters_DEPRECATED(std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(seen) + , log(ownedSeen) , location(location) , variance(variance) , counters(counters ? counters : &countersData) , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); +} + +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + : types(types) + , mode(mode) + , globalScope(std::move(globalScope)) + , log(sharedSeen) + , location(location) + , variance(variance) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) + , sharedState(sharedState) +{ + LUAU_ASSERT(sharedState.iceHandler); } void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); + bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + auto& cache = sharedState.cachedUnify; + + // What if the types are immutable and we proved their relation before + if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + return; + // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. @@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauUnionHeuristic) { + bool found = false; + const std::string* subName = getName(subTy); if (subName) { @@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) + { + found = true; + startIndex = i; + break; + } + } + } + + if (!found && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) { startIndex = i; break; @@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; - for (TypeId type : uv->parts) + size_t startIndex = 0; + + if (cacheEnabled) { + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[i]; + + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } + } + } + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(superTy, type, isFunctionCall); @@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool tryUnifyFunctions(superTy, subTy, isFunctionCall); else if (get(superTy) && get(subTy)) + { tryUnifyTables(superTy, subTy, isIntersection); + if (cacheEnabled && errors.empty()) + cacheResult(superTy, subTy); + } + // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. else if (get(superTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); @@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } +void Unifier::cacheResult(TypeId superTy, TypeId subTy) +{ + LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); + + bool* superTyInfo = sharedState.skipCacheForType.find(superTy); + + if (superTyInfo && *superTyInfo) + return; + + bool* subTyInfo = sharedState.skipCacheForType.find(subTy); + + if (subTyInfo && *subTyInfo) + return; + + auto skipCacheFor = [this](TypeId ty) { + SkipCacheForType visitor{sharedState.skipCacheForType}; + visitTypeVarOnce(ty, visitor, sharedState.seenAny); + + sharedState.skipCacheForType[ty] = visitor.result; + + return visitor.result; + }; + + if (!superTyInfo && skipCacheFor(superTy)) + return; + + if (!subTyInfo && skipCacheFor(subTy)) + return; + + sharedState.cachedUnify.insert({superTy, subTy}); + + if (variance == Invariant) + sharedState.cachedUnify.insert({subTy, superTy}); +} + struct WeirdIter { TypePackId packId; @@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(superTp, subTp, isFunctionCall); } /* @@ -650,11 +837,11 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking a return value, we swap these to produce - // the expected error message. + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. size_t expectedSize = size(superTp); size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result || ctx == CountMismatch::Return) + if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); @@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) std::vector missingProperties; std::vector extraProperties; + // Optimization: First test that the property sets are compatible without doing any recursive unification + if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + { + for (const auto& [propName, superProp] : lt->props) + { + auto subIter = rt->props.find(propName); + if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + missingProperties.push_back(propName); + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + } + + // And vice versa if we're invariant + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + { + for (const auto& [propName, subProp] : rt->props) + { + auto superIter = lt->props.find(propName); + if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + extraProperties.push_back(propName); + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + } + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) innerState.log.rollback(); } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (rt->state == TableState::Free) { log(rt); @@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) lt->props[name] = clone; } else if (variance == Covariant) - {} + { + } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (lt->state == TableState::Free) { log(lt); @@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); } else - return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); + return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) @@ -1247,7 +1471,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. // Otherwise, we would falsely generate an extra-property-error for 's' in this code: // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking agaist the table '{n: number}'. + // When checking against the table '{n: number}'. if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) { // Check for extra properties in the subTy @@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { std::vector queue = {ty}; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); + } } else { @@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { std::vector queue; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack(queue, tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); + + queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); + } } else { @@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTy.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTy.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTy_DEPRECATED.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) @@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTp.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTp_DEPRECATED.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) @@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; + if (FFlag::LuauShareTxnSeen) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + else + return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const @@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::ice(const std::string& message, const Location& location) { - iceHandler->ice(message, location); + sharedState.iceHandler->ice(message, location); } void Unifier::ice(const std::string& message) { - iceHandler->ice(message); + sharedState.iceHandler->ice(message); } } // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 641dfd3c..503eca61 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair createScopeDa } // namespace Luau // Regular scope -#define LUAU_TIMETRACE_SCOPE(name, category) \ +#define LUAU_TIMETRACE_SCOPE(name, category) \ static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) // A scope without nested scopes that may be skipped if the time it took is less than the threshold -#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) // Extra key/value data can be added to regular scopes -#define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ - if (FFlag::DebugLuauTimeTracing) \ +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ lttScopeStatic.second.eventArgument(name, value); \ } while (false) @@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair createScopeDa #define LUAU_TIMETRACE_SCOPE(name, category) #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) #define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ + do \ + { \ } while (false) #endif diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 40026d8b..846bc0ba 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1593,7 +1593,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { Location location = lexer.current().location; - // For a missing type annoation, capture 'space' between last token and the next one + // For a missing type annotation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e6aab20e..ded50e53 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -77,7 +77,10 @@ struct GlobalContext // Ideally we would want all ThreadContext destructors to run // But in VS, not all thread_local object instances are destroyed for (ThreadContext* context : threads) - context->flushEvents(); + { + if (!context->events.empty()) + context->flushEvents(); + } if (traceFile) fclose(traceFile); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 6baa21ea..4968d080 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -169,7 +169,17 @@ static std::string runCode(lua_State* L, const std::string& source) } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(T, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + error += "\nstack backtrace:\n"; error += lua_debugtrace(T); @@ -322,7 +332,17 @@ static bool runFile(const char* name, lua_State* GL) } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(L, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(L, -1)) + { + error = str; + } + error += "\nstacktrace:\n"; error += lua_debugtrace(L); diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 07be2e74..4b03ed1c 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -208,14 +208,14 @@ enum LuauOpcode LOP_MODK, LOP_POWK, - // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthful) and put the result into target register + // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthy) and put the result into target register // A: target register // B: source register 1 // C: source register 2 LOP_AND, LOP_OR, - // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthful) and put the result into target register + // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthy) and put the result into target register // A: target register // B: source register // C: constant table index (0..255) diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 797ee20d..7750a1d9 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -718,9 +718,9 @@ struct Compiler } // compile expr to target temp register - // if the expr (or not expr if onlyTruth is false) is truthful, jump via skipJump - // if the expr (or not expr if onlyTruth is false) is falseful, fall through (target isn't guaranteed to be updated in this case) - // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthfulness of the expression + // if the expr (or not expr if onlyTruth is false) is truthy, jump via skipJump + // if the expr (or not expr if onlyTruth is false) is falsy, fall through (target isn't guaranteed to be updated in this case) + // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthiness of the expression void compileConditionValue(AstExpr* node, const uint8_t* target, std::vector& skipJump, bool onlyTruth) { // Optimization: we don't need to compute constant values @@ -728,7 +728,7 @@ struct Compiler if (cv && cv->type != Constant::Type_Unknown) { - // note that we only need to compute the value if it's truthful; otherwise we cal fall through + // note that we only need to compute the value if it's truthy; otherwise we cal fall through if (cv->isTruthful() == onlyTruth) { if (target) @@ -747,7 +747,7 @@ struct Compiler case AstExprBinary::And: case AstExprBinary::Or: { - // disambiguation: there's 4 cases (we only need truthful or falseful results based on onlyTruth) + // disambiguation: there's 4 cases (we only need truthy or falsy results based on onlyTruth) // onlyTruth = 1: a and b transforms to a ? b : dontcare // onlyTruth = 1: a or b transforms to a ? a : a // onlyTruth = 0: a and b transforms to !a ? a : b @@ -791,8 +791,8 @@ struct Compiler if (target) { // since target is a temp register, we'll initialize it to 1, and then jump if the comparison is true - // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falseful results - // when we only care about falseful values instead of truthful values, the process is the same but with flipped conditionals + // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falsy results + // when we only care about falsy values instead of truthy values, the process is the same but with flipped conditionals bytecode.emitABC(LOP_LOADB, *target, onlyTruth ? 1 : 0, 0); } diff --git a/Makefile b/Makefile index 0056870b..7788251d 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +.SUFFIXES: MAKEFLAGS+=-r -j8 COMMA=, @@ -107,6 +108,7 @@ coverage: $(TESTS_TARGET) rm default.profraw default-flags.profraw llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests llvm-cov report -ignore-filename-regex=\(tests\|extern\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests + llvm-cov export -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info format: find . -name '*.h' -or -name '*.cpp' | xargs clang-format -i diff --git a/Sources.cmake b/Sources.cmake index 83ed5230..c30cf77d 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h + Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h @@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeVar.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h + Analysis/include/Luau/UnifierSharedState.h Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitTypeVar.h @@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/Module.cpp Analysis/src/Predicate.cpp + Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp Analysis/src/Substitution.cpp diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 7a09ae9f..30cffaff 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -76,7 +76,7 @@ struct luaL_Buffer char buffer[LUA_BUFFERSIZE]; }; -// when internal buffer storage is exhaused, a mutable string value 'storage' will be placed on the stack +// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) // with the exception of luaL_addvalue that expects the value at the top and string buffer further away (top-2) // functions that accept a 'boxloc' support string buffer placement at any location in the stack diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 01315360..f2e97c66 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) luaC_checkGC(L); luaC_checkthreadsleep(L); Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); - memcpy(u->data + sz, &dtor, sizeof(dtor)); + memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); return u->data; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index e37618f7..2a684ee4 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -313,7 +313,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired { size_t newsize = currentsize + currentsize / 2; - // check for size oveflow + // check for size overflow if (SIZE_MAX - desiredsize < currentsize) luaL_error(L, "buffer too large"); diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index c0b50b96..9724c0e7 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) static const char* const statnames[] = {"running", "suspended", "normal", "dead"}; -static int costatus(lua_State* L, lua_State* co) +static int auxstatus(lua_State* L, lua_State* co) { if (co == L) return CO_RUN; @@ -25,7 +23,7 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; if (co->status == LUA_BREAK) return CO_NOR; - if (co->status != 0) /* some error occured */ + if (co->status != 0) /* some error occurred */ return CO_DEAD; if (co->ci != co->base_ci) /* does it have frames? */ return CO_NOR; @@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; /* initial state */ } -static int luaB_costatus(lua_State* L) +static int costatus(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); - lua_pushstring(L, statnames[costatus(L, co)]); + lua_pushstring(L, statnames[auxstatus(L, co)]); return 1; } @@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) // error handling for edge cases if (co->status != LUA_YIELD) { - int status = costatus(L, co); + int status = auxstatus(L, co); if (status != CO_SUS) { lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]); @@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co) } } -static int luaB_coresumefinish(lua_State* L, int r) +static int coresumefinish(lua_State* L, int r) { if (r < 0) { @@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r) } } -static int luaB_coresumey(lua_State* L) +static int coresumey(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_coresumecont(lua_State* L, int status) +static int coresumecont(lua_State* L, int status) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_auxwrapfinish(lua_State* L, int r) +static int auxwrapfinish(lua_State* L, int r) { if (r < 0) { @@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r) return r; } -static int luaB_auxwrapy(lua_State* L) +static int auxwrapy(lua_State* L) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); int narg = cast_int(L->top - L->base); @@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_auxwrapcont(lua_State* L, int status) +static int auxwrapcont(lua_State* L, int status) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); @@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_cocreate(lua_State* L) +static int cocreate(lua_State* L) { luaL_checktype(L, 1, LUA_TFUNCTION); lua_State* NL = lua_newthread(L); - - if (FFlag::LuauPreferXpush) - { - lua_xpush(L, NL, 1); // push function on top of NL - } - else - { - lua_pushvalue(L, 1); /* move function to top */ - lua_xmove(L, NL, 1); /* move function from L to NL */ - } - + lua_xpush(L, NL, 1); // push function on top of NL return 1; } -static int luaB_cowrap(lua_State* L) +static int cowrap(lua_State* L) { - luaB_cocreate(L); + cocreate(L); - lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont); + lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } -static int luaB_yield(lua_State* L) +static int coyield(lua_State* L) { int nres = cast_int(L->top - L->base); return lua_yield(L, nres); } -static int luaB_corunning(lua_State* L) +static int corunning(lua_State* L) { if (lua_pushthread(L)) lua_pushnil(L); /* main thread is not a coroutine */ return 1; } -static int luaB_yieldable(lua_State* L) +static int coyieldable(lua_State* L) { lua_pushboolean(L, lua_isyieldable(L)); return 1; } static const luaL_Reg co_funcs[] = { - {"create", luaB_cocreate}, - {"running", luaB_corunning}, - {"status", luaB_costatus}, - {"wrap", luaB_cowrap}, - {"yield", luaB_yield}, - {"isyieldable", luaB_yieldable}, + {"create", cocreate}, + {"running", corunning}, + {"status", costatus}, + {"wrap", cowrap}, + {"yield", coyield}, + {"isyieldable", coyieldable}, {NULL, NULL}, }; @@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont); + lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 39a61597..328b47e6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -18,6 +18,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) +LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) /* ** {====================================================== @@ -536,7 +537,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } - // an error occured, check if we have a protected error callback + if (FFlag::LuauCcallRestoreFix) + { + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; + } + + // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) { L->global->cb.debugprotectederror(L); @@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - L->nCcalls = oldnCcalls; + if (!FFlag::LuauCcallRestoreFix) + { + L->nCcalls = oldnCcalls; + } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 4fe1c341..72807f0f 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -37,7 +37,7 @@ /* results from luaD_precall */ #define PCRLUA 0 /* initiated a call to a Lua function */ #define PCRC 1 /* did a call to a C function */ -#define PCRYIELD 2 /* C funtion yielded */ +#define PCRYIELD 2 /* C function yielded */ /* type of protected functions, to be ran by `runprotected' */ typedef void (*Pfunc)(lua_State* L, void* ud); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 0b543026..64878569 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -76,7 +76,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) if (p->v == level) { /* found a corresponding upvalue? */ if (isdead(g, obj2gco(p))) /* is it dead? */ - changewhite(obj2gco(p)); /* ressurect it */ + changewhite(obj2gco(p)); /* resurrect it */ return p; } pp = &p->next; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 510a9f54..6553009f 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,11 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) -LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) +LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, g->gcstats.currcycle.marktime += seconds; // atomic step had to be performed during the switch and it's tracked separately - if (g->gcstate == GCSsweepstring) + if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; + case GCSatomic: + g->gcstats.currcycle.atomictime += seconds; + break; case GCSsweepstring: case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; break; + default: + LUAU_ASSERT(!"Unexpected GC state"); } if (assist) @@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h) if (h->metatable) markobject(g, cast_to(Table*, h->metatable)); - if (FFlag::LuauShrinkWeakTables) + /* is there a weak mode? */ + if (const char* modev = gettablemode(g, h)) { - /* is there a weak mode? */ - if (const char* modev = gettablemode(g, h)) - { - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } - } - } - else - { - const TValue* mode = gfasttm(g, h->metatable, TM_MODE); - if (mode && ttisstring(mode)) - { /* is there a weak mode? */ - const char* modev = svalue(mode); - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } + weakkey = (strchr(modev, 'k') != NULL); + weakvalue = (strchr(modev, 'v') != NULL); + if (weakkey || weakvalue) + { /* is really weak? */ + h->gclist = g->weak; /* must be cleared after GC, ... */ + g->weak = obj2gco(h); /* ... so put in the appropriate list */ } } @@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack) for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); /* final traversal? */ - if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack)) + if (g->gcstate == GCSatomic || clearstack) { StkId stack_end = l->stack + l->stacksize; for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */ @@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g) lua_State* th = gco2th(o); g->gray = th->gclist; - if (FFlag::LuauGcFullSkipInactiveThreads) + LUAU_ASSERT(!luaC_threadsleeping(th)); + + // threads that are executing and the main thread are not deactivated + bool active = luaC_threadactive(th) || th == th->global->mainthread; + + if (!active && g->gcstate == GCSpropagate) { - LUAU_ASSERT(!luaC_threadsleeping(th)); + traversestack(g, th, /* clearstack= */ true); - // threads that are executing and the main thread are not deactivated - bool active = luaC_threadactive(th) || th == th->global->mainthread; - - if (!active && g->gcstate == GCSpropagate) - { - traversestack(g, th, /* clearstack= */ true); - - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); - } - else - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - - traversestack(g, th, /* clearstack= */ false); - } + l_setbit(th->stackstate, THREAD_SLEEPINGBIT); } else { @@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g) } } -static void propagateall(global_State* g) +static size_t propagateall(global_State* g) { + size_t work = 0; while (g->gray) { - propagatemark(g); + work += propagatemark(g); } + return work; } /* @@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o) /* ** clear collected entries from weaktables */ -static void cleartable(lua_State* L, GCObject* l) +static size_t cleartable(lua_State* L, GCObject* l) { + size_t work = 0; while (l) { Table* h = gco2h(l); + work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + int i = h->sizearray; while (i--) { @@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l) { LuaNode* n = gnode(h, i); - if (FFlag::LuauShrinkWeakTables) + // non-empty entry? + if (!ttisnil(gval(n))) { - // non-empty entry? - if (!ttisnil(gval(n))) - { - // can we clear key or value? - if (iscleared(gkey(n)) || iscleared(gval(n))) - { - setnilvalue(gval(n)); /* remove value ... */ - removeentry(n); /* remove entry from table */ - } - else - { - activevalues++; - } - } - } - else - { - if (!ttisnil(gval(n)) && /* non-empty entry? */ - (iscleared(gkey(n)) || iscleared(gval(n)))) + // can we clear key or value? + if (iscleared(gkey(n)) || iscleared(gval(n))) { setnilvalue(gval(n)); /* remove value ... */ removeentry(n); /* remove entry from table */ } + else + { + activevalues++; + } } } - if (FFlag::LuauShrinkWeakTables) + if (const char* modev = gettablemode(L->global, h)) { - if (const char* modev = gettablemode(L->global, h)) + // are we allowed to shrink this weak table? + if (strchr(modev, 's')) { - // are we allowed to shrink this weak table? - if (strchr(modev, 's')) - { - // shrink at 37.5% occupancy - if (activevalues < sizenode(h) * 3 / 8) - luaH_resizehash(L, h, activevalues); - } + // shrink at 37.5% occupancy + if (activevalues < sizenode(h) * 3 / 8) + luaH_resizehash(L, h, activevalues); } } l = h->gclist; } + return work; } static void shrinkstack(lua_State* L) @@ -655,37 +619,49 @@ static void markroot(lua_State* L) g->gcstate = GCSpropagate; } -static void remarkupvals(global_State* g) +static size_t remarkupvals(global_State* g) { - UpVal* uv; - for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + size_t work = 0; + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { + work += sizeof(UpVal); LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); if (isgray(obj2gco(uv))) markvalue(g, uv->v); } + return work; } -static void atomic(lua_State* L) +static size_t atomic(lua_State* L) { global_State* g = L->global; - g->gcstate = GCSatomic; + size_t work = 0; + + if (FFlag::LuauSeparateAtomic) + { + LUAU_ASSERT(g->gcstate == GCSatomic); + } + else + { + g->gcstate = GCSatomic; + } + /* remark occasional upvalues of (maybe) dead threads */ - remarkupvals(g); + work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ - propagateall(g); + work += propagateall(g); /* remark weak tables */ g->gray = g->weak; g->weak = NULL; LUAU_ASSERT(!iswhite(obj2gco(g->mainthread))); markobject(g, L); /* mark running thread */ markmt(g); /* mark basic metatables (again) */ - propagateall(g); + work += propagateall(g); /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; - propagateall(g); - cleartable(L, g->weak); /* remove collected objects from weak tables */ + work += propagateall(g); + work += cleartable(L, g->weak); /* remove collected objects from weak tables */ g->weak = NULL; /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); @@ -693,7 +669,12 @@ static void atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - GC_INTERRUPT(GCSatomic); + if (!FFlag::LuauSeparateAtomic) + { + GC_INTERRUPT(GCSatomic); + } + + return work; } static size_t singlestep(lua_State* L) @@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + if (g->gray) { - if (g->gray) - { - g->gcstats.currcycle.markitems++; + g->gcstats.currcycle.markitems++; - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } + cost = propagatemark(g); } else { - if (g->gray) - { - g->gcstats.currcycle.markitems++; + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L) } else /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { size_t traversedcount = 0; @@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; @@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + while (g->gray && cost < limit) { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; + g->gcstats.currcycle.markitems++; - cost += propagatemark(g); - } - - if (!g->gray) - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } + cost += propagatemark(g); } - else + + if (!g->gray) { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - cost += propagatemark(g); - } - - if (!g->gray) /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { while (g->sweepstrgc < g->strt.size && cost < limit) @@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; } @@ -1008,7 +978,7 @@ void luaC_step(lua_State* L, bool assist) startGcCycleStats(g); int lastgcstate = g->gcstate; - double lastttimestamp = lua_clock(); + double lasttimestamp = lua_clock(); if (FFlag::LuauConsolidatedStep) { @@ -1034,15 +1004,15 @@ void luaC_step(lua_State* L, bool assist) double now = lua_clock(); - recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); - lastttimestamp = now; + lasttimestamp = now; lastgcstate = g->gcstate; } } while (lim > 0 && g->gcstate != GCSpause); } - recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); // at the end of the last cycle if (g->gcstate == GCSpause) @@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L) g->weak = NULL; g->gcstate = GCSsweepstring; } - LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain); + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); /* finish any pending sweep phase */ while (g->gcstate != GCSpause) { @@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L) void luaC_barrierupval(lua_State* L, GCObject* v) { - if (FFlag::LuauGcFullSkipInactiveThreads) - { - global_State* g = L->global; - LUAU_ASSERT(iswhite(v) && !isdead(g, v)); + global_State* g = L->global; + LUAU_ASSERT(iswhite(v) && !isdead(g, v)); - if (keepinvariant(g)) - reallymarkobject(g, v); - } + if (keepinvariant(g)) + reallymarkobject(g, v); } void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) @@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index dc780bba..f434e506 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,8 +6,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - /* ** Possible states of the Garbage Collector */ @@ -25,10 +23,10 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) ** still-black objects. The invariant is restored when sweep ends and ** all objects are white again. */ -#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain) +#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) /* -** some userful bit tricks +** some useful bit tricks */ #define resetbits(x, m) ((x) &= cast_to(uint8_t, ~(m))) #define setbits(x, m) ((x) |= (m)) @@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); LUAI_FUNC void luaC_wakethread(lua_State* L); -LUAI_FUNC const char* luaC_statename(int state); \ No newline at end of file +LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 2759f3b8..d8b265cb 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) if (page->freeNext >= 0) { - block = page->data + page->freeNext; + block = &page->data + page->freeNext; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); page->freeNext -= page->blockSize; diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index d77e17c9..18ee1cda 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u) void (*dtor)(void*) = nullptr; if (u->tag == UTAG_IDTOR) - memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor)); + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); else if (u->tag) dtor = L->global->udatagc[u->tag]; diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index b932a85b..a168b652 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens -template +template struct TempBuffer { lua_State* L; @@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint32_t mainid = readVarInt(data, size, offset); Proto* main = protos[mainid]; + luaC_checkthreadsleep(L); + Closure* cl = luaF_newLclosure(L, 0, envt, main); setclvalue(L, L->top, cl); incr_top(L); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua new file mode 100644 index 00000000..87b9abfd --- /dev/null +++ b/bench/tests/chess.lua @@ -0,0 +1,849 @@ + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local RANKS = "12345678" +local FILES = "abcdefgh" +local PieceSymbols = "PpRrNnBbQqKk" +local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"} +local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" + +-- +-- Lua 5.2 Compat +-- + +if not table.create then + function table.create(n, v) + local result = {} + for i=1,n do result[i] = v end + return result + end +end + +if not table.move then + function table.move(a, from, to, start, target) + local dx = start - from + for i=from,to do + target[i+dx] = a[i] + end + end +end + + +-- +-- Utils +-- + +local function square(s) + return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9 +end + +local function squareName(n) + local file = n % 8 + local rank = (n-file)/8 + return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1) +end + +local function moveName(v ) + local from = bit32.extract(v, 6, 6) + local to = bit32.extract(v, 0, 6) + local piece = bit32.extract(v, 20, 4) + local captured = bit32.extract(v, 25, 4) + + local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to) + + if bit32.extract(v,14) == 1 then + if to > from then + return "O-O" + else + return "O-O-O" + end + end + + local promote = bit32.extract(v,15,4) + if promote ~= 0 then + move = move .. "=" .. PieceSymbols:sub(promote,promote) + end + return move +end + +local function ucimove(m) + local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6)) + local promote = bit32.extract(m,15,4) + if promote > 0 then + mm = mm .. PieceSymbols:sub(promote,promote):lower() + end + return mm +end + +local _utils = {squareName, moveName} + +-- +-- Bitboards +-- + +local Bitboard = {} + + +function Bitboard:toString() + local out = {} + + local src = self.h + for x=7,0,-1 do + table.insert(out, RANKS:sub(x+1,x+1)) + table.insert(out, " ") + local bit = bit32.lshift(1,(x%4) * 8) + for x=0,7 do + if bit32.band(src, bit) ~= 0 then + table.insert(out, "x ") + else + table.insert(out, "- ") + end + bit = bit32.lshift(bit, 1) + end + if x == 4 then + src = self.l + end + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h) + return table.concat(out) +end + + +function Bitboard.from(l ,h ) + return setmetatable({l=l, h=h}, Bitboard) +end + +Bitboard.zero = Bitboard.from(0,0) +Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF) + +local Rank1 = Bitboard.from(0x000000FF, 0) +local Rank3 = Bitboard.from(0x00FF0000, 0) +local Rank6 = Bitboard.from(0, 0x0000FF00) +local Rank8 = Bitboard.from(0, 0xFF000000) +local FileA = Bitboard.from(0x01010101, 0x01010101) +local FileB = Bitboard.from(0x02020202, 0x02020202) +local FileC = Bitboard.from(0x04040404, 0x04040404) +local FileD = Bitboard.from(0x08080808, 0x08080808) +local FileE = Bitboard.from(0x10101010, 0x10101010) +local FileF = Bitboard.from(0x20202020, 0x20202020) +local FileG = Bitboard.from(0x40404040, 0x40404040) +local FileH = Bitboard.from(0x80808080, 0x80808080) + +local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH} + +-- These masks are filled out below for all files +local RightMasks = {FileH} +local LeftMasks = {FileA} + + + +local function popcnt32(i) + i = i - bit32.band(bit32.rshift(i,1), 0x55555555) + i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333) + return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24) +end + +function Bitboard:up() + return self:lshift(8) +end + +function Bitboard:down() + return self:rshift(8) +end + +function Bitboard:right() + return self:band(FileH:inverse()):lshift(1) +end + +function Bitboard:left() + return self:band(FileA:inverse()):rshift(1) +end + +function Bitboard:move(x,y) + local out = self + + if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end + if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end + + if y < 0 then out = out:rshift(-8 * y) end + if y > 0 then out = out:lshift(8 * y) end + return out +end + + +function Bitboard:popcnt() + return popcnt32(self.l) + popcnt32(self.h) +end + +function Bitboard:band(other ) + return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h)) +end + +function Bitboard:bandnot(other ) + return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h))) +end + +function Bitboard:bandempty(other ) + return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0 +end + +function Bitboard:bor(other ) + return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h)) +end + +function Bitboard:bxor(other ) + return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h)) +end + +function Bitboard:inverse() + return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF)) +end + +function Bitboard:empty() + return self.h == 0 and self.l == 0 +end + +function Bitboard:ctz() + local target = self.l + local offset = 0 + local result = 0 + + if target == 0 then + target = self.h + result = 32 + end + + if target == 0 then + return 64 + end + + while bit32.extract(target, offset) == 0 do + offset = offset + 1 + end + + return result + offset +end + +function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 +end + + +function Bitboard:lshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + + if amt > 31 then + return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + end + + local l = bit32.lshift(self.l, amt) + local h = bit32.bor( + bit32.lshift(self.h, amt), + bit32.extract(self.l, 32-amt, amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:rshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + local h = bit32.rshift(self.h, amt) + local l = bit32.bor( + bit32.rshift(self.l, amt), + bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:index(i) + if i > 31 then + return bit32.extract(self.h, i - 32) + else + return bit32.extract(self.l, i) + end +end + +function Bitboard:set(i , v) + if i > 31 then + return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32)) + else + return Bitboard.from(bit32.replace(self.l, v, i), self.h) + end +end + +function Bitboard:isolate(i) + return self:band(Bitboard.some(i)) +end + +function Bitboard.some(idx ) + return Bitboard.zero:set(idx, 1) +end + +Bitboard.__index = Bitboard +Bitboard.__tostring = Bitboard.toString + +for i=2,8 do + RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH) + LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA) +end +-- +-- Board +-- + +local Board = {} + + +function Board.new() + local boards = table.create(12, Bitboard.zero) + boards.ocupied = Bitboard.zero + boards.white = Bitboard.zero + boards.black = Bitboard.zero + boards.unocupied = Bitboard.full + boards.ep = Bitboard.zero + boards.castle = Bitboard.zero + boards.toMove = 1 + boards.hm = 0 + boards.moves = 0 + boards.material = 0 + + return setmetatable(boards, Board) +end + +function Board.fromFen(fen ) + local b = Board.new() + local i = 0 + local rank = 7 + local file = 0 + + while true do + i = i + 1 + local p = fen:sub(i,i) + if p == '/' then + rank = rank - 1 + file = 0 + elseif tonumber(p) ~= nil then + file = file + tonumber(p) + else + local pidx = PieceSymbols:find(p) + if pidx == nil then break end + b[pidx] = b[pidx]:set(rank*8+file, 1) + file = file + 1 + end + end + + + local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i) + if move == nil then print(fen:sub(i)) end + b.toMove = move == 'w' and 1 or 2 + + if ep ~= "-" then + b.ep = Bitboard.some(square(ep)) + end + + if castle ~= "-" then + local oo = Bitboard.zero + if castle:find("K") then + oo = oo:set(7, 1) + end + if castle:find("Q") then + oo = oo:set(0, 1) + end + if castle:find("k") then + oo = oo:set(63, 1) + end + if castle:find("q") then + oo = oo:set(56, 1) + end + + b.castle = oo + end + + b.hm = hm + b.moves = m + + b:updateCache() + return b + +end + +function Board:index(idx ) + if self.white:index(idx) == 1 then + for p=1,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + else + for p=2,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + end + + return 0 +end + +function Board:updateCache() + for i=1,11,2 do + self.white = self.white:bor(self[i]) + self.black = self.black:bor(self[i+1]) + end + + self.ocupied = self.black:bor(self.white) + self.unocupied = self.ocupied:inverse() + self.material = + 100*self[1]:popcnt() - 100*self[2]:popcnt() + + 500*self[3]:popcnt() - 500*self[4]:popcnt() + + 300*self[5]:popcnt() - 300*self[6]:popcnt() + + 300*self[7]:popcnt() - 300*self[8]:popcnt() + + 900*self[9]:popcnt() - 900*self[10]:popcnt() + +end + +function Board:fen() + local out = {} + local s = 0 + local idx = 56 + for i=0,63 do + if i % 8 == 0 and i > 0 then + idx = idx - 16 + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, '/') + end + local p = self:index(idx) + if p == 0 then + s = s + 1 + else + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, PieceSymbols:sub(p,p)) + end + + idx = idx + 1 + end + if s > 0 then + table.insert(out, '' .. s) + end + + table.insert(out, self.toMove == 1 and ' w ' or ' b ') + if self.castle:empty() then + table.insert(out, '-') + else + if self.castle:index(7) == 1 then table.insert(out, 'K') end + if self.castle:index(0) == 1 then table.insert(out, 'Q') end + if self.castle:index(63) == 1 then table.insert(out, 'k') end + if self.castle:index(56) == 1 then table.insert(out, 'q') end + end + + table.insert(out, ' ') + if self.ep:empty() then + table.insert(out, '-') + else + table.insert(out, squareName(self.ep:ctz())) + end + + table.insert(out, ' ' .. self.hm) + table.insert(out, ' ' .. self.moves) + + return table.concat(out) +end + +function Board:pmoves(idx) + return self:generate(idx) +end + +function Board:pcaptures(idx) + return self:generate(idx):band(self.ocupied) +end + +local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}} +local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}} +local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}} +local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}} + +function Board:generate(idx) + local piece = self:index(idx) + local r = Bitboard.some(idx) + local out = Bitboard.zero + local type = bit32.rshift(piece - 1, 1) + local cancapture = piece % 2 == 1 and self.black or self.white + + if piece == 0 then return Bitboard.zero end + + if type == 0 then + -- Pawn + local d = -(piece*2 - 3) + local movetwo = piece == 1 and Rank3 or Rank6 + + out = out:bor(r:move(0,d):band(self.unocupied)) + out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied)) + + local captures = r:move(0,d) + captures = captures:right():bor(captures:left()) + + if not captures:bandempty(self.ep) then + out = out:bor(self.ep) + end + + captures = captures:band(cancapture) + out = out:bor(captures) + + return out + elseif type == 5 then + -- King + for x=-1,1,1 do + for y = -1,1,1 do + local w = r:move(x,y) + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + end + elseif type == 2 then + -- Knight + for _,j in ipairs(KNIGHT_MOVES) do + local w = r:move(j[1],j[2]) + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + else + -- Sliders (Rook, Bishop, Queen) + local slides + if type == 1 then + slides = ROOK_SLIDES + elseif type == 3 then + slides = BISHOP_SLIDES + else + slides = QUEEN_SLIDES + end + + for _, op in ipairs(slides) do + local w = r + for i=1,7 do + w = w:move(op[1], op[2]) + if w:empty() then break end + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + break + end + end + end + end + + + return out +end + +-- 0-5 - From Square +-- 6-11 - To Square +-- 12 - is Check +-- 13 - Is EnPassent +-- 14 - Is Castle +-- 15-19 - Promotion Piece +-- 20-24 - Moved Pice +-- 25-29 - Captured Piece + + +function Board:toString(mark ) + local out = {} + for x=8,1,-1 do + table.insert(out, RANKS:sub(x,x) .. " ") + + for y=1,8 do + local n = 8*x+y-9 + local i = self:index(n) + if i == 0 then + table.insert(out, '-') + else + -- out = out .. PieceSymbols:sub(i,i) + table.insert(out, UnicodePieces[i]) + end + if mark ~= nil and mark:index(n) ~= 0 then + table.insert(out, ')') + elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then + table.insert(out, '(') + else + table.insert(out, ' ') + end + end + + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n") + return table.concat(out) +end + +function Board:moveList() + local tm = self.toMove == 1 and self.white or self.black + local castle_rank = self.toMove == 1 and Rank1 or Rank8 + local out = {} + local function emit(id) + if not self:applyMove(id):illegalyChecked() then + table.insert(out, id) + end + end + + local cr = tm:band(self.castle):band(castle_rank) + if not cr:empty() then + local p = self.toMove == 1 and 11 or 12 + local tcolor = self.toMove == 1 and self.black or self.white + local kidx = self[p]:ctz() + + + local castle = bit32.replace(0, p, 20, 4) + castle = bit32.replace(castle, kidx, 6, 6) + castle = bit32.replace(castle, 1, 14) + + + local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank) + local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p]) + if + not cr:bandempty(FileA) and + mustbeemptyl:bandempty(self.ocupied) and + not self:isSquareThreatened(cantbethreatened, tcolor) + then + emit(bit32.replace(castle, kidx - 2, 0, 6)) + end + + + local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank) + if + not cr:bandempty(FileH) and + mustbeemptyr:bandempty(self.ocupied) and + not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor) + then + emit(bit32.replace(castle, kidx + 2, 0, 6)) + end + end + + local sq = tm:ctz() + repeat + local p = self:index(sq) + local moves = self:pmoves(sq) + + while not moves:empty() do + local m = moves:ctz() + moves = moves:set(m, 0) + local id = bit32.replace(m, sq, 6, 6) + id = bit32.replace(id, p, 20, 4) + local mbb = Bitboard.some(m) + if not self.ocupied:bandempty(mbb) then + id = bit32.replace(id, self:index(m), 25, 4) + end + + -- Check if pawn needs to be promoted + if p == 1 and m >= 8*7 then + for i=3,9,2 do + emit(bit32.replace(id, i, 15, 4)) + end + elseif p == 2 and m < 8 then + for i=4,10,2 do + emit(bit32.replace(id, i, 15, 4)) + end + else + emit(id) + end + end + sq = tm:ctzafter(sq) + until sq == 64 + return out +end + +function Board:illegalyChecked() + local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")] + return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black) +end + +function Board:isSquareThreatened(target , color ) + local tm = color + local sq = tm:ctz() + repeat + local moves = self:pmoves(sq) + if not moves:bandempty(target) then + return true + end + sq = color:ctzafter(sq) + until sq == 64 + return false +end + +function Board:perft(depth ) + if depth == 0 then return 1 end + if depth == 1 then + return #self:moveList() + end + local result = 0 + for k,m in ipairs(self:moveList()) do + local c = self:applyMove(m):perft(depth - 1) + if c == 0 then + -- Perft only counts leaf nodes at target depth + -- result = result + 1 + else + result = result + c + end + end + return result +end + + +function Board:applyMove(move ) + local out = Board.new() + table.move(self, 1, 12, 1, out) + local from = bit32.extract(move, 6, 6) + local to = bit32.extract(move, 0, 6) + local promote = bit32.extract(move, 15, 4) + local piece = self:index(from) + local captured = self:index(to) + local tom = Bitboard.some(to) + local isCastle = bit32.extract(move, 14) + + if piece % 2 == 0 then + out.moves = self.moves + 1 + end + + if captured == 1 or piece < 3 then + out.hm = 0 + else + out.hm = self.hm + 1 + end + out.castle = self.castle + out.toMove = self.toMove == 1 and 2 or 1 + + if isCastle == 1 then + local rank = piece == 11 and Rank1 or Rank8 + local colorOffset = piece - 11 + + out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA) + out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank)) + + out[piece] = (from < to and FileG or FileC):band(rank) + out.castle = out.castle:bandnot(rank) + out:updateCache() + return out + end + + if piece < 3 then + local dist = math.abs(to - from) + -- Pawn moved two squares, set ep square + if dist == 16 then + out.ep = Bitboard.some((from + to) / 2) + end + + -- Remove enpasent capture + if not tom:bandempty(self.ep) then + if piece == 1 then + out[2] = out[2]:bandnot(self.ep:down()) + end + if piece == 2 then + out[1] = out[1]:bandnot(self.ep:up()) + end + end + end + + if piece == 3 or piece == 4 then + out.castle = out.castle:set(from, 0) + end + + if piece > 10 then + local rank = piece == 11 and Rank1 or Rank8 + out.castle = out.castle:bandnot(rank) + end + + out[piece] = out[piece]:set(from, 0) + if promote == 0 then + out[piece] = out[piece]:set(to, 1) + else + out[promote] = out[promote]:set(to, 1) + end + if captured ~= 0 then + out[captured] = out[captured]:set(to, 0) + end + + out:updateCache() + return out +end + +Board.__index = Board +Board.__tostring = Board.toString +-- +-- Main +-- + +local failures = 0 +local function test(fen, ply, target) + local b = Board.fromFen(fen) + if b:fen() ~= fen then + print("FEN MISMATCH", fen, b:fen()) + failures = failures + 1 + return + end + + local found = b:perft(ply) + if found ~= target then + print(fen, "Found", found, "target", target) + failures = failures + 1 + for k,v in pairs(b:moveList()) do + print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1')) + end + --error("Test Failure") + else + print("OK", found, fen) + end +end + +-- From https://www.chessprogramming.org/Perft_Results +-- If interpreter, computers, or algorithm gets too fast +-- feel free to go deeper + +local testCases = {} +local function addTest(...) table.insert(testCases, {...}) end + +addTest(StartingFen, 3, 8902) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) + + +local function chess() + for k,v in ipairs(testCases) do + test(v[1],v[2],v[3]) + end +end + +bench.runCode(chess, "chess") diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua index 41d97bb8..ad0557b1 100644 --- a/bench/tests/shootout/scimark.lua +++ b/bench/tests/shootout/scimark.lua @@ -30,7 +30,7 @@ ------------------------------------------------------------------------------ ------------------------------------------------------------------------------ --- Modificatin to be compatible with Lua 5.3 +-- Modification to be compatible with Lua 5.3 ------------------------------------------------------------------------------ local bench = script and require(script.Parent.bench_support) or require("bench_support") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 07910a0a..44b8362d 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") local function target(a: number, b: string) return a + #b end local function d(a: n@1, b) - return target(a, b) + return target(a, b) end )"); @@ -1609,7 +1609,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: s@1) - return target(a, b) + return target(a, b) end )"); @@ -1622,7 +1622,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a:@1 @2, b) - return target(a, b) + return target(a, b) end )"); @@ -1640,7 +1640,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: @1)@2: number - return target(a, b) + return target(a, b) end )"); @@ -1682,7 +1682,7 @@ local x = target(function(a: n@1 local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end local x = target(function(a: n@1, b: @2) - return a + #b + return a + #b end) )"); @@ -1700,7 +1700,7 @@ end) local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a: n@1) - return a + return a end )"); @@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(...:n@1) - return a + return a end )"); @@ -1729,7 +1729,7 @@ end local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a:number, b:number, ...:@1) - return a + b + return a + b end )"); @@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") local function target(callback: () -> number) return callback() end local x = target(function(): n@1 - return 1 + return 1 end )"); @@ -1758,7 +1758,7 @@ end local function target(callback: () -> (number, number)) return callback() end local x = target(function(): (number, n@1 - return 1, 2 + return 1, 2 end )"); @@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" local function target(callback: () -> ...number) return callback() end local x = target(function(): ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); @@ -1787,7 +1787,7 @@ end local function target(callback: () -> ...number) return callback() end local x = target(function(): (number, number, ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 54a31a68..bbac3302 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf") local MaterialsListClass = {} function MaterialsListClass:_MakeToolTip(guiElement, text) - local function updateTooltipPosition() - self._tweakingTooltipFrame = 5 - end + local function updateTooltipPosition() + self._tweakingTooltipFrame = 5 + end - updateTooltipPosition() + updateTooltipPosition() end return MaterialsListClass @@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode") { CHECK_EQ("\n" + compileFunction(R"( function test() - for i=1,10 do + for i=1,10 do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2035,14 +2035,14 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - for i in ipairs(data) do + for i in ipairs(data) do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2068,17 +2068,17 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - while i < 5 do - local j + local i = 0 + while i < 5 do + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - end - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + end + return 0 end )", 1), @@ -2105,17 +2105,17 @@ RETURN R1 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - repeat - local j + local i = 0 + repeat + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - until i < 5 - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + until i < 5 + return 0 end )", 1), @@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ... local Table = {} Table.SubTable["Key"] = { - Key1 = Value1, - Key2 = Value2, - Key3 = Value3, - Key4 = true, + Key1 = Value1, + Key2 = Value2, + Key3 = Value3, + Key4 = true, } )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5a697a49..06b3c523 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression") runConformance("ifelseexpr.lua"); } +TEST_CASE("TagMethodError") +{ + ScopedFastFlag sff{"LuauCcallRestoreFix", true}; + + runConformance("tmerror.lua", [](lua_State* L) { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) { + CHECK(lua_isyieldable(L)); + }; + }); +} + TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index 9f874899..e55b5b0c 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -2,6 +2,9 @@ #pragma once #include +#include + +namespace std { inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { @@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) } template -std::ostream& operator<<(std::ostream& lhs, const std::optional& t) +auto operator<<(std::ostream& lhs, const std::optional& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types { if (t) return lhs << *t; else return lhs << "none"; } + +} // namespace std diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index a9ed139f..37f1b60b 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") { LintResult result = lint(R"(--!strict type InputData = { - id: number, - inputType: EnumItem, - inputState: EnumItem, - updated: number, - position: Vector3, - keyCode: EnumItem, - name: string + id: number, + inputType: EnumItem, + inputState: EnumItem, + updated: number, + position: Vector3, + keyCode: EnumItem, + name: string } )"); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 045f0230..f580604c 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") CHECK_EQ("{number | string}", toString(requireType("p"), {true})); } +/* + * We had a problem where all type aliases would be prototyped into a child scope that happened + * to have the same level. This caused a problem where, if a sibling function referred to that + * type alias in its type signature, it would erroneously be quantified away, even though it doesn't + * actually belong to the function. + * + * We solved this by ascribing a unique subLevel to each prototyped alias. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +{ + CheckResult result = check(R"( + --!strict + + local KeyPool = {} + + local function newkey(pool: KeyPool, index) + return {} + end + + function newKeyPool() + local pool = { + available = {} :: {Key}, + } + + return setmetatable(pool, KeyPool) + end + + export type KeyPool = typeof(newKeyPool()) + export type Key = typeof(newkey(newKeyPool(), 1)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * We keep a cache of type alias onto TypeVar to prevent infinite types from + * being constructed via recursive or corecursive aliases. We have to adjust + * the TypeLevels of those generic TypeVars so that the unifier doesn't think + * they have improperly leaked out of their scope. + */ +TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases") +{ + CheckResult result = check(R"( + type Array = {T} + type Exclude = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index e8974776..17e32e9f 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -359,7 +359,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_corrrectly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6da33a08..eabf7e65 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local b: Vector2? = nil local a = b.X + b.Z diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3a04a18f..581375a1 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } +TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +{ + ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; + ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; + + CheckResult result = check(R"( +local exports = {} +local nested = {} + +nested.name = function(t, k) + local a = t.x.y + return rawget(t, k) +end + +exports.nested = nested +return exports + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8bcb0242..419da8ad 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,12 +9,13 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; TEST_SUITE_BEGIN("ProvisionalTests"); -// These tests check for behavior that differes from the final behavior we'd +// These tests check for behavior that differs from the final behavior we'd // like to have. They serve to document the current state of the typechecker. // When making future improvements, its very likely these tests will break and // will need to be replaced. @@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string old_expected = R"( function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )"; - CHECK_EQ(expected, decorateWithTypes(code)); + + const std::string expected = R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean'then + local a1:boolean=a + elseif a.fn()then + local a2:{fn:()->(a,b...)}=a + end + end + )"; + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ(expected, decorateWithTypes(code)); + else + CHECK_EQ(old_expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") @@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { - ScopedFastInt sffi{"LuauTarjanChildLimit", 50}; - ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50}; + ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; + ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; CheckResult result = check(R"LUA( local Result diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 31739cdc..36dcaa95 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauOrPredicate) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index b7f0dc7b..f1451a81 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") REQUIRE_EQ(indexer.indexType, typeChecker.numberType); - REQUIRE(nullptr != get(indexer.indexResultType)); + REQUIRE(nullptr != get(follow(indexer.indexResultType))); } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index b75878b7..45381757 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement") TEST_CASE_FIXTURE(Fixture, "generic_function") { - CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)"); + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*typeChecker.numberType, *requireType("a")); @@ -406,7 +411,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string - for p in primes3() do print(p) end -- no errror + for p in primes3() do print(p) end -- no error )"); LUAU_REQUIRE_ERROR_COUNT(2, result); @@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") REQUIRE_EQ(2, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); std::vector fArgs = flatten(fType->argTypes).first; @@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") REQUIRE_EQ(6, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); } @@ -2549,7 +2554,7 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") --!strict local x = nil function f() g() end - -- make sure print(x) doen't get toposorted here, breaking the mutual block + -- make sure print(x) doesn't get toposorted here, breaking the mutual block function g() x = f end print(x) )"); @@ -2987,7 +2992,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indeces") +TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( local key @@ -3176,7 +3181,24 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK(acm->context == CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "ignored_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55, "" + end + + local a = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); } TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") @@ -3194,6 +3216,8 @@ TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); } TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") @@ -3823,10 +3847,10 @@ local T: any T = {} T.__index = T function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self + local self = {} + setmetatable(self, T) + self:construct(...) + return self end function T:construct(index) end @@ -4049,11 +4073,11 @@ function n:Clone() end local m = {} function m.a(x) - x:Clone() + x:Clone() end function m.b() - m.a(n) + m.a(n) end return m @@ -4374,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { - ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true); - // Simple direct arg to arg propagation CheckResult result = check(R"( type Table = { x: number, y: number } @@ -4385,7 +4407,7 @@ f(function(a) return a.x + a.y end) LUAU_REQUIRE_NO_ERRORS(result); - // An optional funciton is accepted, but since we already provide a function, nil can be ignored + // An optional function is accepted, but since we already provide a function, nil can be ignored result = check(R"( type Table = { x: number, y: number } local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end @@ -4413,7 +4435,7 @@ f(function(a: number, b, c) return c and a + b or b - a end) LUAU_REQUIRE_NO_ERRORS(result); - // Anonymous function has a varyadic pack + // Anonymous function has a variadic pack result = check(R"( type Table = { x: number, y: number } local function f(a: (Table) -> number) return a({x = 1, y = 2}) end @@ -4432,7 +4454,7 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); - // Infer from varyadic packs into elements + // Infer from variadic packs into elements result = check(R"( function f(a: (...number) -> number) return a(1, 2) end f(function(a, b) return a + b end) @@ -4440,7 +4462,7 @@ f(function(a, b) return a + b end) LUAU_REQUIRE_NO_ERRORS(result); - // Infer from varyadic packs into varyadic packs + // Infer from variadic packs into variadic packs result = check(R"( type Table = { x: number, y: number } function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end @@ -4662,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( @@ -4679,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1f4b63ef..1192a8ac 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -8,6 +8,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauQuantifyInPlace2); + using namespace Luau; struct TryUnifyFixture : Fixture @@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler}; + UnifierSharedState unifierState{&iceHandler}; + Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); + else + CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3e1dedd4..8dab2605 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") std::vector applyArgs = flatten(applyType->argTypes).first; REQUIRE_EQ(3, applyArgs.size()); - const FunctionTypeVar* fType = get(applyArgs[0]); + const FunctionTypeVar* fType = get(follow(applyArgs[0])); REQUIRE(fType != nullptr); - const FunctionTypeVar* gType = get(applyArgs[1]); + const FunctionTypeVar* gType = get(follow(applyArgs[1])); REQUIRE(gType != nullptr); std::vector gArgs = flatten(gType->argTypes).first; @@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") { CheckResult result = check(R"( local _ = function():((...any)->(...any),()->()) - return function() end, function() end + return function() end, function() end end for y in _() do end diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 037144e2..34c25a9f 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -52,7 +52,7 @@ TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign") { CheckResult result = check(R"( local a:number = 10 @@ -63,7 +63,7 @@ TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign2") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign2") { CheckResult result = check(R"( local a:number? = 10 @@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") { - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = {x: number} type B = {} @@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") TEST_CASE_FIXTURE(Fixture, "optional_union_members") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = { a = { x = 1, y = 2 }, b = 3 } type A = typeof(a) @@ -259,8 +255,6 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a.foo(x:number, y:number) return x + y end @@ -276,8 +270,6 @@ local c = b.foo(1, 2) TEST_CASE_FIXTURE(Fixture, "optional_union_methods") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a:foo(x:number, y:number) return x + y end @@ -310,8 +302,6 @@ return f() TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local b: A? = { x = 2 } @@ -327,8 +317,6 @@ local d = b.y TEST_CASE_FIXTURE(Fixture, "optional_index_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -341,8 +329,6 @@ local b = a[1] TEST_CASE_FIXTURE(Fixture, "optional_call_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = (number) -> number local a: A? = function(a) return -a end @@ -355,8 +341,6 @@ local b = a(4) TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local a: A? = { x = 2 } @@ -378,8 +362,6 @@ a.x = 2 TEST_CASE_FIXTURE(Fixture, "optional_length_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -392,9 +374,6 @@ local b = #a TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index a679e3fd..930c1a39 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); } +TEST_CASE("tagging_tables") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar ttv{TableTypeVar{}}; + CHECK(!Luau::hasTag(&ttv, "foo")); + Luau::attachTag(&ttv, "foo"); + CHECK(Luau::hasTag(&ttv, "foo")); +} + +TEST_CASE("tagging_classes") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + CHECK(!Luau::hasTag(&base, "foo")); + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); +} + +TEST_CASE("tagging_subclasses") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + + CHECK(!Luau::hasTag(&base, "foo")); + CHECK(!Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); + CHECK(Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&derived, "bar"); + CHECK(!Luau::hasTag(&base, "bar")); + CHECK(Luau::hasTag(&derived, "bar")); +} + +TEST_CASE("tagging_functions") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypePackVar empty{TypePack{}}; + TypeVar ftv{FunctionTypeVar{&empty, &empty}}; + CHECK(!Luau::hasTag(&ftv, "foo")); + Luau::attachTag(&ftv, "foo"); + CHECK(Luau::hasTag(&ftv, "foo")); +} + +TEST_CASE("tagging_props") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + Property prop{}; + CHECK(!Luau::hasTag(prop, "foo")); + Luau::attachTag(prop, "foo"); + CHECK(Luau::hasTag(prop, "foo")); +} + TEST_SUITE_END(); diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index 79f8d9c2..aac42c56 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -319,7 +319,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 73c3833d..75329642 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -185,7 +185,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index fd4b4de1..4263dfda 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -277,7 +277,7 @@ do assert(getmetatable(o) == tt) -- create new objects during GC local a = 'xuxu'..(10+3)..'joao', {} - ___Glob = o -- ressurect object! + ___Glob = o -- resurrect object! newproxy(o) -- creates a new one with same metatable print(">>> closing state " .. "<<<\n") end diff --git a/tests/conformance/locals.lua b/tests/conformance/locals.lua index cbe5f92d..2d8d004b 100644 --- a/tests/conformance/locals.lua +++ b/tests/conformance/locals.lua @@ -117,7 +117,7 @@ if rawget(_G, "querytab") then local t = querytab(a) for k,_ in pairs(a) do a[k] = nil end - collectgarbage() -- restore GC and collect dead fiels in `a' + collectgarbage() -- restore GC and collect dead fields in `a' for i=0,t-1 do local k = querytab(a, i) assert(k == nil or type(k) == 'number' or k == 'alo') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 5e8b9398..d5bca44f 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -172,7 +172,7 @@ end a = nil --- testing implicit convertions +-- testing implicit conversions local a,b = '10', '20' assert(a*b == 200 and a+b == 30 and a-b == -10 and a/b == 0.5 and -b == -20) diff --git a/tests/conformance/pm.lua b/tests/conformance/pm.lua index 9a113964..263759ac 100644 --- a/tests/conformance/pm.lua +++ b/tests/conformance/pm.lua @@ -21,9 +21,9 @@ a,b = string.find('alo', '') assert(a == 1 and b == 0) a,b = string.find('a\0o a\0o a\0o', 'a', 1) -- first position assert(a == 1 and b == 1) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the middle assert(a == 5 and b == 7) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the middle assert(a == 9 and b == 11) a,b = string.find('a\0a\0a\0a\0\0ab', '\0ab', 2); -- finds at the end assert(a == 9 and b == 11); diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.lua new file mode 100644 index 00000000..1ad4dd16 --- /dev/null +++ b/tests/conformance/tmerror.lua @@ -0,0 +1,15 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes + +-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly +-- called via pcall. +-- This test is meant to detect a regression in handling errors inside a tag method + +local testtable = {} +setmetatable(testtable, { __index = function() error("Error") end }) + +pcall(function() + testtable.missingmethod() +end) + +return('OK') diff --git a/tools/gdb-printers.py b/tools/gdb-printers.py index c711c5e2..017b9f95 100644 --- a/tools/gdb-printers.py +++ b/tools/gdb-printers.py @@ -11,9 +11,9 @@ class VariantPrinter: return type.name + " [" + str(value) + "]" def match_printer(val): - type = val.type.strip_typedefs() - if type.name and type.name.startswith('Luau::Variant<'): - return VariantPrinter(val) - return None + type = val.type.strip_typedefs() + if type.name and type.name.startswith('Luau::Variant<'): + return VariantPrinter(val) + return None gdb.pretty_printers.append(match_printer) From 82d74e6f73fa8d9a81c4c932c668543c08c27597 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 18:12:39 -0800 Subject: [PATCH 003/102] Sync to upstream/release/504 --- .gitignore | 1 + Analysis/include/Luau/Error.h | 13 +- Analysis/include/Luau/FileResolver.h | 2 +- Analysis/include/Luau/Transpiler.h | 3 +- Analysis/include/Luau/TypeInfer.h | 8 +- Analysis/include/Luau/TypePack.h | 20 +- Analysis/include/Luau/TypeVar.h | 45 +- Analysis/include/Luau/Unifiable.h | 3 - Analysis/include/Luau/Unifier.h | 2 +- Analysis/src/Autocomplete.cpp | 46 +- Analysis/src/BuiltinDefinitions.cpp | 244 +-------- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 22 +- Analysis/src/Error.cpp | 291 ++++++---- Analysis/src/Frontend.cpp | 9 +- Analysis/src/Linter.cpp | 10 +- Analysis/src/Module.cpp | 17 +- Analysis/src/Predicate.cpp | 4 - Analysis/src/RequireTracer.cpp | 8 +- Analysis/src/Substitution.cpp | 13 +- Analysis/src/ToString.cpp | 114 +--- Analysis/src/Transpiler.cpp | 117 +++- Analysis/src/TypeAttach.cpp | 41 +- Analysis/src/TypeInfer.cpp | 576 ++++++-------------- Analysis/src/TypePack.cpp | 8 +- Analysis/src/TypeVar.cpp | 70 ++- Analysis/src/Unifiable.cpp | 10 - Analysis/src/Unifier.cpp | 345 ++++++------ Ast/include/Luau/Parser.h | 1 - Ast/src/Parser.cpp | 51 +- CLI/Analyze.cpp | 27 +- CLI/Repl.cpp | 93 +++- CMakeLists.txt | 38 +- Compiler/include/Luau/Bytecode.h | 4 + Compiler/include/Luau/Compiler.h | 3 + Compiler/src/Compiler.cpp | 47 +- Makefile | 10 +- VM/include/lua.h | 4 + VM/src/lapi.cpp | 11 + VM/src/lbitlib.cpp | 42 ++ VM/src/lbuiltins.cpp | 54 +- VM/src/lgc.cpp | 164 +----- VM/src/lstrlib.cpp | 20 +- VM/src/ltable.cpp | 1 + VM/src/ltable.h | 1 - VM/src/ltablib.cpp | 8 - bench/tests/chess.lua | 80 +-- fuzz/luau.proto | 7 + fuzz/proto.cpp | 7 + fuzz/protoprint.cpp | 10 + tests/AstQuery.test.cpp | 1 - tests/Autocomplete.test.cpp | 3 - tests/Compiler.test.cpp | 126 +++++ tests/Conformance.test.cpp | 6 +- tests/Linter.test.cpp | 6 - tests/Parser.test.cpp | 43 +- tests/Predicate.test.cpp | 6 - tests/ToString.test.cpp | 9 - tests/Transpiler.test.cpp | 252 ++++++++- tests/TypeInfer.aliases.test.cpp | 3 - tests/TypeInfer.builtins.test.cpp | 8 - tests/TypeInfer.classes.test.cpp | 25 +- tests/TypeInfer.definitions.test.cpp | 3 - tests/TypeInfer.generics.test.cpp | 89 --- tests/TypeInfer.intersectionTypes.test.cpp | 39 ++ tests/TypeInfer.provisional.test.cpp | 7 +- tests/TypeInfer.refinements.test.cpp | 44 +- tests/TypeInfer.tables.test.cpp | 72 +++ tests/TypeInfer.test.cpp | 34 +- tests/TypeInfer.tryUnify.test.cpp | 5 - tests/TypeInfer.typePacks.cpp | 12 - tests/TypeInfer.unionTypes.test.cpp | 41 +- tests/TypeVar.test.cpp | 2 - tests/conformance/bitwise.lua | 16 + 73 files changed, 1734 insertions(+), 1843 deletions(-) diff --git a/.gitignore b/.gitignore index 0b2422ce..fa11b45b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ ^default.prof* ^fuzz-* ^luau$ +/.vs diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index ac6f13e9..9ee75004 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -8,11 +8,20 @@ namespace Luau { +struct TypeError; struct TypeMismatch { - TypeId wantedType; - TypeId givenType; + TypeMismatch() = default; + TypeMismatch(TypeId wantedType, TypeId givenType); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error); + + TypeId wantedType = nullptr; + TypeId givenType = nullptr; + + std::string reason; + std::shared_ptr error; bool operator==(const TypeMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index a05ec5e9..9b74fc12 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -53,7 +53,7 @@ struct FileResolver } // DEPRECATED APIS - // These are going to be removed with LuauNewRequireTracer + // These are going to be removed with LuauNewRequireTrace2 virtual bool moduleExists(const ModuleName& name) const = 0; virtual std::optional fromAstFragment(AstExpr* expr) const = 0; virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; diff --git a/Analysis/include/Luau/Transpiler.h b/Analysis/include/Luau/Transpiler.h index 817459fe..df01008c 100644 --- a/Analysis/include/Luau/Transpiler.h +++ b/Analysis/include/Luau/Transpiler.h @@ -18,6 +18,7 @@ struct TranspileResult std::string parseError; // Nonempty if the transpile failed }; +std::string toString(AstNode* node); void dump(AstNode* node); // Never fails on a well-formed AST @@ -25,6 +26,6 @@ std::string transpile(AstStatBlock& ast); std::string transpileWithTypes(AstStatBlock& block); // Only fails when parsing fails -TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}); +TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}, bool withTypes = false); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9d62fef0..306ac77d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -263,8 +263,6 @@ public: * */ TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); - // Removed by FFlag::LuauRankNTypes - TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location); // Replace any free types or type packs by `any`. // This is used when exporting types from modules, to make sure free types don't leak. @@ -298,8 +296,6 @@ private: // Produce a new free type var. TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); - TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false); - TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false); // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -326,10 +322,8 @@ private: TypePackId addTypePack(std::initializer_list&& ty); TypePackId freshTypePack(const ScopePtr& scope); TypePackId freshTypePack(TypeLevel level); - TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false); - TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false); - TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); + TypeId resolveType(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index d987d46c..e72808da 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAG(LuauAddMissingFollow) - namespace Luau { @@ -128,13 +126,10 @@ TypePack* asMutable(const TypePack* tp); template const T* get(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(tp->ty)); } @@ -142,13 +137,10 @@ const T* get(TypePackId tp) template T* getMutable(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(asMutable(tp)->ty)); } diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 9611e881..6bd7932d 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) -LUAU_FASTFLAG(LuauAddMissingFollow) namespace Luau { @@ -413,13 +412,17 @@ bool maybeGeneric(const TypeId ty); struct SingletonTypes { - const TypeId nilType = &nilType_; - const TypeId numberType = &numberType_; - const TypeId stringType = &stringType_; - const TypeId booleanType = &booleanType_; - const TypeId threadType = &threadType_; - const TypeId anyType = &anyType_; - const TypeId errorType = &errorType_; + const TypeId nilType; + const TypeId numberType; + const TypeId stringType; + const TypeId booleanType; + const TypeId threadType; + const TypeId anyType; + const TypeId errorType; + const TypeId optionalNumberType; + + const TypePackId anyTypePack; + const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; @@ -427,14 +430,6 @@ struct SingletonTypes private: std::unique_ptr arena; - TypeVar nilType_; - TypeVar numberType_; - TypeVar stringType_; - TypeVar booleanType_; - TypeVar threadType_; - TypeVar anyType_; - TypeVar errorType_; - TypeId makeStringMetatable(); }; @@ -472,13 +467,10 @@ TypeVar* asMutable(TypeId ty); template const T* get(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&tv->ty); } @@ -486,13 +478,10 @@ const T* get(TypeId tv) template T* getMutable(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&asMutable(tv)->ty); } diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 10dbf333..c2e07e46 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -63,12 +63,9 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); - Free(TypeLevel level, bool DEPRECATED_canBeGeneric); int index; TypeLevel level; - // Removed by FFlag::LuauRankNTypes - bool DEPRECATED_canBeGeneric = false; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 56632e33..be0aadd0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -87,7 +87,6 @@ private: void tryUnifyWithAny(TypePackId any, TypePackId ty); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - std::optional findMetatableEntry(TypeId type, std::string entry); public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" @@ -102,6 +101,7 @@ private: bool isNonstrictMode() const; void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); + void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType); [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 3c43c808..1c94bb68 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,7 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) @@ -369,20 +368,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId while (iter != endIter) { - if (FFlag::LuauAddMissingFollow) - { - if (isNil(*iter)) - ++iter; - else - break; - } + if (isNil(*iter)) + ++iter; else - { - if (auto primTy = Luau::get(*iter); primTy && primTy->type == PrimitiveTypeVar::NilType) - ++iter; - else - break; - } + break; } if (iter == endIter) @@ -397,21 +386,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - if (FFlag::LuauAddMissingFollow) + if (isNil(*iter)) { - if (isNil(*iter)) - { - ++iter; - continue; - } - } - else - { - if (auto innerPrimTy = Luau::get(*iter); innerPrimTy && innerPrimTy->type == PrimitiveTypeVar::NilType) - { - ++iter; - continue; - } + ++iter; + continue; } autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); @@ -496,7 +474,7 @@ static bool canSuggestInferredType(ScopePtr scope, TypeId ty) return false; // No syntax for unnamed tables with a metatable - if (const MetatableTypeVar* mtv = get(ty)) + if (get(ty)) return false; if (const TableTypeVar* ttv = get(ty)) @@ -688,7 +666,7 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n TypeId expectedType = follow(*it); - if (const FunctionTypeVar* ftv = get(expectedType)) + if (get(expectedType)) return true; if (const IntersectionTypeVar* itv = get(expectedType)) @@ -1519,10 +1497,10 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName return {}; TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1550,7 +1528,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->commentLocations = std::move(result.commentLocations); TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index f6f2363c..62a06a3c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,10 +8,7 @@ #include -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -185,25 +182,11 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; - TypeId stringType = typeChecker.stringType; - TypeId threadType = typeChecker.threadType; - TypeId anyType = typeChecker.anyType; TypeArena& arena = typeChecker.globalTypes; - TypeId optionalNumber = makeOption(typeChecker, arena, numberType); - TypeId optionalString = makeOption(typeChecker, arena, stringType); - TypeId optionalBoolean = makeOption(typeChecker, arena, booleanType); - - TypeId stringOrNumber = makeUnion(arena, {stringType, numberType}); - - TypePackId emptyPack = arena.addTypePack({}); TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneStringPack = arena.addTypePack({stringType}); TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - TypePackId oneAnyPack = arena.addTypePack({anyType}); - - TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); @@ -215,8 +198,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - TypeId stringToAnyMap = arena.addType(TableTypeVar{{}, TableIndexer(stringType, anyType), typeChecker.globalScope->level}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); @@ -236,8 +217,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); } - TypeId anyFunction = arena.addType(FunctionTypeVar{anyTypePack, anyTypePack}); - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); @@ -252,222 +231,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) - { - TableTypeVar::Props debugLib{ - {"info", {makeIntersection(arena, - { - arena.addType(FunctionTypeVar{arena.addTypePack({typeChecker.threadType, numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({anyFunction, stringType}), anyTypePack}), - })}}, - {"traceback", {makeIntersection(arena, - { - makeFunction(arena, std::nullopt, {optionalString, optionalNumber}, {stringType}), - makeFunction(arena, std::nullopt, {typeChecker.threadType, optionalString, optionalNumber}, {stringType}), - })}}, - }; - - assignPropDocumentationSymbols(debugLib, "@luau/global/debug"); - addGlobalBinding(typeChecker, "debug", - arena.addType(TableTypeVar{debugLib, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}), "@luau"); - - TableTypeVar::Props utf8Lib = { - {"char", {arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneStringPack})}}, // FIXME - {"charpattern", {stringType}}, - {"codes", {makeFunction(arena, std::nullopt, {stringType}, - {makeFunction(arena, std::nullopt, {stringType, numberType}, {numberType, numberType}), stringType, numberType})}}, - {"codepoint", - {arena.addType(FunctionTypeVar{arena.addTypePack({stringType, optionalNumber, optionalNumber}), listOfAtLeastOneNumber})}}, // FIXME - {"len", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {optionalNumber, numberType})}}, - {"offset", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {numberType})}}, - {"nfdnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - {"graphemes", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, - {makeFunction(arena, std::nullopt, {}, {numberType, numberType})})}}, - {"nfcnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - }; - - assignPropDocumentationSymbols(utf8Lib, "@luau/global/utf8"); - addGlobalBinding( - typeChecker, "utf8", arena.addType(TableTypeVar{utf8Lib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId optionalV = makeOption(typeChecker, arena, genericV); - - TypeId arrayOfV = arena.addType(TableTypeVar{{}, TableIndexer(numberType, genericV), typeChecker.globalScope->level}); - - TypePackId unpackArgsPack = arena.addTypePack(TypePack{{arrayOfV, optionalNumber, optionalNumber}}); - TypePackId unpackReturnPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypeId unpackFunc = arena.addType(FunctionTypeVar{{genericV}, {}, unpackArgsPack, unpackReturnPack}); - - TypeId packResult = arena.addType(TableTypeVar{ - TableTypeVar::Props{{"n", {numberType}}}, TableIndexer{numberType, numberType}, typeChecker.globalScope->level, TableState::Sealed}); - TypePackId packArgsPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypePackId packReturnPack = arena.addTypePack(TypePack{{packResult}}); - - TypeId comparator = makeFunction(arena, std::nullopt, {genericV, genericV}, {booleanType}); - TypeId optionalComparator = makeOption(typeChecker, arena, comparator); - - TypeId packFn = arena.addType(FunctionTypeVar(packArgsPack, packReturnPack)); - - TableTypeVar::Props tableLib = { - {"concat", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalString, optionalNumber, optionalNumber}, {stringType})}}, - {"insert", {makeIntersection(arena, {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV}, {}), - makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, genericV}, {})})}}, - {"maxn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"remove", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalNumber}, {optionalV})}}, - {"sort", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalComparator}, {})}}, - {"create", {makeFunction(arena, std::nullopt, {genericV}, {}, {numberType, optionalV}, {arrayOfV})}}, - {"find", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV, optionalNumber}, {optionalNumber})}}, - - {"unpack", {unpackFunc}}, // FIXME - {"pack", {packFn}}, - - // Lua 5.0 compat - {"getn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"foreach", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, - {mapOfKtoV, makeFunction(arena, std::nullopt, {genericK, genericV}, {})}, {})}}, - {"foreachi", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, makeFunction(arena, std::nullopt, {genericV}, {})}, {})}}, - - // backported from Lua 5.3 - {"move", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, numberType, numberType, arrayOfV}, {})}}, - - // added in Luau (borrowed from LuaJIT) - {"clear", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {})}}, - - {"freeze", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {mapOfKtoV})}}, - {"isfrozen", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(tableLib, "@luau/global/table"); - addGlobalBinding( - typeChecker, "table", arena.addType(TableTypeVar{tableLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TableTypeVar::Props coroutineLib = { - {"create", {makeFunction(arena, std::nullopt, {anyFunction}, {threadType})}}, - {"resume", {arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{threadType}, anyTypePack}), anyTypePack})}}, - {"running", {makeFunction(arena, std::nullopt, {}, {threadType})}}, - {"status", {makeFunction(arena, std::nullopt, {threadType}, {stringType})}}, - {"wrap", {makeFunction( - arena, std::nullopt, {anyFunction}, {anyType})}}, // FIXME this technically returns a function, but we can't represent this - // atm since it can be called with different arg types at different times - {"yield", {arena.addType(FunctionTypeVar{anyTypePack, anyTypePack})}}, - {"isyieldable", {makeFunction(arena, std::nullopt, {}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(coroutineLib, "@luau/global/coroutine"); - addGlobalBinding(typeChecker, "coroutine", - arena.addType(TableTypeVar{coroutineLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId genericT = arena.addType(GenericTypeVar{"T"}); - TypeId genericR = arena.addType(GenericTypeVar{"R"}); - - // assert returns all arguments - TypePackId assertArgs = arena.addTypePack({genericT, optionalString}); - TypePackId assertRets = arena.addTypePack({genericT}); - addGlobalBinding(typeChecker, "assert", arena.addType(FunctionTypeVar{assertArgs, assertRets}), "@luau"); - - addGlobalBinding(typeChecker, "print", arena.addType(FunctionTypeVar{anyTypePack, emptyPack}), "@luau"); - - addGlobalBinding(typeChecker, "type", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding(typeChecker, "typeof", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - - addGlobalBinding(typeChecker, "error", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {}), "@luau"); - - addGlobalBinding(typeChecker, "tostring", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding( - typeChecker, "tonumber", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {numberType}), "@luau"); - - addGlobalBinding( - typeChecker, "rawequal", makeFunction(arena, std::nullopt, {genericT, genericR}, {}, {genericT, genericR}, {booleanType}), "@luau"); - addGlobalBinding( - typeChecker, "rawget", makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK}, {genericV}), "@luau"); - addGlobalBinding(typeChecker, "rawset", - makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK, genericV}, {mapOfKtoV}), "@luau"); - - TypePackId genericTPack = arena.addTypePack({genericT}); - TypePackId genericRPack = arena.addTypePack({genericR}); - TypeId genericArgsToReturnFunction = arena.addType( - FunctionTypeVar{{genericT, genericR}, {}, arena.addTypePack(TypePack{{}, genericTPack}), arena.addTypePack(TypePack{{}, genericRPack})}); - - TypeId setfenvArgType = makeUnion(arena, {numberType, genericArgsToReturnFunction}); - TypeId setfenvReturnType = makeOption(typeChecker, arena, genericArgsToReturnFunction); - addGlobalBinding(typeChecker, "setfenv", makeFunction(arena, std::nullopt, {setfenvArgType, stringToAnyMap}, {setfenvReturnType}), "@luau"); - - TypePackId ipairsArgsTypePack = arena.addTypePack({arrayOfV}); - - TypeId ipairsNextFunctionType = arena.addType( - FunctionTypeVar{{genericK, genericV}, {}, arena.addTypePack({arrayOfV, numberType}), arena.addTypePack({numberType, genericV})}); - - // ipairs returns 'next, Array, 0' so we would need type-level primitives and change to - // again, we have a direct reference to 'next' because ipairs returns it - // ipairs(t: Array) -> ((Array) -> (number, V), Array, 0) - TypePackId ipairsReturnTypePack = arena.addTypePack(TypePack{{ipairsNextFunctionType, arrayOfV, numberType}}); - - // ipairs(t: Array) -> ((Array) -> (number, V), Array, number) - addGlobalBinding(typeChecker, "ipairs", arena.addType(FunctionTypeVar{{genericV}, {}, ipairsArgsTypePack, ipairsReturnTypePack}), "@luau"); - - TypePackId pcallArg0FnArgs = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId pcallArg0FnRet = arena.addTypePack(TypePackVar{GenericTypeVar{"R"}}); - TypeId pcallArg0 = arena.addType(FunctionTypeVar{pcallArg0FnArgs, pcallArg0FnRet}); - TypePackId pcallArgsTypePack = arena.addTypePack(TypePack{{pcallArg0}, pcallArg0FnArgs}); - - TypePackId pcallReturnTypePack = arena.addTypePack(TypePack{{booleanType}, pcallArg0FnRet}); - - // pcall(f: (A...) -> R..., args: A...) -> boolean, R... - addGlobalBinding(typeChecker, "pcall", - arena.addType(FunctionTypeVar{{}, {pcallArg0FnArgs, pcallArg0FnRet}, pcallArgsTypePack, pcallReturnTypePack}), "@luau"); - - // errors thrown by the function 'f' are propagated onto the function 'err' that accepts it. - // and either 'f' or 'err' are valid results of this xpcall - // if 'err' did throw an error, then it returns: false, "error in error handling" - // TODO: the above is not represented (nor representable) in the type annotation below. - // - // The real type of xpcall is as such: (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, - // R2...) - TypePackId genericAPack = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId genericR1Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R1"}}); - TypePackId genericR2Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R2"}}); - - TypeId genericE = arena.addType(GenericTypeVar{"E"}); - - TypeId xpcallFArg = arena.addType(FunctionTypeVar{genericAPack, genericR1Pack}); - TypeId xpcallErrArg = arena.addType(FunctionTypeVar{arena.addTypePack({genericE}), genericR2Pack}); - - TypePackId xpcallArgsPack = arena.addTypePack({{xpcallFArg, xpcallErrArg}, genericAPack}); - TypePackId xpcallRetPack = arena.addTypePack({{booleanType}, genericR1Pack}); // FIXME - - addGlobalBinding(typeChecker, "xpcall", - arena.addType(FunctionTypeVar{{genericE}, {genericAPack, genericR1Pack, genericR2Pack}, xpcallArgsPack, xpcallRetPack}), "@luau"); - - addGlobalBinding(typeChecker, "unpack", unpackFunc, "@luau"); - - TypePackId selectArgsTypePack = arena.addTypePack(TypePack{ - {stringOrNumber}, - anyTypePack // FIXME? select() is tricky. - }); - - addGlobalBinding(typeChecker, "select", arena.addType(FunctionTypeVar{selectArgsTypePack, anyTypePack}), "@luau"); - - // TODO: not completely correct. loadstring's return type should be a function or (nil, string) - TypeId loadstringFunc = arena.addType(FunctionTypeVar{anyTypePack, oneAnyPack}); - - addGlobalBinding(typeChecker, "loadstring", - makeFunction(arena, std::nullopt, {stringType, optionalString}, - { - makeOption(typeChecker, arena, loadstringFunc), - makeOption(typeChecker, arena, stringType), - }), - "@luau"); - - // a userdata object is "roughly" the same as a sealed empty table - // except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - // another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - // setmetatable. - // TODO: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - TypeId sealedTable = arena.addType(TableTypeVar(TableState::Sealed, typeChecker.globalScope->level)); - addGlobalBinding(typeChecker, "newproxy", makeFunction(arena, std::nullopt, {optionalBoolean}, {sealedTable}), "@luau"); - } - // next(t: Table, i: K | nil) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", @@ -475,8 +238,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = (FFlag::LuauRankNTypes ? arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}) - : getGlobalBinding(typeChecker, "next")); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); // NOTE we are missing 'i: K | nil' argument in the first return types' argument. @@ -711,7 +473,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0]; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1e91561a..96703ef1 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,9 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) - namespace Luau { @@ -19,6 +16,8 @@ declare bit32: { bnot: (number) -> number, extract: (number, number, number?) -> number, replace: (number, number, number, number?) -> number, + countlz: (number) -> number, + countrz: (number) -> number, } declare math: { @@ -103,15 +102,6 @@ declare _VERSION: string declare function gcinfo(): number -)BUILTIN_SRC"; - -std::string getBuiltinDefinitionSource() -{ - std::string src = kBuiltinDefinitionLuaSrc; - - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - src += R"( declare function print(...: T...) declare function type(value: T): string @@ -208,10 +198,12 @@ std::string getBuiltinDefinitionSource() -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V - )"; - } - return src; +)BUILTIN_SRC"; + +std::string getBuiltinDefinitionSource() +{ + return kBuiltinDefinitionLuaSrc; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 04d91444..46ff2c72 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -94,8 +94,23 @@ struct ErrorConverter { std::string operator()(const Luau::TypeMismatch& tm) const { - ToStringOptions opts; - return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'"; + std::string result = "Type '" + Luau::toString(tm.givenType) + "' could not be converted into '" + Luau::toString(tm.wantedType) + "'"; + + if (tm.error) + { + result += "\ncaused by:\n "; + + if (!tm.reason.empty()) + result += tm.reason + ". "; + + result += Luau::toString(*tm.error); + } + else if (!tm.reason.empty()) + { + result += "; " + tm.reason; + } + + return result; } std::string operator()(const Luau::UnknownSymbol& e) const @@ -478,9 +493,36 @@ struct InvalidNameChecker } }; +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType) + : wantedType(wantedType) + , givenType(givenType) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) + , error(std::make_shared(std::move(error))) +{ +} + bool TypeMismatch::operator==(const TypeMismatch& rhs) const { - return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType; + if (!!error != !!rhs.error) + return false; + + if (error && !(*error == *rhs.error)) + return false; + + return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason; } bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const @@ -690,130 +732,141 @@ bool containsParseErrorName(const TypeError& error) return Luau::visit(InvalidNameChecker{}, error.data); } -void copyErrors(ErrorVec& errors, struct TypeArena& destArena) +template +void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - auto clone = [&](auto&& ty) { return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); }; auto visitErrorData = [&](auto&& e) { - using T = std::decay_t; + copyError(e, destArena, seenTypes, seenTypePacks); + }; - if constexpr (false) - { - } - else if constexpr (std::is_same_v) - { - e.wantedType = clone(e.wantedType); - e.givenType = clone(e.givenType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.typeFun = clone(e.typeFun); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.expectedReturnType = clone(e.expectedReturnType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.superType = clone(e.superType); - e.subType = clone(e.subType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.optional = clone(e.optional); - } - else if constexpr (std::is_same_v) - { - e.type = clone(e.type); + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + { + e.wantedType = clone(e.wantedType); + e.givenType = clone(e.givenType); - for (auto& ty : e.missing) - ty = clone(ty); - } - else - static_assert(always_false_v, "Non-exhaustive type switch"); + if (e.error) + visit(visitErrorData, e.error->data); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.typeFun = clone(e.typeFun); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.expectedReturnType = clone(e.expectedReturnType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.superType = clone(e.superType); + e.subType = clone(e.subType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.optional = clone(e.optional); + } + else if constexpr (std::is_same_v) + { + e.type = clone(e.type); + + for (auto& ty : e.missing) + ty = clone(ty); + } + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +void copyErrors(ErrorVec& errors, TypeArena& destArena) +{ + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + auto visitErrorData = [&](auto&& e) { + copyError(e, destArena, seenTypes, seenTypePacks); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5e7af50c..2f411274 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,11 +18,10 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau @@ -415,7 +414,7 @@ CheckResult Frontend::check(const ModuleName& name) // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel) + if (options.typecheckTwice) { ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; @@ -897,7 +896,7 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; const ModuleInfo* info = exprs.find(&pathExpr); - if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) + if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty())) return std::nullopt; return *info; @@ -914,7 +913,7 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - if (FFlag::LuauNewRequireTrace) + if (FFlag::LuauNewRequireTrace2) return frontend->sourceNodes.count(moduleName) != 0; else return frontend->fileResolver->moduleExists(moduleName); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index bff947a5..1a5b24fe 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,9 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) -LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) - namespace Luau { @@ -1110,10 +1107,7 @@ private: if (g && g->name == "type") { - if (FFlag::LuauLinterUnknownTypeVectorAware) - validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); - else - validateType(arg, {Kind_Primitive}, "primitive type"); + validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); } else if (g && g->name == "typeof") { @@ -2146,7 +2140,7 @@ private: "wrap it in parentheses to silence"); } - if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + if (func->index == "move" && node->args.size >= 4) { // table.move(t, 0, _, _) if (isConstant(args[1], 0.0)) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 2fd95896..880ffd2e 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) @@ -290,9 +289,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) for (TypePackId genericPack : t.genericPacks) ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ftv->tags = t.tags; - + ftv->tags = t.tags; ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); ftv->argNames = t.argNames; ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -319,12 +316,7 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; for (const auto& [name, prop] : t.props) - { - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; - } + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), @@ -379,10 +371,7 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 25e63bff..848627cf 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -3,8 +3,6 @@ #include "Luau/Ast.h" -LUAU_FASTFLAG(LuauOrPredicate) - namespace Luau { @@ -60,8 +58,6 @@ std::string toString(const LValue& lvalue) void merge(RefinementMap& l, const RefinementMap& r, std::function f) { - LUAU_ASSERT(FFlag::LuauOrPredicate); - auto itL = l.begin(); auto itR = r.begin(); while (itL != l.end() && itR != r.end()) diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 95910b56..b72f53f9 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,7 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) -LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false) namespace Luau { @@ -19,7 +19,7 @@ struct RequireTracerOld : AstVisitor : fileResolver(fileResolver) , currentModuleName(currentModuleName) { - LUAU_ASSERT(!FFlag::LuauNewRequireTrace); + LUAU_ASSERT(!FFlag::LuauNewRequireTrace2); } FileResolver* const fileResolver; @@ -188,7 +188,7 @@ struct RequireTracer : AstVisitor , currentModuleName(currentModuleName) , locals(nullptr) { - LUAU_ASSERT(FFlag::LuauNewRequireTrace); + LUAU_ASSERT(FFlag::LuauNewRequireTrace2); } bool visit(AstExprTypeAssertion* expr) override @@ -332,7 +332,7 @@ struct RequireTracer : AstVisitor RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - if (FFlag::LuauNewRequireTrace) + if (FFlag::LuauNewRequireTrace2) { RequireTraceResult result; RequireTracer tracer{result, fileResolver, currentModuleName}; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index d861eb3d..ca2b30f5 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,8 +8,6 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -19,7 +17,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { ty = follow(ty); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; if (const FunctionTypeVar* ftv = get(ty)) @@ -68,7 +66,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { tp = follow(tp); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; if (const TypePack* tpp = get(tp)) @@ -399,8 +397,7 @@ TypeId Substitution::clone(TypeId ty) if (FFlag::LuauTypeAliasPacks) clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - clone.tags = ttv->tags; + clone.tags = ttv->tags; result = addType(std::move(clone)); } else if (const MetatableTypeVar* mtv = get(ty)) @@ -486,7 +483,7 @@ void Substitution::replaceChildren(TypeId ty) { ty = follow(ty); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; if (FunctionTypeVar* ftv = getMutable(ty)) @@ -535,7 +532,7 @@ void Substitution::replaceChildren(TypePackId tp) { tp = follow(tp); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; if (TypePack* tpp = getMutable(tp)) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index cd8180db..885fd489 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -237,15 +236,6 @@ struct TypeVarStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tv)) - { - state.emit(state.getName(tv)); - return; - } - } - Luau::visit( [this, tv](auto&& t) { return (*this)(tv, t); @@ -316,11 +306,7 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; - - if (FFlag::LuauAddMissingFollow) - state.emit(state.getName(ty)); - else - state.emit(""); + state.emit(state.getName(ty)); } void operator()(TypeId, const BoundTypeVar& btv) @@ -724,16 +710,6 @@ struct TypePackStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tp)) - { - state.emit(state.getName(tp)); - state.emit("..."); - return; - } - } - auto it = state.cycleTpNames.find(tp); if (it != state.cycleTpNames.end()) { @@ -821,16 +797,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; - - if (FFlag::LuauAddMissingFollow) - { - state.emit(state.getName(tp)); - state.emit("..."); - } - else - { - state.emit(""); - } + state.emit(state.getName(tp)); + state.emit("..."); } void operator()(TypePackId, const BoundTypePack& btv) @@ -864,23 +832,15 @@ static void assignCycleNames(const std::unordered_set& cycles, const std std::string name; // TODO: use the stringified type list if there are no cycles - if (FFlag::LuauInstantiatedTypeParamRecursion) + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - { - // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) - cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; + // If we have a cycle type in type parameters, assign a cycle name for this named table + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) + cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; - continue; - } - } - else - { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - continue; + continue; } name = "t" + std::to_string(nextIndex); @@ -912,58 +872,6 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) ToStringResult result; - if (!FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) - { - if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) - { - if (ttv->syntheticName) - result.invalid = true; - - // If scope if provided, add module name and check visibility - if (ttv->name && opts.scope) - { - auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name); - - if (!success) - result.invalid = true; - - if (moduleName) - result.name = format("%s.", moduleName->c_str()); - } - - result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - - if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) - return result; - - std::vector params; - for (TypeId tp : ttv->instantiatedTypeParams) - params.push_back(toString(tp)); - - if (FFlag::LuauTypeAliasPacks) - { - // Doesn't preserve grouping of multiple type packs - // But this is under a parent block of code that is being removed later - for (TypePackId tp : ttv->instantiatedTypePackParams) - { - std::string content = toString(tp); - - if (!content.empty()) - params.push_back(std::move(content)); - } - } - - result.name += "<" + join(params, ", ") + ">"; - return result; - } - else if (auto mtv = get(ty); mtv && mtv->syntheticName) - { - result.invalid = true; - result.name = *mtv->syntheticName; - return result; - } - } - StringifierState state{opts, result, opts.nameMap}; std::unordered_set cycles; @@ -975,7 +883,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) TypeVarStringifier tvs{state}; - if (FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) + if (!opts.exhaustive) { if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) { diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1b83ccdc..7d880af4 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace @@ -97,9 +96,6 @@ struct Writer { virtual ~Writer() {} - virtual void begin() {} - virtual void end() {} - virtual void advance(const Position&) = 0; virtual void newline() = 0; virtual void space() = 0; @@ -131,6 +127,7 @@ struct StringWriter : Writer if (pos.column < newPos.column) write(std::string(newPos.column - pos.column, ' ')); } + void maybeSpace(const Position& newPos, int reserve) override { if (pos.column + reserve < newPos.column) @@ -279,11 +276,14 @@ struct Printer writer.identifier(func->index.value); } - void visualizeTypePackAnnotation(const AstTypePack& annotation) + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { + advance(annotation.location.begin); if (const AstTypePackVariadic* variadicTp = annotation.as()) { - writer.symbol("..."); + if (!forVarArg) + writer.symbol("..."); + visualizeTypeAnnotation(*variadicTp->variadicType); } else if (const AstTypePackGeneric* genericTp = annotation.as()) @@ -293,6 +293,7 @@ struct Printer } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { + LUAU_ASSERT(!forVarArg); visualizeTypeList(explicitTp->typeList, true); } else @@ -317,7 +318,7 @@ struct Printer // Only variadic tail if (list.types.size == 0) { - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } else { @@ -345,7 +346,7 @@ struct Printer if (list.tailType) { writer.symbol(","); - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } writer.symbol(")"); @@ -542,6 +543,7 @@ struct Printer case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: writer.maybeSpace(a->right->location.begin, 2); + writer.symbol(toString(a->op)); break; case AstExprBinary::Concat: case AstExprBinary::CompareNe: @@ -550,19 +552,35 @@ struct Printer case AstExprBinary::CompareGe: case AstExprBinary::Or: writer.maybeSpace(a->right->location.begin, 3); + writer.keyword(toString(a->op)); break; case AstExprBinary::And: writer.maybeSpace(a->right->location.begin, 4); + writer.keyword(toString(a->op)); break; } - writer.symbol(toString(a->op)); - visualize(*a->right); } else if (const auto& a = expr.as()) { visualize(*a->expr); + + if (writeTypes) + { + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualize(*a->condition); + writer.keyword("then"); + visualize(*a->trueExpr); + writer.keyword("else"); + visualize(*a->falseExpr); } else if (const auto& a = expr.as()) { @@ -769,24 +787,31 @@ struct Printer switch (a->op) { case AstExprBinary::Add: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("+="); break; case AstExprBinary::Sub: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("-="); break; case AstExprBinary::Mul: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("*="); break; case AstExprBinary::Div: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; case AstExprBinary::Mod: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); break; case AstExprBinary::Pow: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("^="); break; case AstExprBinary::Concat: + writer.maybeSpace(a->value->location.begin, 3); writer.symbol("..="); break; default: @@ -874,7 +899,7 @@ struct Printer void visualizeFunctionBody(AstExprFunction& func) { - if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0)) + if (func.generics.size > 0 || func.genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -913,12 +938,13 @@ struct Printer if (func.vararg) { comma(); + advance(func.varargLocation.begin); writer.symbol("..."); if (func.varargAnnotation) { writer.symbol(":"); - visualizeTypePackAnnotation(*func.varargAnnotation); + visualizeTypePackAnnotation(*func.varargAnnotation, true); } } @@ -980,8 +1006,14 @@ struct Printer advance(typeAnnotation.location.begin); if (const auto& a = typeAnnotation.as()) { + if (a->hasPrefix) + { + writer.write(a->prefix.value); + writer.symbol("."); + } + writer.write(a->name.value); - if (a->parameters.size > 0) + if (a->parameters.size > 0 || a->hasParameterList) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -992,7 +1024,7 @@ struct Printer if (o.type) visualizeTypeAnnotation(*o.type); else - visualizeTypePackAnnotation(*o.typePack); + visualizeTypePackAnnotation(*o.typePack, false); } writer.symbol(">"); @@ -1000,7 +1032,7 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -1075,7 +1107,16 @@ struct Printer auto rta = r->as(); if (rta && rta->name == "nil") { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + writer.symbol("?"); return; } @@ -1089,7 +1130,15 @@ struct Printer writer.symbol("|"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (const auto& a = typeAnnotation.as()) @@ -1102,7 +1151,15 @@ struct Printer writer.symbol("&"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (typeAnnotation.is()) @@ -1116,31 +1173,27 @@ struct Printer } }; -void dump(AstNode* node) +std::string toString(AstNode* node) { StringWriter writer; + writer.pos = node->location.begin; + Printer printer(writer); printer.writeTypes = true; if (auto statNode = dynamic_cast(node)) - { printer.visualize(*statNode); - printf("%s\n", writer.str().c_str()); - } else if (auto exprNode = dynamic_cast(node)) - { printer.visualize(*exprNode); - printf("%s\n", writer.str().c_str()); - } else if (auto typeNode = dynamic_cast(node)) - { printer.visualizeTypeAnnotation(*typeNode); - printf("%s\n", writer.str().c_str()); - } - else - { - printf("Can't dump this node\n"); - } + + return writer.str(); +} + +void dump(AstNode* node) +{ + printf("%s\n", toString(node).c_str()); } std::string transpile(AstStatBlock& block) @@ -1149,6 +1202,7 @@ std::string transpile(AstStatBlock& block) Printer(writer).visualizeBlock(block); return writer.str(); } + std::string transpileWithTypes(AstStatBlock& block) { StringWriter writer; @@ -1158,7 +1212,7 @@ std::string transpileWithTypes(AstStatBlock& block) return writer.str(); } -TranspileResult transpile(std::string_view source, ParseOptions options) +TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { auto allocator = Allocator{}; auto names = AstNameTable{allocator}; @@ -1176,6 +1230,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options) if (!parseResult.root) return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; + if (withTypes) + return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root)}; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 49f8e0ca..11aa7b39 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) @@ -203,39 +202,23 @@ public: return allocator->alloc(Location(), std::nullopt, AstName("")); AstArray generics; - if (FFlag::LuauGenericFunctions) + generics.size = ftv.generics.size(); + generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + size_t numGenerics = 0; + for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { - generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); - size_t i = 0; - for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) - { - if (auto gtv = get(*it)) - generics.data[i++] = AstName(gtv->name.c_str()); - } - } - else - { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + generics.data[numGenerics++] = AstName(gtv->name.c_str()); } AstArray genericPacks; - if (FFlag::LuauGenericFunctions) + genericPacks.size = ftv.genericPacks.size(); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + size_t numGenericPacks = 0; + for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); - size_t i = 0; - for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) - { - if (auto gtv = get(*it)) - genericPacks.data[i++] = AstName(gtv->name.c_str()); - } - } - else - { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 38e2e527..8fad1af9 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,29 +22,19 @@ LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) -LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) -LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) -LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) -LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -222,9 +212,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) - , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) - , anyTypePack(globalTypes.addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}, true})) - , errorTypePack(globalTypes.addTypePack(TypePackVar{Unifiable::Error{}})) + , optionalNumberType(singletonTypes.optionalNumberType) + , anyTypePack(singletonTypes.anyTypePack) + , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -251,10 +241,8 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (module.cyclic) moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); - else if (FFlag::LuauRankNTypes) - moduleScope->returnType = freshTypePack(moduleScope); else - moduleScope->returnType = DEPRECATED_freshTypePack(moduleScope, true); + moduleScope->returnType = freshTypePack(moduleScope); moduleScope->varargPack = anyTypePack; @@ -268,7 +256,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona checkBlock(moduleScope, *module.root); - if (get(FFlag::LuauAddMissingFollow ? follow(moduleScope->returnType) : moduleScope->returnType)) + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); @@ -326,7 +314,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) check(scope, *typealias); else if (auto global = program.as()) { - TypeId globalType = (FFlag::LuauRankNTypes ? resolveType(scope, *global->type) : resolveType(scope, *global->type, true)); + TypeId globalType = resolveType(scope, *global->type); Name globalName(global->name.value); currentModule->declaredGlobals[globalName] = globalType; @@ -494,7 +482,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std Name name = typealias->name.value; TypeId type = bindings[name].type; - if (get(FFlag::LuauAddMissingFollow ? follow(type) : type)) + if (get(follow(type))) { *asMutable(type) = ErrorTypeVar{}; reportError(TypeError{typealias->location, OccursCheckFailed{}}); @@ -607,26 +595,22 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; + expectedTypes.reserve(return_.list.size); - if (FFlag::LuauInferReturnAssertAssign) + TypePackIterator expectedRetCurr = begin(scope->returnType); + TypePackIterator expectedRetEnd = end(scope->returnType); + + for (size_t i = 0; i < return_.list.size; ++i) { - expectedTypes.reserve(return_.list.size); - - TypePackIterator expectedRetCurr = begin(scope->returnType); - TypePackIterator expectedRetEnd = end(scope->returnType); - - for (size_t i = 0; i < return_.list.size; ++i) + if (expectedRetCurr != expectedRetEnd) { - if (expectedRetCurr != expectedRetEnd) - { - expectedTypes.push_back(*expectedRetCurr); - ++expectedRetCurr; - } - else if (auto expectedArgsTail = expectedRetCurr.tail()) - { - if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) - expectedTypes.push_back(vtp->ty); - } + expectedTypes.push_back(*expectedRetCurr); + ++expectedRetCurr; + } + else if (auto expectedArgsTail = expectedRetCurr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) + expectedTypes.push_back(vtp->ty); } } @@ -672,34 +656,30 @@ ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; + expectedTypes.reserve(assign.vars.size); - if (FFlag::LuauInferReturnAssertAssign) + ScopePtr moduleScope = currentModule->getModuleScope(); + + for (size_t i = 0; i < assign.vars.size; ++i) { - expectedTypes.reserve(assign.vars.size); + AstExpr* dest = assign.vars.data[i]; - ScopePtr moduleScope = currentModule->getModuleScope(); - - for (size_t i = 0; i < assign.vars.size; ++i) + if (auto a = dest->as()) { - AstExpr* dest = assign.vars.data[i]; - - if (auto a = dest->as()) - { - // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later - expectedTypes.push_back(scope->lookup(a->local)); - } - else if (auto a = dest->as()) - { - // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList - if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) - expectedTypes.push_back(it->second.typeId); - else - expectedTypes.push_back(std::nullopt); - } + // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later + expectedTypes.push_back(scope->lookup(a->local)); + } + else if (auto a = dest->as()) + { + // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList + if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) + expectedTypes.push_back(it->second.typeId); else - { - expectedTypes.push_back(checkLValue(scope, *dest)); - } + expectedTypes.push_back(std::nullopt); + } + else + { + expectedTypes.push_back(checkLValue(scope, *dest)); } } @@ -715,7 +695,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) AstExpr* dest = assign.vars.data[i]; TypeId left = nullptr; - if (!FFlag::LuauInferReturnAssertAssign || dest->is() || dest->is()) + if (dest->is() || dest->is()) left = checkLValue(scope, *dest); else left = *expectedTypes[i]; @@ -751,11 +731,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (right) { - if (FFlag::LuauGenericFunctions && !maybeGeneric(left) && isGeneric(right)) - right = instantiate(scope, right, loc); - - if (!FFlag::LuauGenericFunctions && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && - get(FFlag::LuauAddMissingFollow ? follow(right) : right)) + if (!maybeGeneric(left) && isGeneric(right)) right = instantiate(scope, right, loc); // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry @@ -766,7 +742,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (!destTableTypeReceivingNil || !destTableTypeReceivingNil->indexer) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. - if (isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && !get(follow(right))) + if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) unify(left, anyType, loc); else unify(left, right, loc); @@ -815,7 +791,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (annotation) { - ty = (FFlag::LuauRankNTypes ? resolveType(scope, *annotation) : resolveType(scope, *annotation, true)); + ty = resolveType(scope, *annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(ty))) @@ -823,23 +799,19 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) } if (!ty) - ty = rhsIsTable ? (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)) - : isNonstrictMode() ? anyType : (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + ty = rhsIsTable ? freshType(scope) : isNonstrictMode() ? anyType : freshType(scope); varBindings.emplace_back(vars[i], Binding{ty, vars[i]->location}); variableTypes.push_back(ty); expectedTypes.push_back(ty); - if (FFlag::LuauGenericFunctions) - instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); - else - instantiateGenerics.push_back(annotation != nullptr && get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)); + instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); } if (local.values.size > 0) { - TypePackId variablePack = addTypePack(variableTypes, FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, true)); + TypePackId variablePack = addTypePack(variableTypes, freshTypePack(scope)); TypePackId valuePack = checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; @@ -979,8 +951,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { AstExprCall* exprCall = firstValue->as(); callRetPack = checkExprPack(scope, *exprCall).type; - if (!FFlag::LuauRankNTypes) - callRetPack = DEPRECATED_instantiate(scope, callRetPack, exprCall->location); callRetPack = follow(callRetPack); if (get(callRetPack)) @@ -998,8 +968,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else { iterTy = *first(callRetPack); - if (FFlag::LuauRankNTypes) - iterTy = instantiate(scope, iterTy, exprCall->location); + iterTy = instantiate(scope, iterTy, exprCall->location); } } else @@ -1158,10 +1127,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (FFlag::LuauGenericFunctions) - scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; - else - scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; + scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; } void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) @@ -1199,7 +1165,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; @@ -1234,7 +1200,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; } - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; @@ -1266,7 +1232,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } } - TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); + TypeId ty = resolveType(aliasScope, *typealias.type); if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy @@ -1325,25 +1291,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; - if (FFlag::LuauAddMissingFollow) + if (!get(follow(*superTy))) { - if (!get(follow(*superTy))) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); + reportError(declaredClass.location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); - return; - } - } - else - { - if (const ClassTypeVar* superCtv = get(*superTy); !superCtv) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); - - return; - } + return; } } @@ -1558,8 +1511,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa } else if (auto ftp = get(varargPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); - TypePackId tail = (FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); + TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; return {head}; } @@ -1567,7 +1520,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {errorType}; else if (auto vtp = get(varargPack)) return {vtp->ty}; - else if (FFlag::LuauGenericVariadicsUnification && get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1588,7 +1541,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa } else if (auto ftp = get(retPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); unify(retPack, pack, expr.location); return {head, std::move(result.predicates)}; @@ -1667,7 +1620,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId result = freshType(scope); tableType->props[name] = {result}; return result; } @@ -2129,7 +2082,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isNonstrictMode() && !isOrOp) return ty; - if (auto i = get(ty)) + if (get(ty)) { std::optional cleaned = tryStripUnionFromNil(ty); @@ -2158,16 +2111,9 @@ TypeId TypeChecker::checkRelationalOperation( { if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { - if (FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - { - ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); - return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); - } - else - { - return unionOfTypes(rhsType, checkExpr(scope, *subexp->right).type, expr.location); - } + ScopePtr subScope = childScope(scope, subexp->location); + reportErrors(resolve(predicates, subScope, true)); + return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); } } @@ -2217,10 +2163,8 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = - isString(lhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType); - std::optional rightMetatable = - isString(rhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(rhsType) : rhsType); + std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); + std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { @@ -2266,7 +2210,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) + if (get(follow(lhsType)) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2417,12 +2361,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi resolve(lhs.predicates, innerScope, true); ExprResult rhs = checkExpr(innerScope, *expr.right); - if (!FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - resolve(rhs.predicates, innerScope, true); return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauOrPredicate && expr.op == AstExprBinary::Or) + else if (expr.op == AstExprBinary::Or) { ExprResult lhs = checkExpr(scope, *expr.left); @@ -2468,19 +2410,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { - ExprResult result; - TypeId annotationType; - - if (FFlag::LuauInferReturnAssertAssign) - { - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - result = checkExpr(scope, *expr.expr, annotationType); - } - else - { - result = checkExpr(scope, *expr.expr); - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - } + TypeId annotationType = resolveType(scope, *expr.annotation); + ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); if (!errorVec.empty()) @@ -2570,23 +2501,16 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (it != moduleScope->bindings.end()) return std::pair(it->second.typeId, &it->second.typeId); - if (isNonstrictMode() || FFlag::LuauSecondTypecheckKnowsTheDataModel) - { - TypeId result = (FFlag::LuauGenericFunctions && FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(moduleScope, true)); + TypeId result = freshType(scope); + Binding& binding = moduleScope->bindings[expr.name]; + binding = {result, expr.location}; - Binding& binding = moduleScope->bindings[expr.name]; - binding = {result, expr.location}; + // If we're in strict mode, we want to report defining a global as an error, + // but still add it to the bindings, so that autocomplete includes it in completions. + if (!isNonstrictMode()) + reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - // If we're in strict mode, we want to report defining a global as an error, - // but still add it to the bindings, so that autocomplete includes it in completions. - if (!isNonstrictMode()) - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - - return std::pair(result, &binding.typeId); - } - - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(errorType, nullptr); + return std::pair(result, &binding.typeId); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) @@ -2611,7 +2535,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - TypeId theType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; @@ -2683,7 +2607,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope AstExprConstantString* value = expr.index->as(); - if (value && FFlag::LuauClassPropertyAccessAsString) + if (value) { if (const ClassTypeVar* exprClass = get(exprType)) { @@ -2714,7 +2638,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; @@ -2730,7 +2654,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } @@ -2758,7 +2682,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(scope); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2768,7 +2692,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)), funName.location}; + binding = {freshType(scope), funName.location}; return binding.typeId; } @@ -2798,7 +2722,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + property.type = freshType(scope); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -2865,22 +2789,11 @@ std::pair TypeChecker::checkFunctionSignature( expectedFunctionType = nullptr; } - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); - } + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); TypePackId retPack; if (expr.hasReturnAnnotation) - { - if (FFlag::LuauGenericFunctions) - retPack = resolveTypePack(funScope, expr.returnAnnotation); - else - retPack = resolveTypePack(scope, expr.returnAnnotation); - } + retPack = resolveTypePack(funScope, expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) @@ -2889,24 +2802,17 @@ std::pair TypeChecker::checkFunctionSignature( // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) - retPack = FFlag::LuauGenericFunctions ? freshTypePack(funScope) : freshTypePack(scope); + retPack = freshTypePack(funScope); else retPack = addTypePack(head, tail); } - else if (FFlag::LuauGenericFunctions) - retPack = freshTypePack(funScope); else - retPack = freshTypePack(scope); + retPack = freshTypePack(funScope); if (expr.vararg) { if (expr.varargAnnotation) - { - if (FFlag::LuauGenericFunctions) - funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); - else - funScope->varargPack = resolveTypePack(scope, *expr.varargAnnotation); - } + funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); else { if (expectedFunctionType && !isNonstrictMode()) @@ -2963,7 +2869,7 @@ std::pair TypeChecker::checkFunctionSignature( if (local->annotation) { - argType = resolveType((FFlag::LuauGenericFunctions ? funScope : scope), *local->annotation); + argType = resolveType(funScope, *local->annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(argType))) @@ -3022,7 +2928,7 @@ static bool allowsNoReturnValues(const TypePackId tp) { for (TypeId ty : tp) { - if (!get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (!get(follow(ty))) { return false; } @@ -3058,7 +2964,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE check(scope, *function.body); // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (FFlag::LuauAddMissingFollow ? get_if(&funTy->retType->ty) : get(funTy->retType)) + if (get_if(&funTy->retType->ty)) *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3287,7 +3193,7 @@ void TypeChecker::checkArgumentList( return; } - else if (FFlag::LuauGenericVariadicsUnification && get(tail)) + else if (get(tail)) { // Create a type pack out of the remaining argument types // and unify it with the tail. @@ -3310,7 +3216,7 @@ void TypeChecker::checkArgumentList( return; } - else if (FFlag::LuauRankNTypes && get(tail)) + else if (get(tail)) { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; @@ -3323,10 +3229,7 @@ void TypeChecker::checkArgumentList( } else { - if (FFlag::LuauRankNTypes) - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); - else - state.tryUnify(*paramIter, *argIter, /*isFunctionCall*/ false); + unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); ++argIter; ++paramIter; } @@ -3356,9 +3259,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A ice("method call expression has no 'self'"); selfType = checkExpr(scope, *indexExpr->expr).type; - if (!FFlag::LuauRankNTypes) - instantiate(scope, selfType, expr.func->location); - selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) @@ -3393,8 +3293,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argList = argListResult.type; - TypePackId argPack = (FFlag::LuauRankNTypes ? argList : DEPRECATED_instantiate(scope, argList, expr.location)); + TypePackId argPack = argListResult.type; if (get(argPack)) return ExprResult{errorTypePack}; @@ -3526,8 +3425,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); TypeId fn = *ty; - if (FFlag::LuauRankNTypes) - fn = instantiate(scope, fn, expr.func->location); + fn = instantiate(scope, fn, expr.func->location); return checkCallOverload( scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); @@ -3800,7 +3698,7 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; - if (instantiateGenerics.size() > i && instantiateGenerics[i] && (FFlag::LuauGenericFunctions || get(actualType))) + if (instantiateGenerics.size() > i && instantiateGenerics[i]) actualType = instantiate(scope, actualType, expr->location); if (expectedType) @@ -3837,7 +3735,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); - if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + if (FFlag::LuauNewRequireTrace2 && moduleInfo.name.empty()) { if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { @@ -3922,7 +3820,6 @@ bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& locat bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) { - LUAU_ASSERT(FFlag::LuauRankNTypes); Unifier state = mkUnifier(location); unifyWithInstantiationIfNeeded(scope, left, right, state); @@ -3933,7 +3830,6 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (!maybeGeneric(right)) // Quick check to see if we definitely can't instantiate state.tryUnify(left, right, /*isFunctionCall*/ false); @@ -3973,19 +3869,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l bool Instantiation::isDirty(TypeId ty) { - if (FFlag::LuauRankNTypes) - { - if (get(ty)) - return true; - else - return false; - } - - if (const FunctionTypeVar* ftv = get(ty)) - return !ftv->generics.empty() || !ftv->genericPacks.empty(); - else if (const TableTypeVar* ttv = get(ty)) - return ttv->state == TableState::Generic; - else if (get(ty)) + if (get(ty)) return true; else return false; @@ -3993,18 +3877,11 @@ bool Instantiation::isDirty(TypeId ty) bool Instantiation::isDirty(TypePackId tp) { - if (FFlag::LuauRankNTypes) - return false; - - if (get(tp)) - return true; - else - return false; + return false; } bool Instantiation::ignoreChildren(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (get(ty)) return true; else @@ -4013,63 +3890,38 @@ bool Instantiation::ignoreChildren(TypeId ty) TypeId Instantiation::clean(TypeId ty) { - LUAU_ASSERT(isDirty(ty)); + const FunctionTypeVar* ftv = get(ty); + LUAU_ASSERT(ftv); - if (const FunctionTypeVar* ftv = get(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); - if (FFlag::LuauRankNTypes) - { - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - replaceGenerics.level = level; - replaceGenerics.currentModule = currentModule; - replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); - replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + replaceGenerics.level = level; + replaceGenerics.currentModule = currentModule; + replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); + replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - } + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else if (const TableTypeVar* ttv = get(ty)) - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - TypeId result = addType(std::move(clone)); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TypeId result = addType(FreeTypeVar{level}); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; } TypePackId Instantiation::clean(TypePackId tp) { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - return addTypePack(TypePackVar(FreeTypePack{level})); + LUAU_ASSERT(false); + return tp; } bool ReplaceGenerics::ignoreChildren(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (const FunctionTypeVar* ftv = get(ty)) // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. @@ -4083,7 +3935,6 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::isDirty(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (const TableTypeVar* ttv = get(ty)) return ttv->state == TableState::Generic; else if (get(ty)) @@ -4094,7 +3945,6 @@ bool ReplaceGenerics::isDirty(TypeId ty) bool ReplaceGenerics::isDirty(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (get(tp)) return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); else @@ -4255,21 +4105,6 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat } } -TypePackId TypeChecker::DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); - instantiation.level = scope->level; - instantiation.currentModule = currentModule; - std::optional instantiated = instantiation.substitute(ty); - if (instantiated.has_value()) - return *instantiated; - else - { - reportError(location, UnificationTooComplex{}); - return errorTypePack; - } -} - TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { anyification.anyType = anyType; @@ -4444,16 +4279,6 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } -TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) -{ - return DEPRECATED_freshType(scope->level, canBeGeneric); -} - -TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) -{ - return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); -} - std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4509,21 +4334,8 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) return addTypePack(TypePackVar(FreeTypePack(level))); } -TypePackId TypeChecker::DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric) +TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) { - return DEPRECATED_freshTypePack(scope->level, canBeGeneric); -} - -TypePackId TypeChecker::DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric) -{ - return addTypePack(TypePackVar(FreeTypePack(level, canBeGeneric))); -} - -TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation, bool DEPRECATED_canBeGeneric) -{ - if (DEPRECATED_canBeGeneric) - LUAU_ASSERT(!FFlag::LuauRankNTypes); - if (const auto& lit = annotation.as()) { std::optional tf; @@ -4668,11 +4480,11 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation std::optional tableIndexer; for (const auto& prop : table->props) - props[prop.name.value] = {resolveType(scope, *prop.type, DEPRECATED_canBeGeneric)}; + props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType, DEPRECATED_canBeGeneric), resolveType(scope, *indexer->resultType, DEPRECATED_canBeGeneric)); + resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4683,17 +4495,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { ScopePtr funcScope = childScope(scope, func->location); - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); - } - - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && (generics.size() > 0 || genericPacks.size() > 0)) - reportError(TypeError{annotation.location, GenericError{"generic function where only monotypes are allowed"}}); + auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); @@ -4716,16 +4518,13 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (auto typeOf = annotation.as()) { TypeId ty = checkExpr(scope, *typeOf->expr).type; - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && isGeneric(ty)) - reportError(TypeError{annotation.location, GenericError{"typeof produced a polytype where only monotypes are allowed"}}); return ty; } else if (const auto& un = annotation.as()) { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(UnionTypeVar{types}); } @@ -4733,7 +4532,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(IntersectionTypeVar{types}); } @@ -4919,9 +4718,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) { - // TODO: CLI-46926 it's a bad idea to rename the type whether we follow through the BoundTypeVar or not - TypeId target = FFlag::LuauFollowInTypeFunApply ? follow(instantiated) : instantiated; - + // TODO: CLI-46926 it's not a good idea to rename the type here + TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; TableTypeVar* ttv = getMutableTableType(target); @@ -5152,31 +4950,18 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { - if (FFlag::LuauOrPredicate) + if (!sense) { - if (!sense) - { - OrPredicate orP{ - {NotPredicate{std::move(andP.lhs)}}, - {NotPredicate{std::move(andP.rhs)}}, - }; + OrPredicate orP{ + {NotPredicate{std::move(andP.lhs)}}, + {NotPredicate{std::move(andP.rhs)}}, + }; - return resolve(orP, errVec, refis, scope, !sense); - } - - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + return resolve(orP, errVec, refis, scope, !sense); } - else - { - // And predicate is currently not resolvable when sense is false. 'not (a and b)' is synonymous with '(not a) or (not b)'. - // TODO: implement environment merging to permit this case. - if (!sense) - return; - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); - } + resolve(andP.lhs, errVec, refis, scope, sense); + resolve(andP.rhs, errVec, refis, scope, sense); } void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5207,58 +4992,41 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { - if (FFlag::LuauTypeGuardPeelsAwaySubclasses) - { - // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + // This by itself is not truly enough to determine that A is stronger than B or vice versa. + // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. + // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) + bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); + bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); - // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. - if (!optionIsSubtype && targetIsSubtype) + // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. + if (!optionIsSubtype && targetIsSubtype) + return sense ? isaP.ty : option; + + // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. + if (optionIsSubtype && !targetIsSubtype) + return sense ? std::optional(option) : std::nullopt; + + // If neither has any relationship, we only return A if sense is false. + if (!optionIsSubtype && !targetIsSubtype) + return sense ? std::nullopt : std::optional(option); + + // If both are subtypes, then we're in one of the two situations: + // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ + // 2. any <: Instance ∧ Instance <: any + // Right now, we have to look at the types to see if they were undecidables. + // By this point, we also know free tables are also subtypes and supertypes. + if (optionIsSubtype && targetIsSubtype) + { + // We can only have (any, Instance) because the rhs is never undecidable right now. + // So we can just return the right hand side immediately. + + // typeof(x) == "Instance" where x : any + auto ttv = get(option); + if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) return sense ? isaP.ty : option; - // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. - if (optionIsSubtype && !targetIsSubtype) - return sense ? std::optional(option) : std::nullopt; - - // If neither has any relationship, we only return A if sense is false. - if (!optionIsSubtype && !targetIsSubtype) - return sense ? std::nullopt : std::optional(option); - - // If both are subtypes, then we're in one of the two situations: - // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ - // 2. any <: Instance ∧ Instance <: any - // Right now, we have to look at the types to see if they were undecidables. - // By this point, we also know free tables are also subtypes and supertypes. - if (optionIsSubtype && targetIsSubtype) - { - // We can only have (any, Instance) because the rhs is never undecidable right now. - // So we can just return the right hand side immediately. - - // typeof(x) == "Instance" where x : any - auto ttv = get(option); - if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) - return sense ? isaP.ty : option; - - // typeof(x) == "Instance" where x : Instance - if (sense) - return isaP.ty; - } - } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (isSubclass(lctv, rctv) == sense) - return option; - - if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - - if (canUnify(option, isaP.ty, isaP.location).empty() == sense) + // typeof(x) == "Instance" where x : Instance + if (sense) return isaP.ty; } @@ -5457,7 +5225,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp TypePack* expectedPack = getMutable(expectedTypePack); LUAU_ASSERT(expectedPack); for (size_t i = 0; i < expectedLength; ++i) - expectedPack->head.push_back(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + expectedPack->head.push_back(freshType(scope)); unify(expectedTypePack, tp, location); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 68a16ef0..228b1926 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -97,7 +97,7 @@ TypePackIterator begin(TypePackId tp) TypePackIterator end(TypePackId tp) { - return FFlag::LuauAddMissingFollow ? TypePackIterator{} : TypePackIterator{nullptr}; + return TypePackIterator{}; } bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) @@ -203,7 +203,7 @@ TypePackId follow(TypePackId tp) size_t size(TypePackId tp) { - if (auto pack = get(FFlag::LuauAddMissingFollow ? follow(tp) : tp)) + if (auto pack = get(follow(tp))) return size(*pack); else return 0; @@ -216,7 +216,7 @@ bool finite(TypePackId tp) if (auto pack = get(tp)) return pack->tail ? finite(*pack->tail) : true; - if (auto pack = get(tp)) + if (get(tp)) return false; return true; @@ -227,7 +227,7 @@ size_t size(const TypePack& tp) size_t result = tp.head.size(); if (tp.tail) { - const TypePack* tail = get(FFlag::LuauAddMissingFollow ? follow(*tp.tail) : *tp.tail); + const TypePack* tail = get(follow(*tp.tail)); if (tail) result += size(*tail); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e82f7519..cd447ca2 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,8 +19,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) @@ -42,7 +40,7 @@ TypeId follow(TypeId t) }; auto force = [](TypeId ty) { - if (auto ltv = FFlag::LuauAddMissingFollow ? get_if(&ty->ty) : get(ty)) + if (auto ltv = get_if(&ty->ty)) { TypeId res = ltv->thunk(); if (get(res)) @@ -296,7 +294,7 @@ bool maybeGeneric(TypeId ty) { ty = follow(ty); if (auto ftv = get(ty)) - return FFlag::LuauRankNTypes || ftv->DEPRECATED_canBeGeneric; + return true; else if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 @@ -545,15 +543,30 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); +static TypeVar nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true}; +static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true}; +static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; +static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; +static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; +static TypeVar anyType_{AnyTypeVar{}}; +static TypeVar errorType_{ErrorTypeVar{}}; +static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; + +static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; +static TypePackVar errorTypePack_{Unifiable::Error{}}; + SingletonTypes::SingletonTypes() - : arena(new TypeArena) - , nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true} - , numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true} - , stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true} - , booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true} - , threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true} - , anyType_{AnyTypeVar{}} - , errorType_{ErrorTypeVar{}} + : nilType(&nilType_) + , numberType(&numberType_) + , stringType(&stringType_) + , booleanType(&booleanType_) + , threadType(&threadType_) + , anyType(&anyType_) + , errorType(&errorType_) + , optionalNumberType(&optionalNumberType_) + , anyTypePack(&anyTypePack_) + , errorTypePack(&errorTypePack_) + , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; @@ -749,9 +762,9 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) { - if (const PrimitiveTypeVar* ptv = get(ty)) + if (get(ty)) formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (const AnyTypeVar* atv = get(ty)) + else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); } else @@ -902,19 +915,19 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } - else if (const AnyTypeVar* atv = get(ty)) + else if (get(ty)) { formatAppend(result, "AnyTypeVar %d", index); finishNodeLabel(ty); finishNode(); } - else if (const PrimitiveTypeVar* ptv = get(ty)) + else if (get(ty)) { formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); finishNodeLabel(ty); finishNode(); } - else if (const ErrorTypeVar* etv = get(ty)) + else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); finishNodeLabel(ty); @@ -994,7 +1007,7 @@ void StateDot::visitChildren(TypePackId tp, int index) finishNodeLabel(tp); finishNode(); } - else if (const Unifiable::Error* etp = get(tp)) + else if (get(tp)) { formatAppend(result, "ErrorTypePack %d", index); finishNodeLabel(tp); @@ -1372,24 +1385,6 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } -static std::vector DEPRECATED_filterMap(TypeId type, TypeIdPredicate predicate) -{ - std::vector result; - - if (auto utv = get(follow(type))) - { - for (TypeId option : utv) - { - if (auto out = predicate(follow(option))) - result.push_back(*out); - } - } - else if (auto out = predicate(follow(type))) - return {*out}; - - return result; -} - static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs"; @@ -1470,9 +1465,6 @@ std::optional> magicFunctionFormat( std::vector filterMap(TypeId type, TypeIdPredicate predicate) { - if (!FFlag::LuauTypeGuardPeelsAwaySubclasses) - return DEPRECATED_filterMap(type, predicate); - type = follow(type); if (auto utv = get(type)) diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index cef07833..dc554664 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" -LUAU_FASTFLAG(LuauRankNTypes) - namespace Luau { namespace Unifiable @@ -14,14 +12,6 @@ Free::Free(TypeLevel level) { } -Free::Free(TypeLevel level, bool DEPRECATED_canBeGeneric) - : index(++nextIndex) - , level(level) - , DEPRECATED_canBeGeneric(DEPRECATED_canBeGeneric) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); -} - int Free::nextIndex = 0; Generic::Generic() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 2539650a..82f621b6 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,17 +14,15 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); -LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) -LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) namespace Luau { @@ -219,17 +217,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool *asMutable(subTy) = BoundTypeVar(superTy); } - if (!FFlag::LuauRankNTypes) - l->DEPRECATED_canBeGeneric &= r->DEPRECATED_canBeGeneric; - return; } - else if (l && r && FFlag::LuauGenericFunctions) + else if (l && r) { log(superTy); occursCheck(superTy, subTy); - if (!FFlag::LuauRankNTypes) - r->DEPRECATED_canBeGeneric &= l->DEPRECATED_canBeGeneric; r->level = min(r->level, l->level); *asMutable(superTy) = BoundTypeVar(subTy); return; @@ -240,7 +233,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (FFlag::LuauRankNTypes && rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(l->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -266,31 +259,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto leftGeneric = get(superTy); - if (FFlag::LuauRankNTypes && leftGeneric && !leftGeneric->level.subsumes(r->level)) + if (leftGeneric && !leftGeneric->level.subsumes(r->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } - // This is the old code which is just wrong - auto wrongGeneric = get(subTy); // Guaranteed to be null - if (!FFlag::LuauRankNTypes && FFlag::LuauGenericFunctions && wrongGeneric && r->level.subsumes(wrongGeneric->level)) - { - // This code is unreachable! Should we just remove it? - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); - return; - } - - // Check if we're unifying a monotype with a polytype - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !r->DEPRECATED_canBeGeneric && isGeneric(superTy)) - { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Failed to unify a polytype with a monotype"}}); - return; - } - if (!get(subTy)) { if (auto leftLevel = getMutableLevel(superTy)) @@ -333,6 +308,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // A | B <: T if A <: T and B <: T bool failed = false; std::optional unificationTooComplex; + std::optional firstFailedOption; size_t count = uv->options.size(); size_t i = 0; @@ -345,7 +321,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; else if (!innerState.errors.empty()) + { + // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' + if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type)) + firstFailedOption = {innerState.errors.front()}; + failed = true; + } if (i != count - 1) innerState.log.rollback(); @@ -358,7 +340,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (failed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (const UnionTypeVar* uv = get(superTy)) { @@ -425,14 +412,49 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (const IntersectionTypeVar* uv = get(superTy)) { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) + if (FFlag::LuauExtendedTypeMismatchError) { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; + } + + log.concat(std::move(innerState.log)); + } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + } + else + { + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + } } } else if (const IntersectionTypeVar* uv = get(subTy)) @@ -480,7 +502,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); @@ -773,8 +800,8 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = l->tail && get(FFlag::LuauAddMissingFollow ? follow(*l->tail) : *l->tail) != nullptr; - const bool rFreeTail = r->tail && get(FFlag::LuauAddMissingFollow ? follow(*r->tail) : *r->tail) != nullptr; + const bool lFreeTail = l->tail && get(follow(*l->tail)) != nullptr; + const bool rFreeTail = r->tail && get(follow(*r->tail)) != nullptr; if (lFreeTail && rFreeTail) tryUnify_(*l->tail, *r->tail); else if (lFreeTail) @@ -812,7 +839,7 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(*superIter) : *superIter)) + else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) { superIter.advance(); continue; @@ -887,24 +914,21 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal ice("passed non-function types to unifyFunction"); size_t numGenerics = lf->generics.size(); - if (FFlag::LuauGenericFunctions && numGenerics != rf->generics.size()) + if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); - if (FFlag::LuauGenericFunctions && numGenericPacks != rf->genericPacks.size()) + if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - if (FFlag::LuauGenericFunctions) - { - for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); - } + for (size_t i = 0; i < numGenerics; i++) + log.pushSeen(lf->generics[i], rf->generics[i]); CountMismatch::Context context = ctx; @@ -931,22 +955,19 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal tryUnify_(lf->retType, rf->retType); } - if (lf->definition && !rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !subTy->persistent)) + if (lf->definition && !rf->definition && !subTy->persistent) { rf->definition = lf->definition; } - else if (!lf->definition && rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !superTy->persistent)) + else if (!lf->definition && rf->definition && !superTy->persistent) { lf->definition = rf->definition; } ctx = context; - if (FFlag::LuauGenericFunctions) - { - for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); - } + for (int i = int(numGenerics) - 1; 0 <= i; i--) + log.popSeen(lf->generics[i], rf->generics[i]); } namespace @@ -1032,7 +1053,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, r->second.type); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1047,7 +1073,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, rt->indexer->indexResultType); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1083,7 +1114,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, lt->indexer->indexResultType); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1384,21 +1420,8 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio const auto& r = rt->props.find(it.first); if (r == rt->props.end()) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + if (isOptional(it.second.type)) + continue; missingPropertiesInSuper.push_back(it.first); @@ -1482,21 +1505,8 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio const auto& r = lt->props.find(it.first); if (r == lt->props.end()) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + if (isOptional(it.second.type)) + continue; extraPropertiesInSub.push_back(it.first); } @@ -1526,7 +1536,18 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse innerState.tryUnify_(lhs->table, rhs->table); innerState.tryUnify_(lhs->metatable, rhs->metatable); - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + if (FFlag::LuauExtendedTypeMismatchError) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); + } + else + { + checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + } log.concat(std::move(innerState.log)); } @@ -1613,10 +1634,34 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - tryUnify_(prop.type, singletonTypes.errorType); + + if (!FFlag::LuauExtendedClassMismatchError) + tryUnify_(prop.type, singletonTypes.errorType); } else - tryUnify_(prop.type, classProp->type); + { + if (FFlag::LuauExtendedClassMismatchError) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, classProp->type); + + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + innerState.log.rollback(); + } + } + else + { + tryUnify_(prop.type, classProp->type); + } + } } if (table->indexer) @@ -1649,45 +1694,24 @@ static void queueTypePack_DEPRECATED( while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = follow(a); if (seenTypePacks.count(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; } - else + else if (auto tp = get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } @@ -1698,45 +1722,24 @@ static void queueTypePack(std::vector& queue, DenseHashSet& while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = follow(a); if (seenTypePacks.find(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; } - else + else if (auto tp = get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } @@ -1990,33 +1993,6 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); } -std::optional Unifier::findMetatableEntry(TypeId type, std::string entry) -{ - type = follow(type); - - std::optional metatable = getMetatable(type); - if (!metatable) - return std::nullopt; - - TypeId unwrapped = follow(*metatable); - - if (get(unwrapped)) - return singletonTypes.anyType; - - const TableTypeVar* mtt = getTableType(unwrapped); - if (!mtt) - { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); - return std::nullopt; - } - - auto it = mtt->props.find(entry); - if (it != mtt->props.end()) - return it->second.type; - else - return std::nullopt; -} - void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; @@ -2168,7 +2144,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { for (const auto& ty : a->head) { - if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (auto f = get(follow(ty))) { occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); occursCheck(seen_DEPRECATED, seen, needle, f->retType); @@ -2207,6 +2183,17 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); } +void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) +{ + LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError); + + if (auto e = hasUnificationTooComplex(innerErrors)) + errors.push_back(*e); + else if (!innerErrors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); +} + void Unifier::ice(const std::string& message, const Location& location) { sharedState.iceHandler->ice(message, location); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 42c64dc9..39c7d925 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -282,7 +282,6 @@ private: // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(); - std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 846bc0ba..a1bad65e 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,13 +10,13 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsParserFix, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) +LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau { @@ -957,7 +957,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial) { nextLexeme(); - AstExpr* expr = parsePrimaryExpr(/* asStatement= */ false); + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ FFlag::LuauFixAmbiguousErrorRecoveryInAssign); if (!isExprLValue(expr)) expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); @@ -995,7 +995,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + auto [generics, genericPacks] = parseGenericTypeList(); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1343,19 +1343,18 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = !(FFlag::LuauParseGenericFunctions && lexer.current().type == '<'); - - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + bool monomorphic = lexer.current().type != '<'; Lexeme begin = lexer.current(); - if (FFlag::LuauGenericFunctionsParserFix) - expectAndConsume('(', "function parameters"); - else - { - LUAU_ASSERT(begin.type == '('); - nextLexeme(); // ( - } + auto [generics, genericPacks] = parseGenericTypeList(); + + Lexeme parameterStart = lexer.current(); + + if (!FFlag::LuauParseGenericFunctionTypeBegin) + begin = parameterStart; + + expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; @@ -1366,7 +1365,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) if (lexer.current().type != ')') varargAnnotation = parseTypeList(params, names); - expectMatchAndConsume(')', begin, true); + expectMatchAndConsume(')', parameterStart, true); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; @@ -1585,7 +1584,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return {parseTableTypeAnnotation(), {}}; } - else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) + else if (lexer.current().type == '(' || lexer.current().type == '<') { return parseFunctionTypeAnnotation(allowPack); } @@ -2315,19 +2314,6 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeListIfFFlagParseGenericFunctions() -{ - if (FFlag::LuauParseGenericFunctions) - return Parser::parseGenericTypeList(); - AstArray generics; - AstArray genericPacks; - generics.size = 0; - generics.data = nullptr; - genericPacks.size = 0; - genericPacks.data = nullptr; - return std::pair(generics, genericPacks); -} - std::pair, AstArray> Parser::parseGenericTypeList() { TempVector names{scratchName}; @@ -2342,7 +2328,7 @@ std::pair, AstArray> Parser::parseGenericTypeList() while (true) { AstName name = parseName().name; - if (FFlag::LuauParseGenericFunctions && lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3) { seenPack = true; nextLexeme(); @@ -2379,15 +2365,12 @@ AstArray Parser::parseTypeParams() Lexeme begin = lexer.current(); nextLexeme(); - bool seenPack = false; while (true) { if (FFlag::LuauParseTypePackTypeParameters) { if (shouldParseTypePackAnnotation(lexer)) { - seenPack = true; - auto typePack = parseTypePackAnnotation(); if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them @@ -2399,8 +2382,6 @@ AstArray Parser::parseTypeParams() if (typePack) { - seenPack = true; - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them parameters.push_back({{}, typePack}); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index ed0552d7..9ab10aaf 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -34,8 +34,10 @@ static void report(ReportFormat format, const char* name, const Luau::Location& } } -static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error) +static void reportError(ReportFormat format, const Luau::TypeError& error) { + const char* name = error.moduleName.c_str(); + if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); else @@ -49,7 +51,10 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) { - Luau::CheckResult cr = frontend.check(name); + Luau::CheckResult cr; + + if (frontend.isDirty(name)) + cr = frontend.check(name); if (!frontend.getSourceModule(name)) { @@ -58,7 +63,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, name, error); + reportError(format, error); Luau::LintResult lr = frontend.lint(name); @@ -115,7 +120,12 @@ struct CliFileResolver : Luau::FileResolver { if (Luau::AstExprConstantString* expr = node->as()) { - Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; + if (!moduleExists(name)) + { + // fall back to .lua if a module with .luau doesn't exist + name = std::string(expr->value.data, expr->value.size) + ".lua"; + } return {{name}}; } @@ -236,8 +246,15 @@ int main(int argc, char** argv) if (isDirectory(argv[i])) { traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + // Look for .luau first and if absent, fall back to .lua + if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) + { failed += !analyzeFile(frontend, name.c_str(), format, annotate); + } + else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + { + failed += !analyzeFile(frontend, name.c_str(), format, annotate); + } }); } else diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4968d080..5c904cca 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -13,6 +13,17 @@ #include +#ifdef _WIN32 +#include +#include +#endif + +enum class CompileFormat +{ + Default, + Binary +}; + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -51,9 +62,13 @@ static int lua_require(lua_State* L) return finishrequire(L); lua_pop(L, 1); - std::optional source = readFile(name + ".lua"); + std::optional source = readFile(name + ".luau"); if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); + { + source = readFile(name + ".lua"); // try .lua if .luau doesn't exist + if (!source) + luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error + } // module needs to run in a new thread, isolated from the rest lua_State* GL = lua_mainthread(L); @@ -183,6 +198,11 @@ static std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); +#ifdef __EMSCRIPTEN__ + // nicer formatting for errors in web repl + error = "Error:" + error; +#endif + fprintf(stdout, "%s", error.c_str()); } @@ -190,6 +210,39 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } +#ifdef __EMSCRIPTEN__ +extern "C" +{ + const char* executeScript(const char* source) + { + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); + } +} +#endif + +// Excluded from emscripten compilation to avoid -Wunused-function errors. +#ifndef __EMSCRIPTEN__ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; @@ -366,7 +419,7 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static bool compileFile(const char* name) +static bool compileFile(const char* name, CompileFormat format) { std::optional source = readFile(name); if (!source) @@ -383,7 +436,15 @@ static bool compileFile(const char* name) Luau::compileOrThrow(bcb, *source); - printf("%s", bcb.dumpEverything().c_str()); + switch (format) + { + case CompileFormat::Default: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; + } return true; } @@ -408,7 +469,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile: compile input files and output resulting bytecode\n"); + printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); @@ -440,8 +501,19 @@ int main(int argc, char** argv) return 0; } - if (argc >= 2 && strcmp(argv[1], "--compile") == 0) + + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { + CompileFormat format = CompileFormat::Default; + + if (strcmp(argv[1], "--compile=binary") == 0) + format = CompileFormat::Binary; + +#ifdef _WIN32 + if (format == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + int failed = 0; for (int i = 2; i < argc; ++i) @@ -452,13 +524,15 @@ int main(int argc, char** argv) if (isDirectory(argv[i])) { traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str()); + if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) + failed += !compileFile(name.c_str(), format); + else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + failed += !compileFile(name.c_str(), format); }); } else { - failed += !compileFile(argv[i]); + failed += !compileFile(argv[i], format); } } @@ -511,5 +585,6 @@ int main(int argc, char** argv) return failed; } } +#endif diff --git a/CMakeLists.txt b/CMakeLists.txt index d6598f2a..36014a98 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,17 +17,26 @@ add_library(Luau.VM STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) - add_executable(Luau.Analyze.CLI) + if(NOT EMSCRIPTEN) + add_executable(Luau.Analyze.CLI) + else() + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") + endif() # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) - set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + + if(NOT EMSCRIPTEN) + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + endif() endif() -if(LUAU_BUILD_TESTS) +if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) endif() + include(Sources.cmake) target_compile_features(Luau.Ast PUBLIC cxx_std_17) @@ -53,10 +62,6 @@ if(MSVC) else() list(APPEND LUAU_OPTIONS -Wall) # All warnings list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors - - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused - endif() endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) @@ -65,7 +70,10 @@ target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) - target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + + if(NOT EMSCRIPTEN) + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + endif() target_include_directories(Luau.Repl.CLI PRIVATE extern) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) @@ -74,10 +82,20 @@ if(LUAU_BUILD_CLI) target_link_libraries(Luau.Repl.CLI PRIVATE pthread) endif() - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + if(NOT EMSCRIPTEN) + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + endif() + + if(EMSCRIPTEN) + # declare exported functions to emscripten + target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions) + + # custom output directory for wasm + js file + set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau) + endif() endif() -if(LUAU_BUILD_TESTS) +if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 4b03ed1c..71631d10 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -467,6 +467,10 @@ enum LuauBuiltinFunction // vector ctor LBF_VECTOR, + + // bit32.count + LBF_BIT32_COUNTLZ, + LBF_BIT32_COUNTRZ, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index f8d67158..4f88e602 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -36,6 +36,9 @@ struct CompileOptions // global builtin to construct vectors; disabled by default const char* vectorLib = nullptr; const char* vectorCtor = nullptr; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals = nullptr; }; class CompileError : public std::exception diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7750a1d9..9712f02f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -13,7 +13,9 @@ LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) +LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) +LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) namespace Luau { @@ -22,6 +24,7 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; +// TODO: Remove with LuauGenericSpecialGlobals static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; CompileError::CompileError(const Location& location, const std::string& message) @@ -1277,7 +1280,7 @@ struct Compiler { const Global* global = globals.find(expr->name); - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->special)); + return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -2465,9 +2468,10 @@ struct Compiler } else if (node->is()) { + LUAU_ASSERT(!loops.empty()); + // before exiting out of the loop, we need to close all local variables that were captured in closures since loop start // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here - LUAU_ASSERT(!loops.empty()); closeLocals(loops.back().localOffset); size_t label = bytecode.emitLabel(); @@ -2478,12 +2482,13 @@ struct Compiler } else if (AstStatContinue* stat = node->as()) { + LUAU_ASSERT(!loops.empty()); + if (loops.back().untilCondition) validateContinueUntil(stat, loops.back().untilCondition); // before continuing, we need to close all local variables that were captured in closures since loop start // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here - LUAU_ASSERT(!loops.empty()); closeLocals(loops.back().localOffset); size_t label = bytecode.emitLabel(); @@ -2900,6 +2905,11 @@ struct Compiler break; case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.valueString.size); + } break; default: @@ -3440,7 +3450,7 @@ struct Compiler struct Global { - bool special = false; + bool writable = false; bool written = false; }; @@ -3498,7 +3508,7 @@ struct Compiler { Global* g = globals.find(object->name); - return !g || (!g->special && !g->written) ? Builtin{object->name, expr->index} : Builtin(); + return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); } else { @@ -3629,6 +3639,10 @@ struct Compiler return LBF_BIT32_RROTATE; if (builtin.method == "rshift") return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin) + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin) + return LBF_BIT32_COUNTRZ; } if (builtin.object == "string") @@ -3696,13 +3710,24 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); - // since access to some global objects may result in values that change over time, we block table imports - for (const char* global : kSpecialGlobals) + // since access to some global objects may result in values that change over time, we block imports from non-readonly tables + if (FFlag::LuauGenericSpecialGlobals) { - AstName name = names.get(global); + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (name.value) - compiler.globals[name].special = true; + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + compiler.globals[name].writable = true; + } + else + { + for (const char* global : kSpecialGlobals) + { + if (AstName name = names.get(global); name.value) + compiler.globals[name].writable = true; + } } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written @@ -3717,7 +3742,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1) + if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Makefile b/Makefile index 7788251d..5d51b3d4 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,10 @@ OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(T CXXFLAGS=-g -Wall -Werror LDFLAGS= -CXXFLAGS+=-Wno-unused # temporary, for older gcc versions +# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused +ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) + CXXFLAGS+=-Wno-unused +endif # configuration-specific flags ifeq ($(config),release) @@ -134,12 +137,11 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): # executable targets for fuzzing fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) + $(CXX) $^ $(LDFLAGS) -o $@ + fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-%: - $(CXX) $^ $(LDFLAGS) -o $@ - # static library targets $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) diff --git a/VM/include/lua.h b/VM/include/lua.h index 2f93ad90..a9d3e875 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -213,6 +213,8 @@ LUA_API int lua_resume(lua_State* L, lua_State* from, int narg); LUA_API int lua_resumeerror(lua_State* L, lua_State* from); LUA_API int lua_status(lua_State* L); LUA_API int lua_isyieldable(lua_State* L); +LUA_API void* lua_getthreaddata(lua_State* L); +LUA_API void lua_setthreaddata(lua_State* L, void* data); /* ** garbage-collection function and options @@ -346,6 +348,8 @@ struct lua_Debug * can only be changed when the VM is not running any code */ struct lua_Callbacks { + void* userdata; /* arbitrary userdata pointer that is never overwritten by Luau */ + void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */ void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */ diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f2e97c66..7e742644 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -703,6 +703,7 @@ void lua_setreadonly(lua_State* L, int objindex, bool value) const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); + api_check(L, t != hvalue(registry(L))); t->readonly = value; return; } @@ -987,6 +988,16 @@ int lua_status(lua_State* L) return L->status; } +void* lua_getthreaddata(lua_State* L) +{ + return L->userdata; +} + +void lua_setthreaddata(lua_State* L, void* data) +{ + L->userdata = data; +} + /* ** Garbage-collection function */ diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 0754a351..c72fe674 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -4,6 +4,8 @@ #include "lnumutils.h" +LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) + #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -177,6 +179,44 @@ static int b_replace(lua_State* L) return 1; } +static int b_countlz(lua_State* L) +{ + if (!FFlag::LuauBit32Count) + luaL_error(L, "bit32.countlz isn't enabled"); + + b_uint v = luaL_checkunsigned(L, 1); + + b_uint r = NBITS; + for (int i = 0; i < NBITS; ++i) + if (v & (1u << (NBITS - 1 - i))) + { + r = i; + break; + } + + lua_pushunsigned(L, r); + return 1; +} + +static int b_countrz(lua_State* L) +{ + if (!FFlag::LuauBit32Count) + luaL_error(L, "bit32.countrz isn't enabled"); + + b_uint v = luaL_checkunsigned(L, 1); + + b_uint r = NBITS; + for (int i = 0; i < NBITS; ++i) + if (v & (1u << i)) + { + r = i; + break; + } + + lua_pushunsigned(L, r); + return 1; +} + static const luaL_Reg bitlib[] = { {"arshift", b_arshift}, {"band", b_and}, @@ -190,6 +230,8 @@ static const luaL_Reg bitlib[] = { {"replace", b_replace}, {"rrotate", b_rrot}, {"rshift", b_rshift}, + {"countlz", b_countlz}, + {"countrz", b_countrz}, {NULL, NULL}, }; diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index e1c99b21..9ab57ac9 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -20,8 +20,9 @@ // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path // If luauF_* succeeds, it needs to return *all* requested arguments, filling results with nil as appropriate. // On input, nparams refers to the actual number of arguments (0+), whereas nresults contains LUA_MULTRET for arbitrary returns or 0+ for a -// fixed-length return Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults -// is <= expected number, which covers the LUA_MULTRET case. +// fixed-length return +// Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults is <= expected +// number, which covers the LUA_MULTRET case. static int luauF_assert(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { @@ -1030,6 +1031,52 @@ static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_countlz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + unsigned n; + luai_num2unsigned(n, a1); + +#ifdef _MSC_VER + unsigned long rl; + int r = _BitScanReverse(&rl, n) ? 31 - int(rl) : 32; +#else + int r = n == 0 ? 32 : __builtin_clz(n); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + unsigned n; + luai_num2unsigned(n, a1); + +#ifdef _MSC_VER + unsigned long rl; + int r = _BitScanForward(&rl, n) ? int(rl) : 32; +#else + int r = n == 0 ? 32 : __builtin_ctz(n); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1096,4 +1143,7 @@ luau_FastFunction luauF_table[256] = { luauF_tunpack, luauF_vector, + + luauF_countlz, + luauF_countrz, }; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 6553009f..11f79d1a 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,7 +13,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -677,117 +676,6 @@ static size_t atomic(lua_State* L) return work; } -static size_t singlestep(lua_State* L) -{ - size_t cost = 0; - global_State* g = L->global; - switch (g->gcstate) - { - case GCSpause: - { - markroot(L); /* start a new collection */ - LUAU_ASSERT(g->gcstate == GCSpropagate); - break; - } - case GCSpropagate: - { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } - break; - } - case GCSpropagateagain: - { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } - } - break; - } - case GCSatomic: - { - g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - cost = atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - break; - } - case GCSsweepstring: - { - size_t traversedcount = 0; - sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); - - // nothing more to sweep? - if (g->sweepstrgc >= g->strt.size) - { - // sweep string buffer list and preserve used string count - uint32_t nuse = L->global->strt.nuse; - sweepwholelist(L, &g->strbufgc, &traversedcount); - L->global->strt.nuse = nuse; - - g->gcstate = GCSsweep; // end sweep-string phase - } - - g->gcstats.currcycle.sweepitems += traversedcount; - - cost = GC_SWEEPCOST; - break; - } - case GCSsweep: - { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); - - g->gcstats.currcycle.sweepitems += traversedcount; - - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } - cost = GC_SWEEPMAX * GC_SWEEPCOST; - break; - } - default: - LUAU_ASSERT(!"Unexpected GC state"); - } - - return cost; -} - static size_t gcstep(lua_State* L, size_t limit) { size_t cost = 0; @@ -980,37 +868,12 @@ void luaC_step(lua_State* L, bool assist) int lastgcstate = g->gcstate; double lasttimestamp = lua_clock(); - if (FFlag::LuauConsolidatedStep) - { - size_t work = gcstep(L, lim); + size_t work = gcstep(L, lim); - if (assist) - g->gcstats.currcycle.assistwork += work; - else - g->gcstats.currcycle.explicitwork += work; - } + if (assist) + g->gcstats.currcycle.assistwork += work; else - { - // always perform at least one single step - do - { - lim -= singlestep(L); - - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) - { - GC_INTERRUPT(lastgcstate); - - double now = lua_clock(); - - recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); - - lasttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); - } + g->gcstats.currcycle.explicitwork += work; recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); @@ -1037,14 +900,7 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - if (FFlag::LuauConsolidatedStep) - { - GC_INTERRUPT(lastgcstate); - } - else - { - GC_INTERRUPT(g->gcstate); - } + GC_INTERRUPT(lastgcstate); } void luaC_fullgc(lua_State* L) @@ -1070,10 +926,7 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - if (FFlag::LuauConsolidatedStep) - gcstep(L, SIZE_MAX); - else - singlestep(L); + gcstep(L, SIZE_MAX); } finishGcCycleStats(g); @@ -1084,10 +937,7 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - if (FFlag::LuauConsolidatedStep) - gcstep(L, SIZE_MAX); - else - singlestep(L); + gcstep(L, SIZE_MAX); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index a9db3727..80a34483 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) + /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1404,10 +1406,20 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); + if (FFlag::LuauStrPackUBCastFix) + { + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); + } + else + { + unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, n, h.islittle, size, 0); + } break; } case Kfloat: diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 883442ae..07d22d59 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -30,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) +static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect"); // TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); diff --git a/VM/src/ltable.h b/VM/src/ltable.h index f98d87b1..45061443 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -9,7 +9,6 @@ #define gval(n) (&(n)->val) #define gnext(n) ((n)->key.next) -static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast below is incorrect"); #define gval2slot(t, v) int(cast_to(LuaNode*, static_cast(v)) - t->node) LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index de5788eb..37025818 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,8 +9,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -491,9 +489,6 @@ static int tclear(lua_State* L) static int tfreeze(lua_State* L) { - if (!FFlag::LuauTableFreeze) - luaG_runerror(L, "table.freeze is disabled"); - luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !lua_getreadonly(L, 1), 1, "table is already frozen"); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); @@ -506,9 +501,6 @@ static int tfreeze(lua_State* L) static int tisfrozen(lua_State* L) { - if (!FFlag::LuauTableFreeze) - luaG_runerror(L, "table.isfrozen is disabled"); - luaL_checktype(L, 1, LUA_TTABLE); lua_pushboolean(L, lua_getreadonly(L, 1)); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua index 87b9abfd..f6ae2cc6 100644 --- a/bench/tests/chess.lua +++ b/bench/tests/chess.lua @@ -205,38 +205,48 @@ function Bitboard:empty() return self.h == 0 and self.l == 0 end -function Bitboard:ctz() - local target = self.l - local offset = 0 - local result = 0 - - if target == 0 then - target = self.h - result = 32 +if not bit32.countrz then + local function ctz(v) + if v == 0 then return 32 end + local offset = 0 + while bit32.extract(v, offset) == 0 do + offset = offset + 1 + end + return offset end - - if target == 0 then - return 64 - end - - while bit32.extract(target, offset) == 0 do - offset = offset + 1 - end - - return result + offset -end - -function Bitboard:ctzafter(start) - start = start + 1 - if start < 32 then - for i=start,31 do - if bit32.extract(self.l, i) == 1 then return i end + function Bitboard:ctz() + local result = ctz(self.l) + if result == 32 then + return ctz(self.h) + 32 + else + return result end end - for i=math.max(32,start),63 do - if bit32.extract(self.h, i-32) == 1 then return i end + function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 + end +else + function Bitboard:ctz() + local result = bit32.countrz(self.l) + if result == 32 then + return bit32.countrz(self.h) + 32 + else + return result + end + end + function Bitboard:ctzafter(start) + local masked = self:band(Bitboard.full:lshift(start+1)) + return masked:ctz() end - return 64 end @@ -245,7 +255,7 @@ function Bitboard:lshift(amt) if amt == 0 then return self end if amt > 31 then - return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + return Bitboard.from(0, bit32.lshift(self.l, amt-32)) end local l = bit32.lshift(self.l, amt) @@ -832,12 +842,12 @@ end local testCases = {} local function addTest(...) table.insert(testCases, {...}) end -addTest(StartingFen, 3, 8902) -addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) -addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) -addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) -addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) -addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) +addTest(StartingFen, 2, 400) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46) local function chess() diff --git a/fuzz/luau.proto b/fuzz/luau.proto index 41a1d077..c78fcf31 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -19,6 +19,7 @@ message Expr { ExprTable table = 13; ExprUnary unary = 14; ExprBinary binary = 15; + ExprIfElse ifelse = 16; } } @@ -149,6 +150,12 @@ message ExprBinary { required Expr right = 3; } +message ExprIfElse { + required Expr cond = 1; + required Expr then = 2; + required Expr else = 3; +} + message LValue { oneof lvalue_oneof { ExprLocal local = 1; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 6c230b67..c85fac7d 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -11,6 +11,7 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" #include "Luau/ToString.h" +#include "Luau/Transpiler.h" #include "lua.h" #include "lualib.h" @@ -23,6 +24,7 @@ const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; const bool kFuzzTypes = true; +const bool kFuzzTranspile = true; static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); @@ -242,6 +244,11 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) } } + if (kFuzzTranspile && parseResult.root) + { + transpileWithTypes(*parseResult.root); + } + // run resulting bytecode if (kFuzzVM && bytecode.size()) { diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 2c861a55..e61b6936 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -476,6 +476,16 @@ struct ProtoToLuau print(expr.right()); } + void print(const luau::ExprIfElse& expr) + { + source += " if "; + print(expr.cond()); + source += " then "; + print(expr.then()); + source += " else "; + print(expr.else_()); + } + void print(const luau::LValue& expr) { if (expr.has_local()) diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index dd49e675..aa53a92b 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -45,7 +45,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { ScopedFastFlag sffs[] = { - {"LuauDontMutatePersistentFunctions", true}, {"LuauPersistDefinitionFileTypes", true}, }; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 44b8362d..8a7798f3 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1287,9 +1287,6 @@ local e: (n: n@5 TEST_CASE_FIXTURE(ACFixture, "generic_types") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - check(R"( function f(a: T@1 local b: string = "don't trip" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index bbac3302..7f03019c 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -13,6 +13,7 @@ LUAU_FASTFLAG(LuauPreloadClosures) LUAU_FASTFLAG(LuauPreloadClosuresFenv) LUAU_FASTFLAG(LuauPreloadClosuresUpval) +LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -1168,6 +1169,17 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldStringLen") +{ + CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( +LOADN R0 6 +LOADN R1 0 +LOADN R2 1 +LOADN R3 1 +RETURN R0 4 +)"); +} + TEST_CASE("ConstantFoldCompare") { // ordered comparisons @@ -3659,4 +3671,118 @@ RETURN R0 0 )"); } +TEST_CASE("LuauGenericSpecialGlobals") +{ + const char* source = R"( +print() +Game.print() +Workspace.print() +_G.print() +game.print() +plugin.print() +script.print() +shared.print() +workspace.print() +)"; + + { + ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; + + // Check Roblox globals are here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +RETURN R0 0 +)"); + } + + ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; + + // Check Roblox globals are no longer here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R0 3 +CALL R0 0 0 +GETIMPORT R0 5 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R0 9 +CALL R0 0 0 +GETIMPORT R0 11 +CALL R0 0 0 +GETIMPORT R0 13 +CALL R0 0 0 +GETIMPORT R0 15 +CALL R0 0 0 +GETIMPORT R0 17 +CALL R0 0 0 +RETURN R0 0 +)"); + + // Check we can add them back + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + const char* mutableGlobals[] = {"Game", "Workspace", "game", "plugin", "script", "shared", "workspace", NULL}; + options.mutableGlobals = &mutableGlobals[0]; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +RETURN R0 0 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 06b3c523..c1b790b9 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -240,8 +240,6 @@ TEST_CASE("Math") TEST_CASE("Table") { - ScopedFastFlag sff("LuauTableFreeze", true); - runConformance("nextvar.lua"); } @@ -322,6 +320,8 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { + ScopedFastFlag sff("LuauBit32Count", true); + runConformance("bitwise.lua"); } @@ -359,6 +359,8 @@ TEST_CASE("PCall") TEST_CASE("Pack") { + ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; + runConformance("tpack.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 37f1b60b..7ba40c50 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -479,10 +479,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff{"LuauLinterUnknownTypeVectorAware", true}; - - SourceModule sm; - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1400,8 +1396,6 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { - ScopedFastFlag sff("LuauLinterTableMoveZero", true); - LintResult result = lintTyped(R"( local t = {} local tt = {} diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index a80718e4..e3e6ce6d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixAmbiguousErrorRecoveryInAssign) + using namespace Luau; namespace @@ -625,10 +627,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_messages") )"), "Cannot have more than one table indexer"); - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauGenericFunctionsParserFix", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CHECK_EQ(getParseError(R"( type T = foo )"), @@ -1624,6 +1622,20 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") "statements"); CHECK(result3.errors.size() == 1); + + auto result4 = matchParseError(R"( + local t = {} + function f() return t end + t.x, (f) + ().y = 5, 6 + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + if (FFlag::LuauFixAmbiguousErrorRecoveryInAssign) + CHECK(result4.errors.size() == 1); + else + CHECK(result4.errors.size() == 5); } TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") @@ -1824,9 +1836,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( function f(...: a...) end @@ -1861,9 +1870,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") TEST_CASE_FIXTURE(Fixture, "generic_function_declaration_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( declare function f() )"); @@ -1953,12 +1959,7 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", "Expected '->' when parsing function type, got "); - { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - ScopedFastFlag luauGenericFunctionsParserFix{"LuauGenericFunctionsParserFix", true}; - - matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); - } + matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } TEST_SUITE_END(); @@ -2362,8 +2363,6 @@ type Fn = ( CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); } - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - try { parse(R"(type Fn = (any, string | number | ()) -> any)"); @@ -2397,8 +2396,6 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - try { parse(R"( @@ -2521,7 +2518,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); AstStat* stat = parse(R"( @@ -2534,4 +2530,9 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp index bb5a93c5..7081693e 100644 --- a/tests/Predicate.test.cpp +++ b/tests/Predicate.test.cpp @@ -33,8 +33,6 @@ TEST_SUITE_BEGIN("Predicate"); TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"b", typeChecker.stringType}, {"c", typeChecker.numberType}, @@ -61,8 +59,6 @@ TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"a", typeChecker.stringType}, {"b", typeChecker.stringType}, @@ -89,8 +85,6 @@ TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"a", typeChecker.stringType}, {"b", typeChecker.numberType}, diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e18bf7cd..b076e9ad 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -259,9 +259,6 @@ TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names") TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names_generic") { - ScopedFastFlag luauGenericFunctions{"LuauGenericFunctions", true}; - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - CheckResult result = check("local function f(n: number, ...: a...): (a...) return ... end"); LUAU_REQUIRE_NO_ERRORS(result); @@ -340,10 +337,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { - ScopedFastFlag sff[] = { - {"LuauGenericFunctions", true}, - }; - CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -468,8 +461,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_inters TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; - TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->name = "Table"; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index bfff60f9..928c03a3 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -21,7 +21,7 @@ local function isPortal(element) return false end - return element.component==Core.Portal + return element.component == Core.Portal end )"; @@ -223,12 +223,24 @@ TEST_CASE("escaped_strings") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_2") +{ + const std::string code = R"( local s="\a\b\f\n\r\t\v\'\"\\" )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("binary_keywords") +{ + const std::string code = "local c = a0 ._ or b0 ._"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -364,10 +376,10 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") )"; std::string expected = R"( - local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:...string): (string,...number) + local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:string): (string,...number) end - local b:(...string)->(...number)=function(...:...string): ...number + local b:(...string)->(...number)=function(...:string): ...number end local c:()->()=function(): () @@ -400,4 +412,238 @@ TEST_CASE_FIXTURE(Fixture, "function_type_location") CHECK_EQ(expected, actual); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") +{ + std::string code = "local a = 5 :: number"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") +{ + ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); + + std::string code = "local a = if 1 then 2 else 3"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") +{ + fileResolver.source["game/A"] = R"( +export type Type = { a: number } +return {} + )"; + + std::string code = R"( +local Import = require(game.A) +local a: Import.Type + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + std::string code = R"( +type Packed = (T...)->(T...) +local a: Packed<> +local b: Packed<(number, string)> + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested") +{ + std::string code = "local a: ((number)->(string))|((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_2") +{ + std::string code = "local a: (number&string)|(string&boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") +{ + std::string code = "local a: nil | (string & number)"; + + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") +{ + std::string code = "local a: ((number)->(string))&((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested_2") +{ + std::string code = "local a: (number|string)&(string|boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_varargs") +{ + std::string code = "local function f(...) return ... end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") +{ + std::string code = "local a = {1, 2, 3} local b = a[2]"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_unary") +{ + std::string code = R"( +local a = 1 +local b = -1 +local c = true +local d = not c +local e = 'hello' +local d = #e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") +{ + std::string code = R"( +local a, b, c +repeat + if a then break end + if b then continue end +until c + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignmenr") +{ + std::string code = R"( +local a = 1 +a += 2 +a -= 3 +a *= 4 +a /= 5 +a %= 6 +a ^= 7 +a ..= ' - result' + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") +{ + std::string code = "a, b, c = 1, 2, 3"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") +{ + ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true); + + std::string code = R"( +local function foo(a: T, ...: S...) return 1 end +local f: (T, S...)->(number) = foo + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_reverse") +{ + std::string code = "local a: nil | number"; + + CHECK_EQ("local a: number?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple") +{ + std::string code = "for k,v in next,{}do print(k,v) end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") +{ + std::string code = "local a = f:-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") +{ + std::string code = "-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("(error-stat: (error-expr))", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_type") +{ + std::string code = "local a: "; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a:%error-type%", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_parse_error") +{ + std::string code = "local a = -"; + + auto result = transpile(code); + CHECK_EQ("", result.code); + CHECK_EQ("Expected identifier when parsing expression, got ", result.parseError); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_to_string") +{ + std::string code = "local a: string = 'hello'"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + REQUIRE(parseResult.root); + REQUIRE(parseResult.root->body.size == 1); + AstStatLocal* statLocal = parseResult.root->body.data[0]->as(); + REQUIRE(statLocal); + CHECK_EQ("local a: string = 'hello'", toString(statLocal)); + REQUIRE(statLocal->vars.size == 1); + AstLocal* local = statLocal->vars.data[0]; + REQUIRE(local->annotation); + CHECK_EQ("string", toString(local->annotation)); + REQUIRE(statLocal->values.size == 1); + AstExpr* expr = statLocal->values.data[0]; + CHECK_EQ("'hello'", toString(expr)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index f580604c..c27f8083 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -247,9 +247,6 @@ TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") { - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 17e32e9f..1e2eae14 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -444,8 +444,6 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local co = coroutine.create(function() end) )"); @@ -456,8 +454,6 @@ TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local function nifty(x, y) print(x, y) @@ -476,8 +472,6 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( --!nonstrict local function nifty(x, y) @@ -822,8 +816,6 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local f = math.sin local function g(x) return math.sin(x) end diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index eabf7e65..1ff23fe6 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -232,8 +232,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class") TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() local x = 1 + c["BaseField"] @@ -244,8 +242,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() c["BaseField"] = 444 @@ -451,4 +447,25 @@ b.X = 2 -- real Vector2.X is also read-only CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[3])); } +TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") +{ + ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true}; + + CheckResult result = check(R"( +local function foo(v) + return v.X :: number + string.len(v.Y) +end + +local a: Vector2 +local b = foo +b(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'Vector2' could not be converted into '{- X: a, Y: string -}' +caused by: + Property 'Y' is not compatible. Type 'number' could not be converted into 'string')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 41e3e45a..2652486b 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -171,9 +171,6 @@ TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") TEST_CASE_FIXTURE(Fixture, "declaring_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - loadDefinition(R"( declare function f(a: a, b: b): string declare function g(...: a...): b... diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 581375a1..de2f0154 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -13,8 +13,6 @@ TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function id(x:a): a return x @@ -27,8 +25,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_function") TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a): a return x @@ -41,10 +37,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function id(...: a...): (a...) return ... end local x: string, y: boolean = id("hi", true) @@ -56,8 +48,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -66,8 +56,6 @@ TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id @@ -79,7 +67,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -92,7 +79,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -104,8 +90,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t = {} t.m = function(x: a):a return x end @@ -117,8 +101,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t: { m: (number)->number } = { m = function(x:number) return x+1 end } local function id(x:a):a return x end @@ -129,8 +111,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x:a): a @@ -145,8 +125,6 @@ TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a local y: string = id("hi") @@ -159,8 +137,6 @@ TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local id2 local function id1(x:a):a @@ -179,8 +155,6 @@ TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } local x: T = { id = function(x:a):a return x end } @@ -192,8 +166,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") TEST_CASE_FIXTURE(Fixture, "generic_factories") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -215,10 +187,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_factories") TEST_CASE_FIXTURE(Fixture, "factories_of_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -241,7 +209,6 @@ TEST_CASE_FIXTURE(Fixture, "factories_of_generics") TEST_CASE_FIXTURE(Fixture, "infer_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function id(x) return x @@ -265,7 +232,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x @@ -289,7 +255,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x) @@ -304,7 +269,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -316,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -331,8 +294,6 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") TEST_CASE_FIXTURE(Fixture, "infer_generic_property") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauRankNTypes", true}; CheckResult result = check(R"( local t = {} t.m = function(x) return x end @@ -344,9 +305,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_property") TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f(g: (a)->a) local x: number = g(37) @@ -358,9 +316,6 @@ TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f() : (a)->a local function id(x:a):a return x end @@ -372,9 +327,6 @@ TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id(id) @@ -384,7 +336,6 @@ TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) -- this will only typecheck if we infer z: any @@ -406,7 +357,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) local z = y @@ -423,12 +373,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") { - ScopedFastFlag sffs[] = { - {"LuauGenericFunctions", true}, - {"LuauParseGenericFunctions", true}, - {"LuauRankNTypes", true}, - }; - CheckResult result = check(R"( type T = { m: (a) -> T } function f(t : T) @@ -440,10 +384,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") TEST_CASE_FIXTURE(Fixture, "dont_unify_bound_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type F = () -> (a, b) -> a type G = (b, b) -> b @@ -470,7 +410,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") // Replaying the classic problem with polymorphism and mutable state in Luau // See, e.g. Tofte (1990) // https://www.sciencedirect.com/science/article/pii/089054019090018D. - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( --!strict -- Our old friend the polymorphic identity function @@ -508,7 +447,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") { - ScopedFastFlag sffs{"LuauGenericFunctions", false}; CheckResult result = check(R"( --!strict local function id(x) return x end @@ -531,8 +469,6 @@ TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f(x:a):a return x end )"); @@ -541,7 +477,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -550,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -559,9 +493,6 @@ TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") TEST_CASE_FIXTURE(Fixture, "variadic_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a) end @@ -573,9 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_generics") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): (a...) return ... end )"); @@ -586,10 +514,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): any return (...) end )"); @@ -599,9 +523,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: T...) return ... @@ -626,9 +547,6 @@ TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f() end )"); @@ -641,8 +559,6 @@ TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(z) local o = {} @@ -665,8 +581,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(x) return {5} end function g(x, y) return f(x) end @@ -680,8 +594,6 @@ TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( type T = { x: {a}, y: {number} } local o1: T = { x = {true}, y = {5} } @@ -697,7 +609,6 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") { - ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; CheckResult result = check(R"( diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 9685f4f3..893bc2b3 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -341,4 +341,43 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X & Y & Z + +local a: XYZ = 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' +caused by: + Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X & Y & Z + +local a: XYZ +local b: number = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 419da8ad..e5c14dde 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -194,9 +194,6 @@ TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } @@ -596,11 +593,9 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - ScopedFastFlag luauFollowInTypeFunApply{"LuauFollowInTypeFunApply", true}; - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; // Mutability in type function application right now can create strange recursive types - // TODO: instantiation right now is problematic, it this example should either leave the Table type alone + // TODO: instantiation right now is problematic, in this example should either leave the Table type alone // or it should rename the type to 'Self' so that the result will be 'Self' CheckResult result = check(R"( type Table = { a: number } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 36dcaa95..733fc39b 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauOrPredicate) LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -133,11 +132,8 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - if (FFlag::LuauOrPredicate) - { - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); } TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") @@ -283,6 +279,8 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -293,7 +291,10 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| x: number? |}' could not be converted into '{| x: number |}'", toString(result.errors[0])); + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") @@ -749,8 +750,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) == "Instance" then @@ -769,8 +768,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; - CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) if typeof(x) == "Instance" then @@ -789,8 +786,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( --!nonstrict @@ -811,11 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) ~= "Instance" or not x:IsA("Part") then @@ -890,8 +880,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_s TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if (not a) or (not b) then @@ -909,8 +897,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if not (a and b) then @@ -928,8 +914,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if (not a) and (not b) then @@ -947,8 +931,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if not (a or b) then @@ -966,8 +948,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(x: any) if type(x) == "number" or type(x) == "string" then @@ -983,8 +963,6 @@ TEST_CASE_FIXTURE(Fixture, "either_number_or_string") TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(t: {x: boolean}?) if not t or t.x then @@ -1000,8 +978,6 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local a: (number | string)? assert(a) @@ -1018,8 +994,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( local function f(b: string | { x: string }, a) @@ -1039,8 +1013,6 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: string | number | boolean) if type(a) ~= "number" and type(a) ~= "string" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index f1451a81..c3694be7 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1950,4 +1950,76 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type A = { x: number, y: number } +type B = { x: number, y: string } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type AS = { x: number, y: number } +type BS = { x: number, y: string } + +type A = { a: boolean, b: AS } +type B = { a: boolean, b: BS } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); +local c1: typeof(a1) = b1 + +local a2 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b2 = setmetatable({ x = 2, y = 4 }, { __call = function(s, t) end }); +local c2: typeof(a2) = b2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' +caused by: + Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' +caused by: + Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); + + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 45381757..30d9130a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3926,8 +3926,6 @@ local b: number = 1 or a TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - CheckResult result = check(R"( --!strict local tbl = {} @@ -4493,10 +4491,6 @@ f(function(x) print(x) end) TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4525,10 +4519,6 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function g1(a: T, f: (T) -> T) return f(a) end local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end @@ -4579,10 +4569,6 @@ local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end @@ -4600,8 +4586,6 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f(): {string|number} return {1, "b", 3} @@ -4625,8 +4609,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f() return {4, "b", 3} :: {string|number} @@ -4638,8 +4620,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a: (number, number) -> number = function(a, b) return a - b end @@ -4655,8 +4635,6 @@ b, c = {2, "s"}, {"b", 4} TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a = {} a.x = 2 @@ -4668,8 +4646,6 @@ a = setmetatable(a, { __call = function(x) end }) TEST_CASE_FIXTURE(Fixture, "refine_and_or") { - ScopedFastFlag sff{"LuauSlightlyMoreFlexibleBinaryPredicates", true}; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t and t.x or 5 @@ -4682,10 +4658,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - }; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t.x and t or 5 @@ -4698,10 +4670,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - }; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t and t.x == 5 or t.x == 31337 @@ -4714,7 +4682,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { - ScopedFastFlag luauFollowInTypeFunApply("LuauFollowInTypeFunApply", true); + ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; CheckResult result = check(R"( type A = { x: number } diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1192a8ac..2d697fc9 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -178,9 +178,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_wor TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( --!strict local function f(...: T): ...T @@ -199,8 +196,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") { - ScopedFastFlag sffs2("LuauGenericFunctions", true); - CheckResult result = check(R"( --!strict table.insert() diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 8dab2605..c6de0abf 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -296,7 +296,6 @@ end TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -361,7 +360,6 @@ local c: Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -395,7 +393,6 @@ local d: { a: typeof(c) } TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -434,7 +431,6 @@ type C = Import.Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -456,7 +452,6 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -475,7 +470,6 @@ type E = X<(number, ...string)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -507,7 +501,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -534,7 +527,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -557,10 +549,8 @@ type D = Y TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block CheckResult result = check(R"( type Y = { f: (T...) -> (U...) } @@ -577,7 +567,6 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -599,7 +588,6 @@ type C = Y<(number), boolean> TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 34c25a9f..9f29b642 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -400,8 +400,6 @@ local e = a.z TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") { - ScopedFastFlag luauSealedTableUnifyOptionalFix("LuauSealedTableUnifyOptionalFix", true); - CheckResult result = check(R"( local x: { x: number } = { x = 3 } type A = number? @@ -426,4 +424,43 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ +local b: { w: number } = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' +caused by: + Not all union options are compatible. Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 930c1a39..91efa818 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -11,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauGenericFunctions); - TEST_SUITE_BEGIN("TypeVarTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 6efa5960..13be3f94 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -113,6 +113,20 @@ assert(bit32.replace(0, -1, 4) == 2^4) assert(bit32.replace(-1, 0, 31) == 2^31 - 1) assert(bit32.replace(-1, 0, 1, 2) == 2^32 - 7) +-- testing countlz/countrc +assert(bit32.countlz(0) == 32) +assert(bit32.countlz(42) == 26) +assert(bit32.countlz(0xffffffff) == 0) +assert(bit32.countlz(0x80000000) == 0) +assert(bit32.countlz(0x7fffffff) == 1) + +assert(bit32.countrz(0) == 32) +assert(bit32.countrz(1) == 0) +assert(bit32.countrz(42) == 1) +assert(bit32.countrz(0x80000000) == 31) +assert(bit32.countrz(0x40000000) == 30) +assert(bit32.countrz(0x7fffffff) == 0) + --[[ This test verifies a fix in luauF_replace() where if the 4th parameter was not a number, but the first three are numbers, it will @@ -136,5 +150,7 @@ assert(bit32.bxor("1", 3) == 2) assert(bit32.bxor(1, "3") == 2) assert(bit32.btest(1, "3") == true) assert(bit32.btest("1", 3) == true) +assert(bit32.countlz("42") == 26) +assert(bit32.countrz("42") == 1) return('OK') From 8fe0dc0b6d553dc053a43f5ed6607c9c4d1a26c0 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 18:23:34 -0800 Subject: [PATCH 004/102] Fix build --- VM/src/lbitlib.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index c72fe674..907c43c4 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lcommon.h" #include "lnumutils.h" LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) From 863d3ff6ffa64398a6d13dc8c189bae30fefd557 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 19:42:50 -0800 Subject: [PATCH 005/102] Attempt to work around non-sensical error --- Analysis/src/TypeInfer.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8fad1af9..b27f3e17 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5030,7 +5030,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } - return std::nullopt; + std::optional res = std::nullopt; + return res; }; std::optional ty = resolveLValue(refis, scope, isaP.lvalue); From 3c3541aba84d9209b6098a5c6ae01727ab11ec32 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 20:36:53 -0800 Subject: [PATCH 006/102] Add a comment --- Analysis/src/TypeInfer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b27f3e17..a6696efd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5030,6 +5030,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } + // local variable works around an odd gcc 9.3 warning: may be used uninitialized std::optional res = std::nullopt; return res; }; From 60e6e86adb5c7a687153989bc471fc05fc4637e8 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 18 Nov 2021 14:21:07 -0800 Subject: [PATCH 007/102] Sync to upstream/release/505 --- Analysis/include/Luau/Documentation.h | 11 +- Analysis/include/Luau/ToString.h | 11 +- Analysis/include/Luau/TypeInfer.h | 24 +- Analysis/include/Luau/TypeVar.h | 87 ++++- Analysis/include/Luau/Unifiable.h | 2 + Analysis/include/Luau/Unifier.h | 1 + Analysis/src/Autocomplete.cpp | 33 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 1 + Analysis/src/Error.cpp | 8 +- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Module.cpp | 16 +- Analysis/src/ToString.cpp | 115 +++++- Analysis/src/Transpiler.cpp | 55 --- Analysis/src/TypeAttach.cpp | 16 + Analysis/src/TypeInfer.cpp | 374 ++++++++++++------- Analysis/src/TypePack.cpp | 1 - Analysis/src/TypeVar.cpp | 50 ++- Analysis/src/Unifier.cpp | 64 +++- Ast/include/Luau/Ast.h | 32 ++ Ast/include/Luau/Parser.h | 1 + Ast/include/Luau/StringUtils.h | 2 + Ast/src/Ast.cpp | 22 ++ Ast/src/Parser.cpp | 68 +++- Ast/src/StringUtils.cpp | 58 +++ CLI/Analyze.cpp | 28 +- CLI/FileUtils.cpp | 46 ++- CLI/FileUtils.h | 3 + CLI/Repl.cpp | 81 ++--- CMakeLists.txt | 17 +- Compiler/include/Luau/Compiler.h | 4 +- Compiler/include/luacode.h | 39 ++ Compiler/src/Compiler.cpp | 46 +-- Compiler/src/lcode.cpp | 29 ++ Makefile | 10 +- Sources.cmake | 3 + VM/include/lua.h | 18 +- VM/include/lualib.h | 6 +- VM/src/lapi.cpp | 10 +- VM/src/lbaselib.cpp | 8 +- VM/src/lbitlib.cpp | 1 + VM/src/lcorolib.cpp | 36 +- VM/src/ldebug.cpp | 10 +- VM/src/ldo.cpp | 6 +- VM/src/linit.cpp | 2 +- VM/src/lstate.cpp | 28 ++ VM/src/lstrlib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- bench/bench.py | 32 +- fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 16 +- tests/Compiler.test.cpp | 63 +--- tests/Conformance.test.cpp | 97 ++--- tests/IostreamOptional.h | 3 +- tests/Module.test.cpp | 6 +- tests/ToString.test.cpp | 111 +++++- tests/TypeInfer.aliases.test.cpp | 2 +- tests/TypeInfer.annotations.test.cpp | 4 +- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.refinements.test.cpp | 3 +- tests/TypeInfer.singletons.test.cpp | 377 ++++++++++++++++++++ tests/TypeInfer.test.cpp | 87 +++-- tests/TypeInfer.tryUnify.test.cpp | 30 +- tests/TypeInfer.unionTypes.test.cpp | 5 +- tests/conformance/coroutine.lua | 54 +++ tests/conformance/debugger.lua | 9 + 65 files changed, 1819 insertions(+), 579 deletions(-) create mode 100644 Compiler/include/luacode.h create mode 100644 Compiler/src/lcode.cpp create mode 100644 tests/TypeInfer.singletons.test.cpp diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7b609b4f..68ff3a7c 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -12,10 +12,17 @@ namespace Luau struct FunctionDocumentation; struct TableDocumentation; struct OverloadedFunctionDocumentation; +struct BasicDocumentation; -using Documentation = Luau::Variant; +using Documentation = Luau::Variant; using DocumentationSymbol = std::string; +struct BasicDocumentation +{ + std::string documentation; + std::string learnMoreLink; +}; + struct FunctionParameterDocumentation { std::string name; @@ -29,6 +36,7 @@ struct FunctionDocumentation std::string documentation; std::vector parameters; std::vector returns; + std::string learnMoreLink; }; struct OverloadedFunctionDocumentation @@ -43,6 +51,7 @@ struct TableDocumentation { std::string documentation; Luau::DenseHashMap keys; + std::string learnMoreLink; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index e5683fc4..50379c1c 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -23,10 +23,11 @@ struct ToStringNameMap struct ToStringOptions { - bool exhaustive = false; // If true, we produce complete output rather than comprehensible output - bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. - bool functionTypeArguments = false; // If true, output function type argument names when they are available - bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression void dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 306ac77d..78d642c5 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -175,10 +175,10 @@ struct TypeChecker std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors); + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); - ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); @@ -282,6 +282,14 @@ public: // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); + // Produce an "emergency backup type" for recovery from type errors. + // This comes in two flavours, depening on whether or not we can make a good guess + // for an error recovery type. + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(const ScopePtr& scope); + TypePackId errorRecoveryTypePack(const ScopePtr& scope); + private: void prepareErrorsForDisplay(ErrorVec& errVec); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); @@ -297,6 +305,10 @@ private: TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); + // Produce a new singleton type var. + TypeId singletonType(bool value); + TypeId singletonType(std::string value); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -330,8 +342,8 @@ private: const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -347,7 +359,6 @@ private: void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; @@ -387,12 +398,9 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6bd7932d..093ea431 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -108,6 +108,79 @@ struct PrimitiveTypeVar } }; +// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Types for true and false +struct BoolSingleton +{ + bool value; + + bool operator==(const BoolSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const BoolSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// Types for "foo", "bar" etc. +struct StringSingleton +{ + std::string value; + + bool operator==(const StringSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const StringSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// No type for float singletons, partly because === isn't any equalivalence on floats +// (NaN != NaN). + +using SingletonVariant = Luau::Variant; + +struct SingletonTypeVar +{ + explicit SingletonTypeVar(const SingletonVariant& variant) + : variant(variant) + { + } + + explicit SingletonTypeVar(SingletonVariant&& variant) + : variant(std::move(variant)) + { + } + + // Default operator== is C++20. + bool operator==(const SingletonTypeVar& rhs) const + { + return variant == rhs.variant; + } + + bool operator!=(const SingletonTypeVar& rhs) const + { + return !(*this == rhs); + } + + SingletonVariant variant; +}; + +template +const T* get(const SingletonTypeVar* stv) +{ + if (stv) + return get_if(&stv->variant); + else + return nullptr; +} + struct FunctionArgument { Name name; @@ -332,8 +405,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders bool maybeGeneric(const TypeId ty); +// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton +bool maybeSingleton(TypeId ty); + struct SingletonTypes { const TypeId nilType; @@ -418,16 +494,19 @@ struct SingletonTypes const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(); + TypePackId errorRecoveryTypePack(); + private: std::unique_ptr arena; TypeId makeStringMetatable(); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index c2e07e46..b47610fc 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -105,6 +105,8 @@ private: struct Error { + // This constructor has to be public, since it's used in TypeVar and TypePack, + // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); int index; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index be0aadd0..503034a1 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -65,6 +65,7 @@ struct Unifier private: void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyPrimitives(TypeId superTy, TypeId subTy); + void tryUnifySingletons(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1c94bb68..6fc0b3f8 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - unifier.tryUnify(expectedType, actualType); + if (FFlag::LuauAutocompleteAvoidMutation) + { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); - bool ok = unifier.errors.empty(); - unifier.log.rollback(); - return ok; + auto errors = unifier.canUnify(expectedType, actualType); + return errors.empty(); + } + else + { + unifier.tryUnify(expectedType, actualType); + + bool ok = unifier.errors.empty(); + unifier.log.rollback(); + return ok; + } }; auto expr = node->asExpr(); @@ -1496,11 +1510,9 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 96703ef1..9f5c8250 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -153,6 +153,7 @@ declare function gcinfo(): number wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, + close: (thread) -> (boolean, any?) } declare table: { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 46ff2c72..f80d50a7 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -180,13 +180,13 @@ struct ErrorConverter switch (e.context) { case CountMismatch::Return: - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + - std::to_string(e.actual) + " " + actualVerb + " returned here"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " returned here"; case CountMismatch::Result: // It is alright if right hand side produces more values than the // left hand side accepts. In this context consider only the opposite case. - return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + - std::to_string(e.actual) + " are required here"; + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 2f411274..1e97705d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -458,8 +457,7 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); - if (FFlag::LuauClearScopes) - module->scopes.resize(1); + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 880ffd2e..32a0646a 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -161,6 +161,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); void operator()(const MetatableTypeVar& t); @@ -199,7 +200,9 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; } void operator()(const Unifiable::Generic& t) @@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t) { if (encounteredFreeType) *encounteredFreeType = true; - - seenTypes[typeId] = dest.addType(ErrorTypeVar{}); + TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t) { defaultClone(t); } + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); } +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const FunctionTypeVar& t) { TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 885fd489..735bfa50 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -350,6 +350,23 @@ struct TypeVarStringifier } } + void operator()(TypeId, const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = Luau::get(&stv)) + state.emit(bs->value ? "true" : "false"); + else if (const StringSingleton* ss = Luau::get(&stv)) + { + state.emit("\""); + state.emit(escape(ss->value)); + state.emit("\""); + } + else + { + LUAU_ASSERT(!"Unknown singleton type"); + throw std::runtime_error("Unknown singleton type"); + } + } + void operator()(TypeId, const FunctionTypeVar& ftv) { if (state.hasSeen(&ftv)) @@ -359,6 +376,7 @@ struct TypeVarStringifier return; } + // We should not be respecting opts.hideNamedFunctionTypeParameters here. if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) { state.emit("<"); @@ -514,7 +532,14 @@ struct TypeVarStringifier break; } - state.emit(name); + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } state.emit(": "); stringify(prop.type); comma = true; @@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + std::string s = prefix; + + auto toString_ = [&opts](TypeId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + auto toStringPack_ = [&opts](TypePackId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) + { + s += "<"; + + bool first = true; + for (TypeId g : ftv.generics) + { + if (!first) + s += ", "; + first = false; + s += toString_(g); + } + + for (TypePackId gp : ftv.genericPacks) + { + if (!first) + s += ", "; + first = false; + s += toStringPack_(gp); + } + + s += ">"; + } + + s += "("; + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + s += ", "; + first = false; + + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += toString_(*argPackIter); + + ++argPackIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } + } + + if (argPackIter.tail()) + { + if (auto vtp = get(*argPackIter.tail())) + s += ", ...: " + toString_(vtp->ty); + else + s += ", ...: " + toStringPack_(*argPackIter.tail()); + } + + s += "): "; + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + if (retSize == 0 && !hasTail) + s += "()"; + else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) + s += toStringPack_(ftv.retType); + else + s += "(" + toStringPack_(ftv.retType) + ")"; + + return s; +} + void dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 7d880af4..6627fbe3 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { - -std::string escape(std::string_view s) -{ - std::string r; - r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting - - for (uint8_t c : s) - { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') - r += c; - else - { - r += '\\'; - - switch (c) - { - case '\a': - r += 'a'; - break; - case '\b': - r += 'b'; - break; - case '\f': - r += 'f'; - break; - case '\n': - r += 'n'; - break; - case '\r': - r += 'r'; - break; - case '\t': - r += 't'; - break; - case '\v': - r += 'v'; - break; - case '\'': - r += '\''; - break; - case '\"': - r += '\"'; - break; - case '\\': - r += '\\'; - break; - default: - Luau::formatAppend(r, "%03u", c); - } - } - } - - return r; -} - bool isIdentifierStartChar(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 11aa7b39..af6d2543 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -96,6 +96,22 @@ public: return nullptr; } } + + AstType* operator()(const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = get(&stv)) + return allocator->alloc(Location(), bs->value); + else if (const StringSingleton* ss = get(&stv)) + { + AstArray value; + value.data = const_cast(ss->value.c_str()); + value.size = strlen(value.data); + return allocator->alloc(Location(), value); + } + else + return nullptr; + } + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8fad1af9..b2ae94c7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) namespace Luau { @@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(singletonTypes.booleanType) , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) - , errorType(singletonTypes.errorType) , optionalNumberType(singletonTypes.optionalNumberType) , anyTypePack(singletonTypes.anyTypePack) - , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = ErrorTypeVar{}; + *asMutable(type) = *errorRecoveryType(anyType); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) else if (auto tail = valueIter.tail()) { if (get(*tail)) - right = errorType; + right = errorRecoveryType(scope); else if (auto vtp = get(*tail)) right = vtp->ty; else if (get(*tail)) @@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorType, forin.location); + unify(var, errorRecoveryType(scope), forin.location); return check(loopScope, *forin.body); } @@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorType; + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) unify(var, varTy, forin.location); @@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; else - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; } else { @@ -1398,7 +1399,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult result; @@ -1407,12 +1408,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = checkExpr(scope, *a->expr); else if (expr.is()) result = {nilType}; - else if (expr.is()) - result = {booleanType}; + else if (const AstExprConstantBool* bexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(bexpr->value)}; + else + result = {booleanType}; + } + else if (const AstExprConstantString* sexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + else + result = {stringType}; + } else if (expr.is()) result = {numberType}; - else if (expr.is()) - result = {stringType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1485,7 +1496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1497,7 +1508,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1509,7 +1520,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa std::vector types = flatten(varargPack).first; return {!types.empty() ? types[0] : nilType}; } - else if (auto ftp = get(varargPack)) + else if (get(varargPack)) { TypeId head = freshType(scope); TypePackId tail = freshTypePack(scope); @@ -1517,14 +1528,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {head}; } if (get(varargPack)) - return {errorType}; + return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorType}; + return {errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1539,7 +1550,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (auto ftp = get(retPack)) + else if (get(retPack)) { TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); @@ -1547,7 +1558,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa return {head, std::move(result.predicates)}; } if (get(retPack)) - return {errorType, std::move(result.predicates)}; + return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) @@ -1572,7 +1583,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - return {errorType}; + return {errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) @@ -1876,6 +1887,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; + const UnionTypeVar* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; @@ -1894,6 +1906,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } + else if (FFlag::LuauExpectedTypesOfProperties) + if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -1916,6 +1931,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } + else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + { + std::vector expectedResultTypes; + for (TypeId expectedOption : expectedUnion) + if (const TableTypeVar* ttv = get(follow(expectedOption))) + if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) + expectedResultTypes.push_back(prop->second.type); + if (expectedResultTypes.size() == 1) + expectedResultType = expectedResultTypes[0]; + else if (expectedResultTypes.size() > 1) + expectedResultType = addType(UnionTypeVar{expectedResultTypes}); + } } else { @@ -1958,21 +1985,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) - return {errorType}; + retType = errorRecoveryType(retType); - return {first(retType).value_or(nilType)}; + return {retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorType}; + return {errorRecoveryType(scope)}; } reportErrors(tryUnify(numberType, operandType, expr.location)); @@ -1984,7 +2012,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) - return {errorType}; + return {errorRecoveryType(scope)}; if (get(operandType)) return {numberType}; // Not strictly correct: metatables permit overriding this @@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b if (unify(a, b, location)) return a; - return errorType; + return errorRecoveryType(anyType); } if (*a == *b) @@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); + // TODO: this check seems odd, the second part is redundant + // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } if (leftMetatable) @@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } } @@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation( { reportError( expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } @@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); - return errorType; + return errorRecoveryType(booleanType); } if (needsMetamethod) { reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } return booleanType; @@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - return errorType; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation( auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); + bool hasErrors = !state.errors.empty(); - if (!state.errors.empty()) - return errorType; + if (FFlag::LuauErrorRecoveryType && hasErrors) + { + // If there are unification errors, the return type may still be unknown + // so we loosen the argument types to see if that helps. + TypePackId fallbackArguments = freshTypePack(scope); + TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + state.log.rollback(); + state.errors.clear(); + state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); + if (!state.errors.empty()) + state.log.rollback(); + } - return first(retType).value_or(nilType); + TypeId retType = first(retTypePack).value_or(nilType); + if (hasErrors) + retType = errorRecoveryType(retType); + + return retType; }; std::string op = opToMetaTableEntry(expr.op); @@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation( reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), toString(lhsType).c_str(), toString(rhsType).c_str())}); - return errorType; + + return errorRecoveryType(scope); } switch (expr.op) @@ -2414,11 +2461,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); if (!errorVec.empty()) - { - reportErrors(errorVec); - return {errorType, std::move(result.predicates)}; - } + annotationType = errorRecoveryType(annotationType); return {annotationType, std::move(result.predicates)}; } @@ -2434,7 +2479,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) @@ -2476,7 +2521,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return std::pair(errorType, nullptr); + return {errorRecoveryType(scope), nullptr}; } else ice("Unexpected AST node in checkLValue", expr.location); @@ -2488,7 +2533,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope return {*ty, nullptr}; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorType, nullptr}; + return {errorRecoveryType(scope), nullptr}; } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) @@ -2545,24 +2590,25 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { Unifier state = mkUnifier(expr.location); state.tryUnify(indexer->indexType, stringType); + TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { state.log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); - return std::pair(errorType, nullptr); + retType = errorRecoveryType(retType); } - return std::pair(indexer->indexResultType, nullptr); + return std::pair(retType, nullptr); } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2571,7 +2617,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); @@ -2585,12 +2631,12 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) @@ -2615,7 +2661,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); } @@ -2626,7 +2672,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } if (value) @@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) if (isNonstrictMode()) return globalScope->bindings[name].typeId; - return errorType; + return errorRecoveryType(scope); } else { @@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) TableTypeVar* ttv = getMutableTableType(lhsType); if (!ttv) { - if (!isTableIntersection(lhsType)) + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - return errorType; + return errorRecoveryType(scope); } // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorType; + return errorRecoveryType(scope); Name name = indexName->index.value; if (ttv->props.count(name)) - return errorType; + return errorRecoveryType(scope); Property& property = ttv->props[name]; @@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) return property.type; } else if (funName.is()) - { - return errorType; - } + return errorRecoveryType(scope); else { ice("Unexpected AST node type", funName.location); @@ -2991,7 +3036,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A else if (expr.is()) { if (!scope->varargPack) - return {addTypePack({addType(ErrorTypeVar())})}; + return {errorRecoveryTypePack(scope)}; return {*scope->varargPack}; } @@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList( if (get(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. - TypeId argTy = errorType; while (paramIter != endIter) { - state.tryUnify(*paramIter, argTy); + state.tryUnify(*paramIter, errorRecoveryType(anyType)); ++paramIter; } return; @@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList( { while (argIter != endIter) { - unify(*argIter, errorType, state.location); + unify(*argIter, errorRecoveryType(scope), state.location); ++argIter; } // For this case, we want the error span to cover every errant extra parameter @@ -3246,7 +3290,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A // For each overload // Compare parameter and argument types // Report any errors (also speculate dot vs colon warnings!) - // If there are no errors, return the resulting return type + // Return the resulting return type (even if there are errors) + // If there are no matching overloads, unify with (a...) -> (b...) and return b... TypeId selfType = nullptr; TypeId functionType = nullptr; @@ -3268,8 +3313,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - functionType = errorType; - actualFunctionType = errorType; + functionType = errorRecoveryType(scope); + actualFunctionType = functionType; } } else @@ -3296,7 +3341,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A TypePackId argPack = argListResult.type; if (get(argPack)) - return ExprResult{errorTypePack}; + return {errorRecoveryTypePack(scope)}; TypePack* args = getMutable(argPack); LUAU_ASSERT(args != nullptr); @@ -3314,19 +3359,34 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector errors; // errors encountered for each overload std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; for (TypeId fn : overloads) { fn = follow(fn); - if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) return {retPack}; - return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + + if (FFlag::LuauErrorRecoveryType) + { + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; + } + + return {errorRecoveryTypePack(retPack)}; } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -3382,7 +3442,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors) + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { fn = stripFromNilAndReport(fn, expr.func->location); @@ -3394,7 +3454,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + return {{errorRecoveryTypePack(scope)}}; } if (get(fn)) @@ -3427,14 +3487,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload( - scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + overloadsThatMatchArgCount, overloadsThatDont, errors); } } reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorTypePack, expr.func->location); - return {{errorTypePack}}; + unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + return {{errorRecoveryTypePack(retPack)}}; } // When this function type has magic functions and did return something, we select that overload instead. @@ -3476,6 +3536,8 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); + else if (FFlag::LuauErrorRecoveryType) + overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); state.log.rollback(); @@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, - TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, - const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors) { if (overloads.size() == 1) { reportErrors(std::get<0>(errors.front())); - return {errorTypePack}; + return; } std::vector overloadTypes = overloadsThatMatchArgCount; @@ -3622,7 +3684,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) - return {errorTypePack}; + return; } std::string s; @@ -3655,7 +3717,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); // No viable overload - return {errorTypePack}; + return; } ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, @@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); - return errorType; + return errorRecoveryType(anyType); } return anyType; @@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module reportError(TypeError{location, UnknownRequire{reportedModulePath}}); } - return errorType; + return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } std::optional moduleType = first(module->getModuleScope()->returnType); @@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } SeenTypes seenTypes; @@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!qty.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (ty == *qty) @@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } } @@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(anyType); } } @@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo else { reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryTypePack(anyTypePack); } } @@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } +TypeId TypeChecker::singletonType(bool value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); +} + +TypeId TypeChecker::singletonType(std::string value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); +} + +TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryType(); +} + +TypeId TypeChecker::errorRecoveryType(TypeId guess) +{ + return singletonTypes.errorRecoveryType(guess); +} + +TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryTypePack(); +} + +TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) +{ + return singletonTypes.errorRecoveryTypePack(guess); +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(anyType); } ToStringOptions opts; @@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { if (lit->name == Parser::errorName) - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); std::string typeName; if (lit->hasPrefix) @@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) @@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } - else if (FFlag::LuauTypeAliasPacks) + + if (FFlag::LuauTypeAliasPacks) { if (!lit->hasParameterList && !tf->typePackParams.empty()) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } std::vector typeParams; @@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - return addType(ErrorTypeVar{}); + + if (FFlag::LuauErrorRecoveryType) + { + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); + } + else + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) @@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation for (const auto& param : lit->parameters) typeParams.push_back(resolveType(scope, *param.type)); + if (FFlag::LuauErrorRecoveryType) + { + // If there aren't enough type parameters, pad them out with error recovery types + // (we've already reported the error) + while (typeParams.size() < lit->parameters.size) + typeParams.push_back(errorRecoveryType(scope)); + } + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { // If the generic parameters and the type arguments are the same, we are about to @@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(IntersectionTypeVar{types}); } - else if (annotation.is()) + else if (const auto& tsb = annotation.as()) { - return addType(ErrorTypeVar{}); + return singletonType(tsb->value); } + else if (const auto& tss = annotation.as()) + { + return singletonType(std::string(tss->value.data, tss->value.size)); + } + else if (annotation.is()) + return errorRecoveryType(scope); else { reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } } @@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return addTypePack(TypePackVar{Unifiable::Error{}}); + return errorRecoveryTypePack(scope); } return *genericTy; @@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (!maybeInstantiated.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); - return errorType; + return errorRecoveryType(scope); } TypeId instantiated = *maybeInstantiated; @@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); @@ -5030,7 +5150,9 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; }; std::optional ty = resolveLValue(refis, scope, isaP.lvalue); @@ -5041,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement addRefinement(refis, isaP.lvalue, *result); else { - addRefinement(refis, isaP.lvalue, errorType); + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); } } @@ -5105,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec addRefinement(refis, typeguardP.lvalue, *result); else { - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); if (sense) errVec.push_back( TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); @@ -5116,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec auto fail = [&](const TypeErrorData& err) { errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; if (!typeguardP.isTypeof) @@ -5137,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); } -void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) -{ - if (!sense) - return; - - static std::vector primitives{ - "string", "number", "boolean", "nil", "thread", - "table", // no op. Requires special handling. - "function", // no op. Requires special handling. - "userdata", // no op. Requires special handling. - }; - - if (auto typeFun = globalScope->lookupType(typeguardP.kind); - typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) - { - if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - else if (typeguardP.isTypeof) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - } -} - void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 228b1926..d3221c73 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp) { return const_cast(tp); } - } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index cd447ca2..924bf082 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAG(LuauErrorRecoveryType) namespace Luau { @@ -293,7 +294,7 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (get(ty)) return true; else if (auto ttv = get(ty)) { @@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty) return isGeneric(ty); } +bool maybeSingleton(TypeId ty) +{ + ty = follow(ty); + if (get(ty)) + return true; + if (const UnionTypeVar* utv = get(ty)) + for (TypeId option : utv) + if (get(follow(option))) + return true; + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) @@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes() , booleanType(&booleanType_) , threadType(&threadType_) , anyType(&anyType_) - , errorType(&errorType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) - , errorTypePack(&errorTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable() return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } +TypeId SingletonTypes::errorRecoveryType() +{ + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack() +{ + return &errorTypePack_; +} + +TypeId SingletonTypes::errorRecoveryType(TypeId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorTypePack_; +} + SingletonTypes singletonTypes; void persist(TypeId ty) @@ -1141,6 +1178,11 @@ struct QVarFinder return false; } + bool operator()(const SingletonTypeVar&) const + { + return false; + } + bool operator()(const FunctionTypeVar& ftv) const { if (hasGeneric(ftv.argTypes)) @@ -1412,7 +1454,7 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha else if (strchr(options, data[i])) result.push_back(typechecker.numberType); else - result.push_back(typechecker.errorType); + result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 82f621b6..e1a52be4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) +LUAU_FASTFLAG(LuauErrorRecoveryType); namespace Luau { @@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(subTy, superTy); + // The occurrence check might have caused superTy no longer to be a free type if (!get(subTy)) { log(subTy); @@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (l && r) { - log(superTy); + if (!FFlag::LuauErrorRecoveryType) + log(superTy); occursCheck(superTy, subTy); r->level = min(r->level, l->level); - *asMutable(superTy) = BoundTypeVar(subTy); + + // The occurrence check might have caused superTy no longer to be a free type + if (!FFlag::LuauErrorRecoveryType) + *asMutable(superTy) = BoundTypeVar(subTy); + else if (!get(superTy)) + { + log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } + return; } else if (l) @@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } + // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { if (auto rightLevel = getMutableLevel(subTy)) @@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); } + return; } else if (r) @@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); + else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) + tryUnifySingletons(superTy, subTy); + else if (get(superTy) && get(subTy)) tryUnifyFunctions(superTy, subTy, isFunctionCall); @@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal { occursCheck(superTp, subTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(superTp)) { log(superTp); *asMutable(superTp) = Unifiable::Bound(subTp); } } - else if (get(subTp)) { occursCheck(subTp, superTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(subTp)) { log(subTp); @@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorType, *superIter); + tryUnify_(singletonTypes.errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorType, *subIter); + tryUnify_(singletonTypes.errorRecoveryType(), *subIter); subIter.advance(); } @@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } +void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +{ + const PrimitiveTypeVar* lp = get(superTy); + const SingletonTypeVar* ls = get(superTy); + const SingletonTypeVar* rs = get(subTy); + + if ((!lp && !ls) || !rs) + ice("passed non singleton/primitive types to unifySingletons"); + + if (ls && *ls == *rs) + return; + + if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + return; + + if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + return; + + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); +} + void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) { FunctionTypeVar* lf = getMutable(superTy); @@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && + lt->state != TableState::Free) { for (const auto& [propName, subProp] : rt->props) { @@ -1038,7 +1079,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) return; } } - + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorType); + tryUnify_(prop.type, singletonTypes.errorRecoveryType()); } else { @@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorType; + const TypeId anyTy = singletonTypes.errorRecoveryType(); if (FFlag::LuauTypecheckOpts) { @@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryType(); return; } @@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); return; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a2189f7b..5b4bfa03 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -255,6 +255,14 @@ public: { return visit((class AstType*)node); } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit((class AstType*)node); + } virtual bool visit(class AstTypeError* node) { return visit((class AstType*)node); @@ -1158,6 +1166,30 @@ public: unsigned messageIndex; }; +class AstTypeSingletonBool : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonBool) + + AstTypeSingletonBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstTypeSingletonString : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonString) + + AstTypeSingletonString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + const AstArray value; +}; + class AstTypePack : public AstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 39c7d925..87ebc48b 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -286,6 +286,7 @@ private: // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); + std::optional> parseCharArray(); AstExpr* parseString(); AstLocal* pushLocal(const Binding& binding); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 4f7673fa..6ecf0606 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); +std::string escape(std::string_view s); +bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index b1209faa..e709894d 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor) } } +AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray& value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a1bad65e..bc63e37d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau @@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (lexer.current().type == '[') + if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && + (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back({AstName(chars->data), begin.location, type}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } + else if (lexer.current().type == '[') { if (indexer) { @@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + return {allocator.alloc(begin, true)}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + return {allocator.alloc(begin, false)}; + } + else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(begin, svalue)}; + } + else + return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + { + Location location = lexer.current().location; + nextLexeme(); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")}; + } else if (lexer.current().type == Lexeme::Name) { std::optional prefix; @@ -2416,7 +2463,7 @@ AstArray Parser::parseTypeParams() return copy(parameters); } -AstExpr* Parser::parseString() +std::optional> Parser::parseCharArray() { LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); @@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString() { if (!Lexer::fixupQuotedString(scratchData)) { - Location location = lexer.current().location; - nextLexeme(); - - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + return std::nullopt; } } else @@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString() Lexer::fixupMultilineString(scratchData); } - Location start = lexer.current().location; AstArray value = copy(scratchData); - nextLexeme(); + return value; +} - return allocator.alloc(start, value); +AstExpr* Parser::parseString() +{ + Location location = lexer.current().location; + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); } AstLocal* Parser::pushLocal(const Binding& binding) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 24b2283a..9c7fed31 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size) return hash; } +bool isIdentifier(std::string_view s) +{ + return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); +} + +std::string escape(std::string_view s) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + r += c; + else + { + r += '\\'; + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9ab10aaf..ebdd7896 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -236,32 +236,12 @@ int main(int argc, char** argv) Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - // Look for .luau first and if absent, fall back to .lua - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - }); - } - else - { - failed += !analyzeFile(frontend, argv[i], format, annotate); - } - } + for (const std::string& path : files) + failed += !analyzeFile(frontend, path.c_str(), format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index 0702b74f..b3c9557b 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -142,6 +142,7 @@ static bool traverseDirectoryRec(const std::string& path, const std::function getParentPath(const std::string& path) return ""; } + +static std::string getExtension(const std::string& path) +{ + std::string::size_type dot = path.find_last_of(".\\/"); + + if (dot == std::string::npos || path[dot] != '.') + return ""; + + return path.substr(dot); +} + +std::vector getSourceFiles(int argc, char** argv) +{ + std::vector files; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + }); + } + else + { + files.push_back(argv[i]); + } + } + + return files; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index f7fbe8af..da11f512 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -4,6 +4,7 @@ #include #include #include +#include std::optional readFile(const std::string& name); @@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function getParentPath(const std::string& path); + +std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5c904cca..b29cd6f9 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -20,7 +20,7 @@ enum class CompileFormat { - Default, + Text, Binary }; @@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; lua_pushnil(L); @@ -80,7 +80,7 @@ static int lua_require(lua_State* L) // now we can compile & run module on the new thread std::string bytecode = Luau::compile(*source); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { int status = lua_resume(ML, L, 0); @@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source); - if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { size_t len; const char* msg = lua_tolstring(L, -1, &len); @@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL) std::string bytecode = Luau::compile(*source); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { status = lua_resume(L, NULL, 0); } @@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL) status = LUA_ERRSYNTAX; } - if (status == 0) - { - return true; - } - else + if (status != 0) { std::string error; @@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL) error += lua_debugtrace(L); fprintf(stderr, "%s", error.c_str()); - return false; } + + lua_pop(GL, 1); + return status == 0; } static void report(const char* name, const Luau::Location& location, const char* type, const char* message) @@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format) try { Luau::BytecodeBuilder bcb; - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); - bcb.setDumpSource(*source); + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source); switch (format) { - case CompileFormat::Default: + case CompileFormat::Text: printf("%s", bcb.dumpEverything().c_str()); break; case CompileFormat::Binary: @@ -504,7 +506,7 @@ int main(int argc, char** argv) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - CompileFormat format = CompileFormat::Default; + CompileFormat format = CompileFormat::Text; if (strcmp(argv[1], "--compile=binary") == 0) format = CompileFormat::Binary; @@ -514,27 +516,12 @@ int main(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 2; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - failed += !compileFile(name.c_str(), format); - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str(), format); - }); - } - else - { - failed += !compileFile(argv[i], format); - } - } + for (const std::string& path : files) + failed += !compileFile(path.c_str(), format); return failed; } @@ -548,33 +535,25 @@ int main(int argc, char** argv) int profile = 0; for (int i = 1; i < argc; ++i) + { + if (argv[i][0] != '-') + continue; + if (strcmp(argv[i], "--profile") == 0) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + } if (profile) profilerStart(L, profile); + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !runFile(name.c_str(), L); - }); - } - else - { - failed += !runFile(argv[i], L); - } - } + for (const std::string& path : files) + failed += !runFile(path.c_str(), L); if (profile) { diff --git a/CMakeLists.txt b/CMakeLists.txt index 36014a98..9c69521e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_WERROR "Warnings as errors" OFF) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) @@ -57,11 +58,18 @@ set(LUAU_OPTIONS) if(MSVC) list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. - list(APPEND LUAU_OPTIONS /WX) # Warnings are errors list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores else() list(APPEND LUAU_OPTIONS -Wall) # All warnings - list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors +endif() + +# Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +if(LUAU_WERROR) + if(MSVC) + list(APPEND LUAU_OPTIONS /WX) # Warnings are errors + else() + list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors + endif() endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) @@ -79,7 +87,10 @@ if(LUAU_BUILD_CLI) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) if(UNIX) - target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + endif() endif() if(NOT EMSCRIPTEN) diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 4f88e602..65e962da 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,11 +13,9 @@ class AstNameTable; class BytecodeBuilder; class BytecodeEncoder; +// Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { - // default bytecode version target; can be used to compile code for older clients - int bytecodeVersion = 1; - // 0 - no optimization // 1 - baseline optimization level that doesn't prevent debuggability // 2 - includes optimizations that harm debuggability such as inlining diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h new file mode 100644 index 00000000..e235a2e7 --- /dev/null +++ b/Compiler/include/luacode.h @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +/* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUACODE_API +#define LUACODE_API extern +#endif + +typedef struct lua_CompileOptions lua_CompileOptions; + +struct lua_CompileOptions +{ + // 0 - no optimization + // 1 - baseline optimization level that doesn't prevent debuggability + // 2 - includes optimizations that harm debuggability such as inlining + int optimizationLevel; // default=1 + + // 0 - no debugging support + // 1 - line info & function names only; sufficient for backtraces + // 2 - full debug info with local & upvalue names; necessary for debugger + int debugLevel; // default=1 + + // 0 - no code coverage support + // 1 - statement coverage + // 2 - statement and expression coverage (verbose) + int coverageLevel; // default=0 + + // global builtin to construct vectors; disabled by default + const char* vectorLib; + const char* vectorCtor; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals; +}; + +/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */ +LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9712f02f..5b93c1dc 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,9 +11,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) -LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; -// TODO: Remove with LuauGenericSpecialGlobals -static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; - CompileError::CompileError(const Location& location, const std::string& message) : location(location) , message(message) @@ -466,7 +460,7 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it @@ -482,18 +476,6 @@ struct Compiler } } } - // Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects - // (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used) - else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - return; - } - } if (!shared) bytecode.emitAD(LOP_NEWCLOSURE, target, pid); @@ -3298,8 +3280,7 @@ struct Compiler bool visit(AstStatLocalFunction* node) override { // record local->function association for some optimizations - if (FFlag::LuauPreloadClosuresUpval) - self->locals[node->name].func = node->func; + self->locals[node->name].func = node->func; return true; } @@ -3711,24 +3692,13 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (FFlag::LuauGenericSpecialGlobals) - { - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; - } - else - { - for (const char* global : kSpecialGlobals) - { - if (AstName name = names.get(global); name.value) + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) compiler.globals[name].writable = true; - } - } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written Compiler::AssignmentVisitor assignmentVisitor(&compiler); @@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) + if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp new file mode 100644 index 00000000..ee150b17 --- /dev/null +++ b/Compiler/src/lcode.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luacode.h" + +#include "Luau/Compiler.h" + +#include + +char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize) +{ + LUAU_ASSERT(outsize); + + Luau::CompileOptions opts; + + if (options) + { + static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match"); + memcpy(static_cast(&opts), options, sizeof(opts)); + } + + std::string result = compile(std::string(source, size), opts); + + char* copy = static_cast(malloc(result.size())); + if (!copy) + return nullptr; + + memcpy(copy, result.data(), result.size()); + *outsize = result.size(); + return copy; +} diff --git a/Makefile b/Makefile index 5d51b3d4..cab3d43f 100644 --- a/Makefile +++ b/Makefile @@ -46,14 +46,20 @@ endif OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) # common flags -CXXFLAGS=-g -Wall -Werror +CXXFLAGS=-g -Wall LDFLAGS= -# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused +# some gcc versions treat var in `if (type var = val)` as unused +# some gcc versions treat variables used in constexpr if blocks as unused ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) CXXFLAGS+=-Wno-unused endif +# enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +ifneq ($(werror),) + CXXFLAGS+=-Werror +endif + # configuration-specific flags ifeq ($(config),release) CXXFLAGS+=-O2 -DNDEBUG diff --git a/Sources.cmake b/Sources.cmake index c30cf77d..23b931c6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE Compiler/include/Luau/Bytecode.h Compiler/include/Luau/BytecodeBuilder.h Compiler/include/Luau/Compiler.h + Compiler/include/luacode.h Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/lcode.cpp ) # Luau.Analysis Sources @@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.singletons.test.cpp tests/TypeInfer.tables.test.cpp tests/TypeInfer.test.cpp tests/TypeInfer.tryUnify.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index a9d3e875..1568d191 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud); LUA_API void lua_close(lua_State* L); LUA_API lua_State* lua_newthread(lua_State* L); LUA_API lua_State* lua_mainthread(lua_State* L); +LUA_API void lua_resetthread(lua_State* L); +LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation @@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); -LUA_API void lua_pushcfunction( - lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL); +LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); @@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx); LUA_API void lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); -LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); +LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); -LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); +LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx); /* ** `load' and `call' functions (load and run Luau bytecode) */ -LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); +LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env); LUA_API void lua_call(lua_State* L, int nargs, int nresults); LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); @@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) +#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) +#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) @@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n); LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); -LUA_API void lua_singlestep(lua_State* L, bool singlestep); -LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); +LUA_API void lua_singlestep(lua_State* L, int enabled); +LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); @@ -361,6 +364,7 @@ struct lua_Callbacks void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ }; +typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 30cffaff..fa836955 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -8,11 +8,12 @@ #define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) #define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) -typedef struct luaL_Reg +struct luaL_Reg { const char* name; lua_CFunction func; -} luaL_Reg; +}; +typedef struct luaL_Reg luaL_Reg; LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); @@ -75,6 +76,7 @@ struct luaL_Buffer struct TString* storage; char buffer[LUA_BUFFERSIZE]; }; +typedef struct luaL_Buffer luaL_Buffer; // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 7e742644..a79ba0d4 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) return ret; } -void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) +void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) { luaC_checkGC(L); luaC_checkthreadsleep(L); @@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec) return; } -void lua_setreadonly(lua_State* L, int objindex, bool value) +void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); - t->readonly = value; + t->readonly = bool(enabled); return; } @@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex) return res; } -void lua_setsafeenv(lua_State* L, int objindex, bool value) +void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); - t->safeenv = value; + t->safeenv = bool(enabled); return; } diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 87fc1631..61798e2b 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = { static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) { - lua_pushcfunction(L, u); - lua_pushcfunction(L, f, name, 1); + lua_pushcfunction(L, u, NULL); + lua_pushcclosure(L, f, name, 1); lua_setfield(L, -2, name); } @@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L) auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "pairs", luaB_pairs, luaB_next); - lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); + lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont); lua_setfield(L, -2, "pcall"); - lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); + lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); lua_setfield(L, -2, "xpcall"); return 1; diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index c72fe674..907c43c4 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lcommon.h" #include "lnumutils.h" LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 9724c0e7..0178fae8 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,6 +5,8 @@ #include "lstate.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) + #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -208,8 +210,7 @@ static int cowrap(lua_State* L) { cocreate(L); - lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); - + lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } @@ -232,6 +233,34 @@ static int coyieldable(lua_State* L) return 1; } +static int coclose(lua_State* L) +{ + if (!FFlag::LuauCoroutineClose) + luaL_error(L, "coroutine.close is not enabled"); + + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + + int status = auxstatus(L, co); + if (status != CO_DEAD && status != CO_SUS) + luaL_error(L, "cannot close %s coroutine", statnames[status]); + + if (co->status == LUA_OK || co->status == LUA_YIELD) + { + lua_pushboolean(L, true); + lua_resetthread(co); + return 1; + } + else + { + lua_pushboolean(L, false); + if (lua_gettop(co)) + lua_xmove(co, L, 1); /* move error message */ + lua_resetthread(co); + return 2; + } +} + static const luaL_Reg co_funcs[] = { {"create", cocreate}, {"running", corunning}, @@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = { {"wrap", cowrap}, {"yield", coyield}, {"isyieldable", coyieldable}, + {"close", coclose}, {NULL, NULL}, }; @@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); + lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 1890e682..d77f84ef 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); } - uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); + uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]); // patch just the opcode byte, leave arguments alone p->code[i] &= ~0xff; @@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc) return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; } -void lua_singlestep(lua_State* L, bool singlestep) +void lua_singlestep(lua_State* L, int enabled) { - L->singlestep = singlestep; + L->singlestep = bool(enabled); } -void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) +void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) { const TValue* func = luaA_toobject(L, funcindex); api_check(L, ttisfunction(func) && !clvalue(func)->isC); - luaG_breakpoint(L, clvalue(func)->l.p, line, enable); + luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 328b47e6..1259d461 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) +LUAU_FASTFLAG(LuauCoroutineClose) /* ** {====================================================== @@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud) if (L->status == 0) { // start coroutine - LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); + LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); + if (FFlag::LuauCoroutineClose && firstArg == L->base) + luaG_runerror(L, "cannot resume dead coroutine"); + if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) return; diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index bf5e738f..4e40165a 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) const luaL_Reg* lib = lualibs; for (; lib->func; lib++) { - lua_pushcfunction(L, lib->func); + lua_pushcfunction(L, lib->func, NULL); lua_pushstring(L, lib->name); lua_call(L, 1, 0); } diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 0b2dfb69..24e97063 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1) luaM_free(L, L1, sizeof(lua_State), L1->memcat); } +void lua_resetthread(lua_State* L) +{ + /* close upvalues before clearing anything */ + luaF_close(L, L->stack); + /* clear call frames */ + CallInfo* ci = L->base_ci; + ci->func = L->stack; + ci->base = ci->func + 1; + ci->top = ci->base + LUA_MINSTACK; + setnilvalue(ci->func); + L->ci = ci; + luaD_reallocCI(L, BASIC_CI_SIZE); + /* clear thread state */ + L->status = LUA_OK; + L->base = L->ci->base; + L->top = L->ci->base; + L->nCcalls = L->baseCcalls = 0; + /* clear thread stack */ + luaD_reallocstack(L, BASIC_STACK_SIZE); + for (int i = 0; i < L->stacksize; i++) + setnilvalue(L->stack + i); +} + +int lua_isthreadreset(lua_State* L) +{ + return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK; +} + lua_State* lua_newstate(lua_Alloc f, void* ud) { int i; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 80a34483..b576f809 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -748,7 +748,7 @@ static int gmatch(lua_State* L) luaL_checkstring(L, 2); lua_settop(L, 2); lua_pushinteger(L, 0); - lua_pushcfunction(L, gmatch_aux, NULL, 3); + lua_pushcclosure(L, gmatch_aux, NULL, 3); return 1; } diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 6a026296..378de3d0 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -265,7 +265,7 @@ static int iter_aux(lua_State* L) static int iter_codes(lua_State* L) { luaL_checkstring(L, 1); - lua_pushcfunction(L, iter_aux); + lua_pushcfunction(L, iter_aux, NULL); lua_pushvalue(L, 1); lua_pushinteger(L, 0); return 3; diff --git a/bench/bench.py b/bench/bench.py index b23ca891..39f219f3 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -25,8 +25,8 @@ try: import scipy from scipy import stats except ModuleNotFoundError: - print("scipy package is required") - exit(1) + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None scriptdir = os.path.dirname(os.path.realpath(__file__)) defaultVm = 'luau.exe' if os.name == "nt" else './luau' @@ -200,11 +200,14 @@ def finalizeResult(result): result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) result.unbiasedEst = result.sampleStdDev * result.sampleStdDev - # Two-tailed distribution with 95% conf. - tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) + if stats: + # Two-tailed distribution with 95% conf. + tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) - # Compute confidence interval - result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + # Compute confidence interval + result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + else: + result.sampleConfidenceInterval = result.sampleStdDev else: result.sampleStdDev = 0 result.unbiasedEst = 0 @@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons): tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) degreesOfFreedom = 2 * main.count - 2 - # Two-tailed distribution with 95% conf. - tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) + if stats: + # Two-tailed distribution with 95% conf. + tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - noSignificantDifference = tStat < tCritical + noSignificantDifference = tStat < tCritical + pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + else: + noSignificantDifference = None + pValue = -1 - pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) - - if noSignificantDifference: + if noSignificantDifference is None: + verdict = "" + elif noSignificantDifference: verdict = "likely same" elif main.avg < compare.avg: verdict = "likely worse" diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c85fac7d..ae2399e4 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) lua_State* L = lua_newthread(globalState); luaL_sandboxthread(L); - if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8a7798f3..5a7c8602 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl { }; -struct UnfrozenACFixture : ACFixtureImpl -{ -}; - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -1919,9 +1915,10 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end @@ -1950,9 +1947,10 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function a(x: boolean) end local function b(x: number?) end @@ -2484,7 +2482,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7f03019c..4ce8d08a 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -11,8 +11,6 @@ #include LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauPreloadClosuresFenv) -LUAU_FASTFLAG(LuauPreloadClosuresUpval) LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( @@ -3479,15 +3477,13 @@ CAPTURE VAL R0 RETURN R1 1 )"); - if (FFlag::LuauPreloadClosuresFenv) - { - // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( setfenv(1, {}) return function() print("hi") end )", - 1), - R"( + 1), + R"( GETIMPORT R0 1 LOADN R1 1 NEWTABLE R2 0 0 @@ -3496,23 +3492,21 @@ NEWCLOSURE R0 P0 RETURN R0 1 )"); - // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 RETURN R0 1 )"); - } } TEST_CASE("SharedClosure") { ScopedFastFlag sff1("LuauPreloadClosures", true); - ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( @@ -3671,7 +3665,7 @@ RETURN R0 0 )"); } -TEST_CASE("LuauGenericSpecialGlobals") +TEST_CASE("MutableGlobals") { const char* source = R"( print() @@ -3685,43 +3679,6 @@ shared.print() workspace.print() )"; - { - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; - - // Check Roblox globals are here - CHECK_EQ("\n" + compileFunction0(source), R"( -GETIMPORT R0 1 -CALL R0 0 0 -GETIMPORT R1 3 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 5 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 9 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 11 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 13 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 15 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 17 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -RETURN R0 0 -)"); - } - - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; - // Check Roblox globals are no longer here CHECK_EQ("\n" + compileFunction0(source), R"( GETIMPORT R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c1b790b9..e495a213 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Compiler.h" +#include "lua.h" +#include "lualib.h" +#include "luacode.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" @@ -10,9 +12,6 @@ #include "doctest.h" #include "ScopedFlags.h" -#include "lua.h" -#include "lualib.h" - #include #include @@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + size_t bytecodeSize = 0; + char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize); + int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); + free(bytecode); + + if (result == 0) return 1; lua_pushnil(L); @@ -179,21 +182,17 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - Luau::CompileOptions copts; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; // default copts.debugLevel = 2; // for debugger tests copts.vectorCtor = "vector"; // for vector tests - std::string bytecode = Luau::compile(source, copts); - int status = 0; + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); + free(bytecode); - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) - { - status = lua_resume(L, nullptr, 0); - } - else - { - status = LUA_ERRSYNTAX; - } + int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; while (yield && (status == LUA_YIELD || status == LUA_BREAK)) { @@ -332,53 +331,61 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { + ScopedFastFlag sff("LuauCoroutineClose", true); + runConformance("coroutine.lua"); } +static int cxxthrow(lua_State* L) +{ +#if LUA_USE_LONGJMP + luaL_error(L, "oops"); +#else + throw std::runtime_error("oops"); +#endif +} + TEST_CASE("PCall") { runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) -> int { -#if LUA_USE_LONGJMP - luaL_error(L, "oops"); -#else - throw std::runtime_error("oops"); -#endif - }); + lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); - lua_pushcfunction(L, [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); lua_setglobal(L, "resumeerror"); }); } TEST_CASE("Pack") { - ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; - + ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; + runConformance("tpack.lua"); } TEST_CASE("Vector") { runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector); + lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); lua_pushvector(L, 0.0f, 0.0f, 0.0f); luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index); + lua_pushcfunction(L, lua_vector_index, nullptr); lua_settable(L, -3); lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall); + lua_pushcfunction(L, lua_vector_namecall, nullptr); lua_settable(L, -3); lua_setreadonly(L, -1, true); @@ -513,15 +520,19 @@ TEST_CASE("Debugger") }; // add breakpoint() function - lua_pushcfunction(L, [](lua_State* L) -> int { - int line = luaL_checkinteger(L, 1); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; - lua_Debug ar = {}; - lua_getinfo(L, 1, "f", &ar); + lua_Debug ar = {}; + lua_getinfo(L, 1, "f", &ar); - lua_breakpoint(L, -1, line, true); - return 0; - }); + lua_breakpoint(L, -1, line, enabled); + return 0; + }, + "breakpoint"); lua_setglobal(L, "breakpoint"); }, [](lua_State* L) { @@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject") if (nsize == 0) { free(ptr); - return NULL; + return nullptr; } else if (nsize > 512 * 1024) { diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index e55b5b0c..e0756bad 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -4,7 +4,8 @@ #include #include -namespace std { +namespace std +{ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 18f55d2c..7a3543c7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") bool encounteredFreeType = false; TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTy)); + CHECK_EQ("any", toString(clonedTy)); CHECK(encounteredFreeType); encounteredFreeType = false; TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTp)); + CHECK_EQ("...any", toString(clonedTp)); CHECK(encounteredFreeType); } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b076e9ad..80a258f5 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TypeVar tv{ttv}; - ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + ToStringOptions o; + o.maxTableLength = 40; CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } @@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); + CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") CHECK_EQ(toString(tableTy), "Table
"); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") +{ + CheckResult result = check(R"( + local function id(x) return x end + )"); + + TypeId ty = requireType("id"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") +{ + CheckResult result = check(R"( + local function map(arr, fn) + local t = {} + for i = 0, #arr do + t[i] = fn(arr[i]) + end + return t + end + )"); + + TypeId ty = requireType("map"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); +} + +TEST_CASE("toStringNamedFunction_unit_f") +{ + TypePackVar empty{TypePack{}}; + FunctionTypeVar ftv{&empty, &empty, {}, false}; + CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") +{ + CheckResult result = check(R"( + local function f(x: a, ...): (a, a, b...) + return x, x, ... + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") +{ + CheckResult result = check(R"( + local function f(): ...number + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") +{ + CheckResult result = check(R"( + local function f(): (string, ...number) + return 'a', 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") +{ + CheckResult result = check(R"( + local f: (number, y: number) -> number + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") +{ + CheckResult result = check(R"( + local function f(x: T, g: (T) -> U)): () + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + ToStringOptions opts; + opts.hideNamedFunctionTypeParameters = true; + CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index c27f8083..74ce155c 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ CheckResult result = check(R"( type A = number type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + local foo: string = 1 -- "Type 'number' could not be converted into 'string'" )"); LUAU_REQUIRE_ERROR_COUNT(2, result); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2e400164..091c2f01 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( type A = B type B = A @@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const ErrorTypeVar* ftv = get(follow(fType)); + const AnyTypeVar* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de2f0154..88c2dc85 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") @@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") @@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 733fc39b..fe8e7ff9 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -704,9 +704,10 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); else CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) + if (FFlag::LuauQuantifyInPlace2) CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" else CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp new file mode 100644 index 00000000..5f95efd5 --- /dev/null +++ b/tests/TypeInfer.singletons.test.cpp @@ -0,0 +1,377 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeSingletons"); + +TEST_CASE_FIXTURE(Fixture, "bool_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: false = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: "bar" = "bar" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = false + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "bar" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "\n" = "\000\r" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: boolean = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, "foo") + g(false, 37) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, 37) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "foo" + local b : MyEnum = "bar" + local c : MyEnum = "baz" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "bang" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum1 = "foo" | "bar" + type MyEnum2 = MyEnum1 | "baz" + local a : MyEnum1 = "foo" + local b : MyEnum2 = a + local c : string = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExpectedTypesOfProperties", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Dog = { tag = "Dog", howls = true } + local b : Animal = { tag = "Cat", meows = true } + local c : Animal = a + c = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", howls = true } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", meows = true } + a.tag = "Dog" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["foo"] : number, + ["$$bar"] : string, + baz : boolean + } + local t: T = { + ["foo"] = 37, + ["$$bar"] = "hi", + baz = true + } + local a: number = t.foo + local b: string = t["$$bar"] + local c: boolean = t.baz + t.foo = 5 + t["$$bar"] = "lo" + t.baz = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["$$bar"] : string, + } + local t: T = { + ["$$bar"] = "hi", + } + t["$$bar"] = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type S = "bar" + type T = { + [("foo")] : number, + [S] : string, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + local x: { ["<>"] : number } + x = { ["\n"] = 5 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 30d9130a..99fd8339 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ(*p, *typeChecker.errorType); + CHECK_EQ("*unknown*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") @@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") @@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK(get(requireType("t"))); + CHECK_EQ("*unknown*", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); + CHECK_EQ("*unknown*", toString(hootyType)); } TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") @@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( local err = (true).x @@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK(nullptr != get(requireType("c"))); - CHECK(nullptr != get(requireType("d"))); - CHECK(nullptr != get(requireType("e"))); - CHECK(nullptr != get(requireType("f"))); + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") @@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK(nullptr != get(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - TypeId aType = requireType("a"); - - REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") @@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "compare_numbers") @@ -2901,6 +2903,8 @@ end TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) @@ -2908,7 +2912,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") @@ -2920,7 +2924,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") @@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") print(a) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); } @@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") CHECK_EQ(toString(*it), "(number) -> number"); } +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation @@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") } } +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 2d697fc9..9f9a007f 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId bType = requireType("b"); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); +} - CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); +TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + function f(arg: number) return arg end + local a + local b + local c = f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") @@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") -{ - TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); - - state.tryUnify(variadicPack, errorPack); - REQUIRE_EQ(0, state.errors.size()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 9f29b642..48496b89 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - TypeId r = requireType("r"); - CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); + CHECK_EQ("*unknown*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -283,7 +282,7 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "optional_union_follow") +TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") { CheckResult result = check(R"( local y: number? = 2 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 75329642..4d9b1295 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -319,4 +319,58 @@ for i=0,30 do assert(#T2 == 1 or T2[#T2] == 42) end +-- test coroutine.close +do + -- ok to close a dead coroutine + local co = coroutine.create(type) + assert(coroutine.resume(co, "testing 'coroutine.close'")) + assert(coroutine.status(co) == "dead") + local st, msg = coroutine.close(co) + assert(st and msg == nil) + -- also ok to close it again + st, msg = coroutine.close(co) + assert(st and msg == nil) + + + -- cannot close the running coroutine + coroutine.wrap(function() + local st, msg = pcall(coroutine.close, coroutine.running()) + assert(not st and string.find(msg, "running")) + end)() + + -- cannot close a "normal" coroutine + coroutine.wrap(function() + local co = coroutine.running() + coroutine.wrap(function () + local st, msg = pcall(coroutine.close, co) + assert(not st and string.find(msg, "normal")) + end)() + end)() + + -- closing a coroutine after an error + local co = coroutine.create(error) + local obj = {42} + local st, msg = coroutine.resume(co, obj) + assert(not st and msg == obj) + st, msg = coroutine.close(co) + assert(not st and msg == obj) + -- after closing, no more errors + st, msg = coroutine.close(co) + assert(st and msg == nil) + + -- closing a coroutine that has outstanding upvalues + local f + local co = coroutine.create(function() + local a = 42 + f = function() return a end + coroutine.yield() + a = 20 + end) + coroutine.resume(co) + assert(f() == 42) + st, msg = coroutine.close(co) + assert(st and msg == nil) + assert(f() == 42) +end + return'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 5e69fc6b..6ba99fb9 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad() local co = coroutine.create(corobad) assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! +function bar() + print("in bar") +end + +breakpoint(49) +breakpoint(49, false) -- validate that disabling breakpoints works + +bar() + return 'OK' From eed18acec8677b380fdbb7f424d79e1c7dae4273 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 2 Dec 2021 15:20:08 -0800 Subject: [PATCH 008/102] Sync to upstream/release/506 --- Analysis/include/Luau/FileResolver.h | 23 - Analysis/include/Luau/Module.h | 12 +- Analysis/include/Luau/ToDot.h | 31 ++ Analysis/include/Luau/TxnLog.h | 9 - Analysis/include/Luau/TypeInfer.h | 1 - Analysis/include/Luau/TypeVar.h | 15 - Analysis/include/Luau/Unifier.h | 25 +- Analysis/include/Luau/UnifierSharedState.h | 8 + Analysis/include/Luau/VisitTypeVar.h | 4 +- Analysis/src/Autocomplete.cpp | 72 ++- Analysis/src/BuiltinDefinitions.cpp | 6 +- Analysis/src/Error.cpp | 116 +---- Analysis/src/Frontend.cpp | 27 +- Analysis/src/IostreamHelpers.cpp | 19 +- Analysis/src/JsonEncoder.cpp | 9 +- Analysis/src/Module.cpp | 132 ++--- Analysis/src/Quantify.cpp | 13 +- Analysis/src/RequireTracer.cpp | 195 +------ Analysis/src/Substitution.cpp | 29 +- Analysis/src/ToDot.cpp | 378 ++++++++++++++ Analysis/src/ToString.cpp | 172 +++---- Analysis/src/Transpiler.cpp | 15 +- Analysis/src/TxnLog.cpp | 37 +- Analysis/src/TypeAttach.cpp | 52 +- Analysis/src/TypeInfer.cpp | 331 +++++------- Analysis/src/TypeVar.cpp | 364 -------------- Analysis/src/Unifier.cpp | 306 ++--------- Ast/src/Parser.cpp | 74 +-- CLI/Analyze.cpp | 23 +- CLI/Repl.cpp | 39 -- CLI/Web.cpp | 106 ++++ CMakeLists.txt | 57 ++- Compiler/src/Compiler.cpp | 14 +- Makefile | 6 +- Sources.cmake | 10 + VM/include/lua.h | 11 +- VM/include/luaconf.h | 38 ++ VM/include/lualib.h | 6 + VM/src/lapi.cpp | 67 ++- VM/src/laux.cpp | 95 ++-- VM/src/lbaselib.cpp | 4 +- VM/src/lbitlib.cpp | 2 +- VM/src/lbuiltins.cpp | 12 +- VM/src/lcorolib.cpp | 2 +- VM/src/ldblib.cpp | 2 +- VM/src/ldo.cpp | 71 +-- VM/src/lgc.cpp | 548 +------------------- VM/src/lgcdebug.cpp | 558 +++++++++++++++++++++ VM/src/linit.cpp | 10 +- VM/src/lmathlib.cpp | 5 +- VM/src/lmem.cpp | 8 +- VM/src/lnumutils.h | 8 + VM/src/lobject.cpp | 2 +- VM/src/lobject.h | 23 +- VM/src/loslib.cpp | 2 +- VM/src/lstrlib.cpp | 2 +- VM/src/ltable.cpp | 19 +- VM/src/ltablib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- VM/src/lvmexecute.cpp | 30 +- VM/src/lvmload.cpp | 6 +- VM/src/lvmutils.cpp | 18 +- bench/gc/test_LB_mandel.lua | 2 +- bench/tests/shootout/mandel.lua | 2 +- bench/tests/shootout/qt.lua | 10 +- fuzz/proto.cpp | 4 +- tests/AstQuery.test.cpp | 23 + tests/Autocomplete.test.cpp | 40 +- tests/Compiler.test.cpp | 168 +++++++ tests/Conformance.test.cpp | 116 +++-- tests/Fixture.cpp | 38 +- tests/Fixture.h | 5 +- tests/Frontend.test.cpp | 17 - tests/Linter.test.cpp | 16 + tests/Module.test.cpp | 66 ++- tests/Parser.test.cpp | 2 - tests/ToDot.test.cpp | 366 ++++++++++++++ tests/Transpiler.test.cpp | 3 - tests/TypeInfer.aliases.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 21 +- tests/TypeInfer.provisional.test.cpp | 2 - tests/TypeInfer.tables.test.cpp | 70 +++ tests/TypeInfer.test.cpp | 20 + tests/TypeInfer.typePacks.cpp | 33 -- tests/TypeVar.test.cpp | 45 ++ tests/conformance/apicalls.lua | 8 +- tests/conformance/basic.lua | 5 +- tests/conformance/closure.lua | 8 +- tests/conformance/constructs.lua | 2 +- tests/conformance/coroutine.lua | 2 +- tests/conformance/datetime.lua | 2 +- tests/conformance/debug.lua | 2 +- tests/conformance/errors.lua | 32 +- tests/conformance/gc.lua | 8 +- tests/conformance/nextvar.lua | 6 +- tests/conformance/pcall.lua | 2 +- tests/conformance/utf8.lua | 2 +- tests/conformance/vector.lua | 31 +- tools/svg.py | 9 +- 99 files changed, 2905 insertions(+), 2568 deletions(-) create mode 100644 Analysis/include/Luau/ToDot.h create mode 100644 Analysis/src/ToDot.cpp create mode 100644 CLI/Web.cpp create mode 100644 VM/src/lgcdebug.cpp create mode 100644 tests/ToDot.test.cpp diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 9b74fc12..0fdcce16 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -51,13 +51,6 @@ struct FileResolver { return std::nullopt; } - - // DEPRECATED APIS - // These are going to be removed with LuauNewRequireTrace2 - virtual bool moduleExists(const ModuleName& name) const = 0; - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -66,22 +59,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - bool moduleExists(const ModuleName& name) const override - { - return false; - } - std::optional fromAstFragment(AstExpr* expr) const override - { - return std::nullopt; - } - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs; - } - std::optional getParentModuleName(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index d0844835..2e41674b 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -78,9 +78,15 @@ void unfreeze(TypeArena& arena); using SeenTypes = std::unordered_map; using SeenTypePacks = std::unordered_map; -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); +struct CloneState +{ + int recursionCount = 0; + bool encounteredFreeType = false; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); struct Module { diff --git a/Analysis/include/Luau/ToDot.h b/Analysis/include/Luau/ToDot.h new file mode 100644 index 00000000..ce518d3a --- /dev/null +++ b/Analysis/include/Luau/ToDot.h @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct ToDotOptions +{ + bool showPointers = true; // Show pointer value in the node label + bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes +}; + +std::string toDot(TypeId ty, const ToDotOptions& opts); +std::string toDot(TypePackId tp, const ToDotOptions& opts); + +std::string toDot(TypeId ty); +std::string toDot(TypePackId tp); + +void dumpDot(TypeId ty); +void dumpDot(TypePackId tp); + +} // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 322abd19..29988a3b 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -25,15 +25,6 @@ struct TxnLog { } - explicit TxnLog(const std::vector>& ownedSeen) - : originalSeenSize(ownedSeen.size()) - , ownedSeen(ownedSeen) - , sharedSeen(nullptr) - { - // This is deprecated! - LUAU_ASSERT(!FFlag::LuauShareTxnSeen); - } - TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 78d642c5..9f553bc1 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -297,7 +297,6 @@ private: private: Unifier mkUnifier(const Location& location); - Unifier mkUnifier(const std::vector>& seen, const Location& location); // These functions are only safe to call when we are in the process of typechecking a module. diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 093ea431..8c4c2f34 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -517,21 +517,6 @@ extern SingletonTypes singletonTypes; void persist(TypeId ty); void persist(TypePackId tp); -struct ToDotOptions -{ - bool showPointers = true; // Show pointer value in the node label - bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes -}; - -std::string toDot(TypeId ty, const ToDotOptions& opts); -std::string toDot(TypePackId tp, const ToDotOptions& opts); - -std::string toDot(TypeId ty); -std::string toDot(TypePackId tp); - -void dumpDot(TypeId ty); -void dumpDot(TypePackId tp); - const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 503034a1..4588cdd8 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -19,12 +19,6 @@ enum Variance Invariant }; -struct UnifierCounters -{ - int recursionCount = 0; - int iterationCount = 0; -}; - struct Unifier { TypeArena* const types; @@ -37,20 +31,11 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - UnifierCounters* counters; - UnifierCounters countersData; - - std::shared_ptr counters_DEPRECATED; - UnifierSharedState& sharedState; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, - UnifierCounters* counters = nullptr); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, - UnifierCounters* counters = nullptr); + Variance variance, UnifierSharedState& sharedState); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -92,9 +77,9 @@ private: public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); + void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); + void occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -106,10 +91,6 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - - // Remove with FFlagLuauCacheUnifyTableResults - DenseHashSet tempSeenTy_DEPRECATED{nullptr}; - DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index f252a004..88997c41 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -24,6 +24,12 @@ struct TypeIdPairHash } }; +struct UnifierCounters +{ + int recursionCount = 0; + int iterationCount = 0; +}; + struct UnifierSharedState { UnifierSharedState(InternalErrorReporter* iceHandler) @@ -39,6 +45,8 @@ struct UnifierSharedState DenseHashSet tempSeenTy{nullptr}; DenseHashSet tempSeenTp{nullptr}; + + UnifierCounters counters; }; } // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index a866655c..740854b3 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -5,8 +5,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauCacheUnifyTableResults) - namespace Luau { @@ -101,7 +99,7 @@ void visit(TypeId ty, F& f, Set& seen) // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) + if (ttv->boundTo) { visit(*ttv->boundTo, f, seen); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 6fc0b3f8..db2d1d0e 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,9 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -203,8 +203,9 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { SeenTypes seenTypes; SeenTypePacks seenTypePacks; - expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); - actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); + CloneState cloneState; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, cloneState); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, cloneState); auto errors = unifier.canUnify(expectedType, actualType); return errors.empty(); @@ -229,28 +230,51 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*it); - if (canUnify(expectedType, ty)) - return TypeCorrectKind::Correct; - - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); - - if (!ftv) - return TypeCorrectKind::None; - - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + if (FFlag::LuauAutocompletePreferToCallFunctions) { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - } + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) + { + auto [retHead, retTail] = flatten(ftv->retType); - return TypeCorrectKind::None; + if (!retHead.empty() && canUnify(expectedType, retHead.front())) + return TypeCorrectKind::CorrectFunctionResult; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(expectedType, vtp->ty)) + return TypeCorrectKind::CorrectFunctionResult; + } + } + + return canUnify(expectedType, ty) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + } + else + { + if (canUnify(expectedType, ty)) + return TypeCorrectKind::Correct; + + // We also want to suggest functions that return compatible result + const FunctionTypeVar* ftv = get(ty); + + if (!ftv) + return TypeCorrectKind::None; + + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty()) + return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail))) + return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + } + + return TypeCorrectKind::None; + } } enum class PropIndexType @@ -1413,7 +1437,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (AstStatIf* statIf = node->as(); FFlag::ElseElseIfCompletionImprovements && statIf && !statIf->hasElse) + else if (AstStatIf* statIf = node->as(); statIf && !statIf->hasElse) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 62a06a3c..bac94a2b 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAG(LuauNewRequireTrace2) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -473,9 +471,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0]; - - if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index f80d50a7..8334bd62 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,57 +7,14 @@ #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - -static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) -{ - std::string s = "expects " + std::to_string(expectedCount) + " "; - - if (isTypeArgs) - s += "type "; - - s += "argument"; - if (expectedCount != 1) - s += "s"; - - s += ", but "; - - if (actualCount == 0) - { - s += "none"; - } - else - { - if (actualCount < expectedCount) - s += "only "; - - s += std::to_string(actualCount); - } - - s += (actualCount == 1) ? " is" : " are"; - - s += " specified"; - - return s; -} - static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { - std::string s; + std::string s = "expects "; - if (FFlag::LuauTypeAliasPacks) - { - s = "expects "; + if (isVariadic) + s += "at least "; - if (isVariadic) - s += "at least "; - - s += std::to_string(expectedCount) + " "; - } - else - { - s = "expects " + std::to_string(expectedCount) + " "; - } + s += std::to_string(expectedCount) + " "; if (argPrefix) s += std::string(argPrefix) + " "; @@ -188,10 +145,7 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - if (FFlag::LuauTypeAliasPacks) - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); - else - return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -232,7 +186,7 @@ struct ErrorConverter std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) + if (!e.typeFun.typeParams.empty() || !e.typeFun.typePackParams.empty()) { name += "<"; bool first = true; @@ -246,36 +200,25 @@ struct ErrorConverter name += toString(t); } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId t : e.typeFun.typePackParams) { - for (TypePackId t : e.typeFun.typePackParams) - { - if (first) - first = false; - else - name += ", "; + if (first) + first = false; + else + name += ", "; - name += toString(t); - } + name += toString(t); } name += ">"; } - if (FFlag::LuauTypeAliasPacks) - { - if (e.typeFun.typeParams.size() != e.actualParameters) - return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + if (e.typeFun.typeParams.size() != e.actualParameters) + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); - return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); - } - else - { - return "Generic type '" + name + "' " + - wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); - } + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); } std::string operator()(const Luau::SyntaxError& e) const @@ -591,11 +534,8 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; - if (FFlag::LuauTypeAliasPacks) - { - if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) - return false; - } + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { @@ -603,13 +543,10 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC return false; } - if (FFlag::LuauTypeAliasPacks) + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) - { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) - return false; - } + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; } return true; @@ -733,14 +670,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) +void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); + return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks); + copyError(e, destArena, seenTypes, seenTypePacks, cloneState); }; if constexpr (false) @@ -864,9 +801,10 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) { SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks); + copyError(e, destArena, seenTypes, seenTypePacks, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1e97705d..e332f07d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,10 +18,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) -LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) -LUAU_FASTFLAG(LuauNewRequireTrace2) namespace Luau { @@ -96,10 +93,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -110,7 +108,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -427,15 +425,16 @@ CheckResult Frontend::check(const ModuleName& name) SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); } stats.timeCheck += getTimestamp() - timestamp; @@ -885,16 +884,13 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module // If we can't find the current module name, that's because we bypassed the frontend's initializer // and called typeChecker.check directly. (This is done by autocompleteSource, for example). // In that case, requires will always fail. - if (FFlag::LuauResolveModuleNameWithoutACurrentModule) - return std::nullopt; - else - throw std::runtime_error("Frontend::resolveModuleName: Unknown currentModuleName '" + currentModuleName + "'"); + return std::nullopt; } const auto& exprs = it->second.exprs; const ModuleInfo* info = exprs.find(&pathExpr); - if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty())) + if (!info) return std::nullopt; return *info; @@ -911,10 +907,7 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - if (FFlag::LuauNewRequireTrace2) - return frontend->sourceNodes.count(moduleName) != 0; - else - return frontend->fileResolver->moduleExists(moduleName); + return frontend->sourceNodes.count(moduleName) != 0; } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 3b267121..ac46b5a4 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,8 +2,6 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace Luau { @@ -94,7 +92,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) + if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty()) { stream << "<"; bool first = true; @@ -108,17 +106,14 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId t : error.typeFun.typePackParams) { - for (TypePackId t : error.typeFun.typePackParams) - { - if (first) - first = false; - else - stream << ", "; + if (first) + first = false; + else + stream << ", "; - stream << toString(t); - } + stream << toString(t); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 064accba..c7f623ee 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,8 +5,6 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace Luau { @@ -615,12 +613,7 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); - - if (FFlag::LuauTypeAliasPacks) - { - PROP(genericPacks); - } - + PROP(genericPacks); PROP(type); PROP(exported); }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 32a0646a..b4b6eb42 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,20 +1,20 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Common.h" +#include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/Common.h" #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) -LUAU_FASTFLAG(LuauTypeAliasPacks) -LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau { @@ -120,12 +120,6 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); - namespace { @@ -138,11 +132,12 @@ struct TypePackCloner; struct TypeCloner { - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) : dest(dest) , typeId(typeId) , seenTypes(seenTypes) , seenTypePacks(seenTypePacks) + , cloneState(cloneState) { } @@ -150,8 +145,7 @@ struct TypeCloner TypeId typeId; SeenTypes& seenTypes; SeenTypePacks& seenTypePacks; - - bool* encounteredFreeType = nullptr; + CloneState& cloneState; template void defaultClone(const T& t); @@ -178,13 +172,14 @@ struct TypePackCloner TypePackId typePackId; SeenTypes& seenTypes; SeenTypePacks& seenTypePacks; - bool* encounteredFreeType = nullptr; + CloneState& cloneState; - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) : dest(dest) , typePackId(typePackId) , seenTypes(seenTypes) , seenTypePacks(seenTypePacks) + , cloneState(cloneState) { } @@ -197,8 +192,7 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - if (encounteredFreeType) - *encounteredFreeType = true; + cloneState.encounteredFreeType = true; TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); TypePackId cloned = dest.addTypePack(*err); @@ -218,13 +212,13 @@ struct TypePackCloner // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypePacks[typePackId] = cloned; } void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); seenTypePacks[typePackId] = cloned; } @@ -236,10 +230,10 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, encounteredFreeType); + destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); } }; @@ -252,8 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - if (encounteredFreeType) - *encounteredFreeType = true; + cloneState.encounteredFreeType = true; TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; @@ -266,7 +259,7 @@ void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Bound& t) { - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypes[typeId] = boundTo; } @@ -294,23 +287,23 @@ void TypeCloner::operator()(const FunctionTypeVar& t) seenTypes[typeId] = result; for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, encounteredFreeType)); + ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); + ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const TableTypeVar& t) { // If table is now bound to another one, we ignore the content of the original - if (FFlag::LuauCloneBoundTables && t.boundTo) + if (t.boundTo) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypes[typeId] = boundTo; return; } @@ -326,34 +319,21 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - - if (!FFlag::LuauCloneBoundTables) - { - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); - } + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), + clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); - } + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); if (ttv->state == TableState::Free) { - if (FFlag::LuauCloneBoundTables || !t.boundTo) - { - if (encounteredFreeType) - *encounteredFreeType = true; - } + cloneState.encounteredFreeType = true; ttv->state = TableState::Sealed; } @@ -369,8 +349,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t) MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, encounteredFreeType); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); + mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); + mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const ClassTypeVar& t) @@ -381,13 +361,13 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); + ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); + ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const AnyTypeVar& t) @@ -404,7 +384,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); } void TypeCloner::operator()(const IntersectionTypeVar& t) @@ -416,7 +396,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); } void TypeCloner::operator()(const LazyTypeVar& t) @@ -426,17 +406,18 @@ void TypeCloner::operator()(const LazyTypeVar& t) } // anonymous namespace -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { if (tp->persistent) return tp; + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + TypePackId& res = seenTypePacks[tp]; if (res == nullptr) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; + TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } @@ -446,17 +427,18 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP return res; } -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { if (typeId->persistent) return typeId; + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + TypeId& res = seenTypes[typeId]; if (res == nullptr) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; + TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. asMutable(res)->documentationSymbol = typeId->documentationSymbol; } @@ -467,19 +449,16 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks return res; } -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); - } + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); + result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); return result; } @@ -519,19 +498,18 @@ bool Module::clonePublicInterface() LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - bool encounteredFreeType = false; - - SeenTypePacks seenTypePacks; SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + CloneState cloneState; ScopePtr moduleScope = getModuleScope(); - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) @@ -540,7 +518,7 @@ bool Module::clonePublicInterface() freeze(internalTypes); freeze(interfaceTypes); - return encounteredFreeType; + return cloneState.encounteredFreeType; } } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index bf6d81aa..c773e208 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,6 +4,8 @@ #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAGVARIABLE(LuauQuantifyVisitOnce, false) + namespace Luau { @@ -79,7 +81,16 @@ struct Quantifier void quantify(ModulePtr module, TypeId ty, TypeLevel level) { Quantifier q{std::move(module), level}; - visitTypeVar(ty, q); + + if (FFlag::LuauQuantifyVisitOnce) + { + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, q, seen); + } + else + { + visitTypeVar(ty, q); + } FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index b72f53f9..8ed245fb 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -4,182 +4,9 @@ #include "Luau/Ast.h" #include "Luau/Module.h" -LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) -LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false) - namespace Luau { -namespace -{ - -struct RequireTracerOld : AstVisitor -{ - explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) - : fileResolver(fileResolver) - , currentModuleName(currentModuleName) - { - LUAU_ASSERT(!FFlag::LuauNewRequireTrace2); - } - - FileResolver* const fileResolver; - ModuleName currentModuleName; - DenseHashMap locals{nullptr}; - RequireTraceResult result; - - std::optional fromAstFragment(AstExpr* expr) - { - if (auto g = expr->as(); g && g->name == "script") - return currentModuleName; - - return fileResolver->fromAstFragment(expr); - } - - bool visit(AstStatLocal* stat) override - { - for (size_t i = 0; i < stat->vars.size; ++i) - { - AstLocal* local = stat->vars.data[i]; - - if (local->annotation) - { - if (AstTypeTypeof* ann = local->annotation->as()) - ann->expr->visit(this); - } - - if (i < stat->values.size) - { - AstExpr* expr = stat->values.data[i]; - expr->visit(this); - - const ModuleInfo* info = result.exprs.find(expr); - if (info) - locals[local] = info->name; - } - } - - return false; - } - - bool visit(AstExprGlobal* global) override - { - std::optional name = fromAstFragment(global); - if (name) - result.exprs[global] = {*name}; - - return false; - } - - bool visit(AstExprLocal* local) override - { - const ModuleName* name = locals.find(local->local); - if (name) - result.exprs[local] = {*name}; - - return false; - } - - bool visit(AstExprIndexName* indexName) override - { - indexName->expr->visit(this); - - const ModuleInfo* info = result.exprs.find(indexName->expr); - if (info) - { - if (indexName->index == "parent" || indexName->index == "Parent") - { - if (auto parent = fileResolver->getParentModuleName(info->name)) - result.exprs[indexName] = {*parent}; - } - else - result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; - } - - return false; - } - - bool visit(AstExprIndexExpr* indexExpr) override - { - indexExpr->expr->visit(this); - - const ModuleInfo* info = result.exprs.find(indexExpr->expr); - const AstExprConstantString* str = indexExpr->index->as(); - if (info && str) - { - result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; - } - - indexExpr->index->visit(this); - - return false; - } - - bool visit(AstExprTypeAssertion* expr) override - { - return false; - } - - // If we see game:GetService("StringLiteral") or Game:GetService("StringLiteral"), then rewrite to game.StringLiteral. - // Else traverse arguments and trace requires to them. - bool visit(AstExprCall* call) override - { - for (AstExpr* arg : call->args) - arg->visit(this); - - call->func->visit(this); - - AstExprGlobal* globalName = call->func->as(); - if (globalName && globalName->name == "require" && call->args.size >= 1) - { - if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) - result.requires.push_back({moduleInfo->name, call->location}); - - return false; - } - - AstExprIndexName* indexName = call->func->as(); - if (!indexName) - return false; - - std::optional rootName = fromAstFragment(indexName->expr); - - if (FFlag::LuauTraceRequireLookupChild && !rootName) - { - if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) - rootName = moduleInfo->name; - } - - if (!rootName) - return false; - - bool supportedLookup = indexName->index == "GetService" || - (FFlag::LuauTraceRequireLookupChild && (indexName->index == "FindFirstChild" || indexName->index == "WaitForChild")); - - if (!supportedLookup) - return false; - - if (call->args.size != 1) - return false; - - AstExprConstantString* name = call->args.data[0]->as(); - if (!name) - return false; - - std::string_view v{name->value.data, name->value.size}; - if (v.end() != std::find(v.begin(), v.end(), '/')) - return false; - - result.exprs[call] = {fileResolver->concat(*rootName, v)}; - - // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime - // If we fail to find such module, we will not report an UnknownRequire error - if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.exprs[call].optional = true; - - return false; - } -}; - struct RequireTracer : AstVisitor { RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) @@ -188,7 +15,6 @@ struct RequireTracer : AstVisitor , currentModuleName(currentModuleName) , locals(nullptr) { - LUAU_ASSERT(FFlag::LuauNewRequireTrace2); } bool visit(AstExprTypeAssertion* expr) override @@ -328,24 +154,13 @@ struct RequireTracer : AstVisitor std::vector requires; }; -} // anonymous namespace - RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - if (FFlag::LuauNewRequireTrace2) - { - RequireTraceResult result; - RequireTracer tracer{result, fileResolver, currentModuleName}; - root->visit(&tracer); - tracer.process(); - return result; - } - else - { - RequireTracerOld tracer{fileResolver, currentModuleName}; - root->visit(&tracer); - return tracer.result; - } + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index ca2b30f5..3d004bee 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -7,8 +7,6 @@ #include LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) -LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) -LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -39,11 +37,8 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp); - } + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -339,10 +334,10 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + if (!ignoreChildren(oldTy)) replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + if (!ignoreChildren(oldTp)) replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; @@ -359,10 +354,10 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + if (!ignoreChildren(oldTy)) replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + if (!ignoreChildren(oldTp)) replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; @@ -393,10 +388,7 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - - if (FFlag::LuauTypeAliasPacks) - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; clone.tags = ttv->tags; result = addType(std::move(clone)); } @@ -505,11 +497,8 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId& itp : ttv->instantiatedTypePackParams) - itp = replace(itp); - } + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp new file mode 100644 index 00000000..df9d4188 --- /dev/null +++ b/Analysis/src/ToDot.cpp @@ -0,0 +1,378 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ToDot.h" + +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/StringUtils.h" + +#include +#include + +namespace Luau +{ + +namespace +{ + +struct StateDot +{ + StateDot(ToDotOptions opts) + : opts(opts) + { + } + + ToDotOptions opts; + + std::unordered_set seenTy; + std::unordered_set seenTp; + std::unordered_map tyToIndex; + std::unordered_map tpToIndex; + int nextIndex = 1; + std::string result; + + bool canDuplicatePrimitive(TypeId ty); + + void visitChildren(TypeId ty, int index); + void visitChildren(TypePackId ty, int index); + + void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); + void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); + + void startNode(int index); + void finishNode(); + + void startNodeLabel(); + void finishNodeLabel(TypeId ty); + void finishNodeLabel(TypePackId tp); +}; + +bool StateDot::canDuplicatePrimitive(TypeId ty) +{ + if (get(ty)) + return false; + + return get(ty) || get(ty); +} + +void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) +{ + if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) + tyToIndex[ty] = nextIndex++; + + int index = tyToIndex[ty]; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, index); + } + + if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) + { + if (get(ty)) + formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); + else if (get(ty)) + formatAppend(result, "n%d [label=\"any\"];\n", index); + } + else + { + visitChildren(ty, index); + } +} + +void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) +{ + if (!tpToIndex.count(tp)) + tpToIndex[tp] = nextIndex++; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); + } + + visitChildren(tp, tpToIndex[tp]); +} + +void StateDot::startNode(int index) +{ + formatAppend(result, "n%d [", index); +} + +void StateDot::finishNode() +{ + formatAppend(result, "];\n"); +} + +void StateDot::startNodeLabel() +{ + formatAppend(result, "label=\""); +} + +void StateDot::finishNodeLabel(TypeId ty) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", ty); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::finishNodeLabel(TypePackId tp) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", tp); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::visitChildren(TypeId ty, int index) +{ + if (seenTy.count(ty)) + return; + seenTy.insert(ty); + + startNode(index); + startNodeLabel(); + + if (const BoundTypeVar* btv = get(ty)) + { + formatAppend(result, "BoundTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(btv->boundTo, index); + } + else if (const FunctionTypeVar* ftv = get(ty)) + { + formatAppend(result, "FunctionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(ftv->argTypes, index, "arg"); + visitChild(ftv->retType, index, "ret"); + } + else if (const TableTypeVar* ttv = get(ty)) + { + if (ttv->name) + formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); + else if (ttv->syntheticName) + formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); + else + formatAppend(result, "TableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + if (ttv->boundTo) + return visitChild(*ttv->boundTo, index, "boundTo"); + + for (const auto& [name, prop] : ttv->props) + visitChild(prop.type, index, name.c_str()); + if (ttv->indexer) + { + visitChild(ttv->indexer->indexType, index, "[index]"); + visitChild(ttv->indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : ttv->instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + formatAppend(result, "MetatableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(mtv->table, index, "table"); + visitChild(mtv->metatable, index, "metatable"); + } + else if (const UnionTypeVar* utv = get(ty)) + { + formatAppend(result, "UnionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : utv->options) + visitChild(opt, index); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + formatAppend(result, "IntersectionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : itv->parts) + visitChild(part, index); + } + else if (const GenericTypeVar* gtv = get(ty)) + { + if (gtv->explicitName) + formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); + else + formatAppend(result, "GenericTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const FreeTypeVar* ftv = get(ty)) + { + formatAppend(result, "FreeTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "AnyTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "ErrorTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const ClassTypeVar* ctv = get(ty)) + { + formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : ctv->props) + visitChild(prop.type, index, name.c_str()); + + if (ctv->parent) + visitChild(*ctv->parent, index, "[parent]"); + + if (ctv->metatable) + visitChild(*ctv->metatable, index, "[metatable]"); + } + else + { + LUAU_ASSERT(!"unknown type kind"); + finishNodeLabel(ty); + finishNode(); + } +} + +void StateDot::visitChildren(TypePackId tp, int index) +{ + if (seenTp.count(tp)) + return; + seenTp.insert(tp); + + startNode(index); + startNodeLabel(); + + if (const BoundTypePack* btp = get(tp)) + { + formatAppend(result, "BoundTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(btp->boundTo, index); + } + else if (const TypePack* tpp = get(tp)) + { + formatAppend(result, "TypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + for (TypeId tv : tpp->head) + visitChild(tv, index); + if (tpp->tail) + visitChild(*tpp->tail, index, "tail"); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + formatAppend(result, "VariadicTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(vtp->ty, index); + } + else if (const FreeTypePack* ftp = get(tp)) + { + formatAppend(result, "FreeTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (const GenericTypePack* gtp = get(tp)) + { + if (gtp->explicitName) + formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); + else + formatAppend(result, "GenericTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (get(tp)) + { + formatAppend(result, "ErrorTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else + { + LUAU_ASSERT(!"unknown type pack kind"); + finishNodeLabel(tp); + finishNode(); + } +} + +} // namespace + +std::string toDot(TypeId ty, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(ty, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypePackId tp, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(tp, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypeId ty) +{ + return toDot(ty, {}); +} + +std::string toDot(TypePackId tp) +{ + return toDot(tp, {}); +} + +void dumpDot(TypeId ty) +{ + printf("%s\n", toDot(ty).c_str()); +} + +void dumpDot(TypePackId tp) +{ + printf("%s\n", toDot(tp).c_str()); +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 735bfa50..6322096c 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) namespace Luau { @@ -59,11 +59,8 @@ struct FindCyclicTypes for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv.instantiatedTypePackParams) - visitTypeVar(itp, *this, seen); - } + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); return exhaustive; } @@ -248,58 +245,45 @@ struct TypeVarStringifier void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) + if (types.size() == 0 && typePacks.size() == 0) return; - if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) + if (types.size() || typePacks.size()) state.emit("<"); - if (FFlag::LuauTypeAliasPacks) - { - bool first = true; + bool first = true; - for (TypeId ty : types) - { - if (!first) - state.emit(", "); + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; + + stringify(ty); + } + + bool singleTp = typePacks.size() == 1; + + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; + + if (!first) + state.emit(", "); + else first = false; - stringify(ty); - } + if (!singleTp) + state.emit("("); - bool singleTp = typePacks.size() == 1; + stringify(tp); - for (TypePackId tp : typePacks) - { - if (isEmpty(tp) && singleTp) - continue; - - if (!first) - state.emit(", "); - else - first = false; - - if (!singleTp) - state.emit("("); - - stringify(tp); - - if (!singleTp) - state.emit(")"); - } - } - else - { - for (size_t i = 0; i < types.size(); ++i) - { - if (i > 0) - state.emit(", "); - - stringify(types[i]); - } + if (!singleTp) + state.emit(")"); } - if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) + if (types.size() || typePacks.size()) state.emit(">"); } @@ -767,12 +751,23 @@ struct TypePackStringifier else state.emit(", "); - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - - if (!elemNames.empty() && elemNames[elemIndex]) + if (FFlag::LuauFunctionArgumentNameSize) { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); + if (elemIndex < elemNames.size() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } + } + else + { + LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); + + if (!elemNames.empty() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } } elemIndex++; @@ -929,38 +924,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (FFlag::LuauTypeAliasPacks) - { - tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); - } - else - { - if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) - { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; - } - else - { - result.name += ">"; - } - } + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); return result; } @@ -1161,17 +1125,37 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV s += ", "; first = false; - // argNames is guaranteed to be equal to argTypes iff argNames is not empty. - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (!ftv.argNames.empty()) - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - s += toString_(*argPackIter); - - ++argPackIter; - if (!ftv.argNames.empty()) + if (FFlag::LuauFunctionArgumentNameSize) { - LUAU_ASSERT(argNameIter != ftv.argNames.end()); - ++argNameIter; + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + ++argNameIter; + } + else + { + s += "_: "; + } + } + else + { + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + } + + s += toString_(*argPackIter); + ++argPackIter; + + if (!FFlag::LuauFunctionArgumentNameSize) + { + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } } } diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 6627fbe3..8e13ea5b 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace { bool isIdentifierStartChar(char c) @@ -787,7 +785,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -798,14 +796,11 @@ struct Printer writer.identifier(o.value); } - if (FFlag::LuauTypeAliasPacks) + for (auto o : a->genericPacks) { - for (auto o : a->genericPacks) - { - comma(); - writer.identifier(o.value); - writer.symbol("..."); - } + comma(); + writer.identifier(o.value); + writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 383bb050..f6a61581 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) - namespace Luau { @@ -36,11 +34,8 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); - if (FFlag::LuauShareTxnSeen) - { - LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); - sharedSeen->resize(originalSeenSize); - } + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); } void TxnLog::concat(TxnLog rhs) @@ -53,45 +48,25 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - - if (!FFlag::LuauShareTxnSeen) - { - ownedSeen.swap(rhs.ownedSeen); - rhs.ownedSeen.clear(); - } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); - else - return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - sharedSeen->push_back(sortedPair); - else - ownedSeen.push_back(sortedPair); + sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - { - LUAU_ASSERT(sortedPair == sharedSeen->back()); - sharedSeen->pop_back(); - } - else - { - LUAU_ASSERT(sortedPair == ownedSeen.back()); - ownedSeen.pop_back(); - } + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index af6d2543..9e61c792 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { char* result = (char*)allocator.allocate(contents.size() + 1); @@ -131,12 +129,9 @@ public: parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; } - if (FFlag::LuauTypeAliasPacks) + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) { - for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) - { - parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; - } + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; } return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); @@ -250,20 +245,7 @@ public: AstTypePack* argTailAnnotation = nullptr; if (argTail) - { - if (FFlag::LuauTypeAliasPacks) - { - argTailAnnotation = rehydrate(*argTail); - } - else - { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) - { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } - } + argTailAnnotation = rehydrate(*argTail); AstArray> argNames; argNames.size = ftv.argNames.size(); @@ -292,20 +274,7 @@ public: AstTypePack* retTailAnnotation = nullptr; if (retTail) - { - if (FFlag::LuauTypeAliasPacks) - { - retTailAnnotation = rehydrate(*retTail); - } - else - { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) - { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } - } + retTailAnnotation = rehydrate(*retTail); return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); @@ -518,18 +487,7 @@ public: const auto& [v, tail] = flatten(ret); if (tail) - { - if (FFlag::LuauTypeAliasPacks) - { - variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); - } - else - { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); - } - } + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b2ae94c7..617bf482 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -23,22 +23,20 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) -LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) -LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) +LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) +LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) +LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) namespace Luau { @@ -562,12 +560,6 @@ ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location return canUnify_(left, right, location); } -ErrorVec TypeChecker::canUnify(const std::vector>& seen, TypeId superTy, TypeId subTy, const Location& location) -{ - Unifier state = mkUnifier(seen, location); - return state.canUnify(superTy, subTy); -} - template ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) { @@ -1152,61 +1144,20 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - else - bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; } else { ScopePtr aliasScope = FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); - if (FFlag::LuauTypeAliasPacks) - { - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); - TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - } - else - { - std::vector generics; - for (AstName generic : typealias.generics) - { - Name n = generic.value; - - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) - { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } - - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasTypeParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; - } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; - } - - TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; - } + TypeId ty = freshType(aliasScope); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; } } else @@ -1223,14 +1174,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId tp : binding->typePackParams) { - for (TypePackId tp : binding->typePackParams) - { - auto generic = get(tp); - LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; - } + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1241,19 +1189,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // Copy can be skipped if this is an identical alias if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) + ttv->instantiatedTypePackParams != binding->typePackParams) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; - clone.name = name; clone.instantiatedTypeParams = binding->typeParams; - - if (FFlag::LuauTypeAliasPacks) - clone.instantiatedTypePackParams = binding->typePackParams; + clone.instantiatedTypePackParams = binding->typePackParams; ty = addType(std::move(clone)); } @@ -1262,9 +1207,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = binding->typePackParams; + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1289,7 +1232,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; if (!get(follow(*superTy))) @@ -1851,6 +1794,24 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; + if (FFlag::LuauPropertiesGetExpectedType && expectedTable) + { + auto it = expectedTable->props.find(key->value.data); + if (it != expectedTable->props.end()) + { + Property expectedProp = it->second; + ErrorVec errors = tryUnify(expectedProp.type, exprType, k->location); + if (errors.empty()) + exprType = expectedProp.type; + } + else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) + { + ErrorVec errors = tryUnify(expectedTable->indexer->indexResultType, exprType, k->location); + if (errors.empty()) + exprType = expectedTable->indexer->indexResultType; + } + } + props[key->value.data] = {exprType, /* deprecated */ false, {}, k->location}; } else @@ -3744,17 +3705,29 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; + std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; if (i == lastIndex && (expr->is() || expr->is())) { auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); + if (FFlag::LuauTailArgumentTypeInfo) + { + if (std::optional firstTy = first(typePack)) + { + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); + } + + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + } + tp->tail = typePack; } else { - std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); @@ -3797,7 +3770,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); - if (FFlag::LuauNewRequireTrace2 && moduleInfo.name.empty()) + if (moduleInfo.name.empty()) { if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { @@ -3814,7 +3787,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // There are two reasons why we might fail to find the module: // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. - if (!resolver->moduleExists(moduleInfo.name) && (FFlag::LuauTraceRequireLookupChild ? !moduleInfo.optional : true)) + if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) { std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(TypeError{location, UnknownRequire{reportedModulePath}}); @@ -3830,7 +3803,12 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - std::optional moduleType = first(module->getModuleScope()->returnType); + TypePackId modulePack = module->getModuleScope()->returnType; + + if (FFlag::LuauModuleRequireErrorPack && get(modulePack)) + return errorRecoveryType(scope); + + std::optional moduleType = first(modulePack); if (!moduleType) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); @@ -3840,7 +3818,8 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module SeenTypes seenTypes; SeenTypePacks seenTypePacks; - return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks); + CloneState cloneState; + return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks, cloneState); } void TypeChecker::tablify(TypeId type) @@ -4326,11 +4305,6 @@ Unifier TypeChecker::mkUnifier(const Location& location) return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } -Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) -{ - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; -} - TypeId TypeChecker::freshType(const ScopePtr& scope) { return freshType(scope->level); @@ -4477,117 +4451,82 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) - { + if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - } - else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) + + if (!lit->hasParameterList && !tf->typePackParams.empty()) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); if (!FFlag::LuauErrorRecoveryType) return errorRecoveryType(scope); } - if (FFlag::LuauTypeAliasPacks) + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; + + for (size_t i = 0; i < lit->parameters.size; ++i) { - if (!lit->hasParameterList && !tf->typePackParams.empty()) + if (AstType* type = lit->parameters.data[i].type) { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); - } + TypeId ty = resolveType(scope, *type); - std::vector typeParams; - std::vector extraTypes; - std::vector typePackParams; - - for (size_t i = 0; i < lit->parameters.size; ++i) - { - if (AstType* type = lit->parameters.data[i].type) - { - TypeId ty = resolveType(scope, *type); - - if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) - typeParams.push_back(ty); - else if (typePackParams.empty()) - extraTypes.push_back(ty); - else - reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); - } - else if (AstTypePack* typePack = lit->parameters.data[i].typePack) - { - TypePackId tp = resolveTypePack(scope, *typePack); - - // If we have collected an implicit type pack, materialize it - if (typePackParams.empty() && !extraTypes.empty()) - typePackParams.push_back(addTypePack(extraTypes)); - - // If we need more regular types, we can use single element type packs to fill those in - if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) - typeParams.push_back(*first(tp)); - else - typePackParams.push_back(tp); - } - } - - // If we still haven't meterialized an implicit type pack, do it now - if (typePackParams.empty() && !extraTypes.empty()) - typePackParams.push_back(addTypePack(extraTypes)); - - // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack - if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) - typePackParams.push_back(addTypePack({})); - - if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) - { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - - if (FFlag::LuauErrorRecoveryType) - { - // Pad the types out with error recovery types - while (typeParams.size() < tf->typeParams.size()) - typeParams.push_back(errorRecoveryType(scope)); - while (typePackParams.size() < tf->typePackParams.size()) - typePackParams.push_back(errorRecoveryTypePack(scope)); - } + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); else - return errorRecoveryType(scope); + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); } - - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; + TypePackId tp = resolveTypePack(scope, *typePack); + + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); + else + typePackParams.push_back(tp); } - - return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } - else - { - std::vector typeParams; - for (const auto& param : lit->parameters) - typeParams.push_back(resolveType(scope, *param.type)); + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); + + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { - // If there aren't enough type parameters, pad them out with error recovery types - // (we've already reported the error) - while (typeParams.size() < lit->parameters.size) + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); } - - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) - { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; - } - - return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); + else + return errorRecoveryType(scope); } + + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; + } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } else if (const auto& table = annotation.as()) { @@ -4757,7 +4696,7 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) bool ApplyTypeFunction::ignoreChildren(TypeId ty) { - if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + if (get(ty)) return true; else return false; @@ -4765,7 +4704,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty) bool ApplyTypeFunction::ignoreChildren(TypePackId tp) { - if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + if (get(tp)) return true; else return false; @@ -4788,36 +4727,26 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - if (FFlag::LuauTypeAliasPacks) - { - TypePackId& arg = typePackArguments[tp]; - if (arg) - return arg; - else - return addTypePack(FreeTypePack{level}); - } + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; else - { return addTypePack(FreeTypePack{level}); - } } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) + if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; - if (FFlag::LuauTypeAliasPacks) - { - applyTypeFunction.typePackArguments.clear(); - for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; - } + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -4866,9 +4795,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = typePackParams; + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4884,9 +4811,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = typePackParams; + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4914,7 +4839,7 @@ std::pair, std::vector> TypeChecker::createGener } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + if (FFlag::LuauRecursiveTypeParameterRestriction) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) @@ -4944,7 +4869,7 @@ std::pair, std::vector> TypeChecker::createGener } TypePackId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + if (FFlag::LuauRecursiveTypeParameterRestriction) { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) @@ -5245,7 +5170,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) + if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 924bf082..62715af5 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,7 +19,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) @@ -739,369 +738,6 @@ void persist(TypePackId tp) } } -namespace -{ - -struct StateDot -{ - StateDot(ToDotOptions opts) - : opts(opts) - { - } - - ToDotOptions opts; - - std::unordered_set seenTy; - std::unordered_set seenTp; - std::unordered_map tyToIndex; - std::unordered_map tpToIndex; - int nextIndex = 1; - std::string result; - - bool canDuplicatePrimitive(TypeId ty); - - void visitChildren(TypeId ty, int index); - void visitChildren(TypePackId ty, int index); - - void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); - void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); - - void startNode(int index); - void finishNode(); - - void startNodeLabel(); - void finishNodeLabel(TypeId ty); - void finishNodeLabel(TypePackId tp); -}; - -bool StateDot::canDuplicatePrimitive(TypeId ty) -{ - if (get(ty)) - return false; - - return get(ty) || get(ty); -} - -void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) -{ - if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) - tyToIndex[ty] = nextIndex++; - - int index = tyToIndex[ty]; - - if (parentIndex != 0) - { - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, index); - } - - if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) - { - if (get(ty)) - formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (get(ty)) - formatAppend(result, "n%d [label=\"any\"];\n", index); - } - else - { - visitChildren(ty, index); - } -} - -void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) -{ - if (!tpToIndex.count(tp)) - tpToIndex[tp] = nextIndex++; - - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); - - visitChildren(tp, tpToIndex[tp]); -} - -void StateDot::startNode(int index) -{ - formatAppend(result, "n%d [", index); -} - -void StateDot::finishNode() -{ - formatAppend(result, "];\n"); -} - -void StateDot::startNodeLabel() -{ - formatAppend(result, "label=\""); -} - -void StateDot::finishNodeLabel(TypeId ty) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", ty); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::finishNodeLabel(TypePackId tp) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", tp); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::visitChildren(TypeId ty, int index) -{ - if (seenTy.count(ty)) - return; - seenTy.insert(ty); - - startNode(index); - startNodeLabel(); - - if (const BoundTypeVar* btv = get(ty)) - { - formatAppend(result, "BoundTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(btv->boundTo, index); - } - else if (const FunctionTypeVar* ftv = get(ty)) - { - formatAppend(result, "FunctionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); - } - else if (const TableTypeVar* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); - - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); - if (ttv->indexer) - { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); - } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); - - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - } - else if (const MetatableTypeVar* mtv = get(ty)) - { - formatAppend(result, "MetatableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionTypeVar* utv = get(ty)) - { - formatAppend(result, "UnionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - formatAppend(result, "IntersectionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericTypeVar* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); - else - formatAppend(result, "GenericTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeTypeVar* ftv = get(ty)) - { - formatAppend(result, "FreeTypeVar %d", ftv->index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassTypeVar* ctv = get(ty)) - { - formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); - - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); - - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); - - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } -} - -void StateDot::visitChildren(TypePackId tp, int index) -{ - if (seenTp.count(tp)) - return; - seenTp.insert(tp); - - startNode(index); - startNodeLabel(); - - if (const BoundTypePack* btp = get(tp)) - { - formatAppend(result, "BoundTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(btp->boundTo, index); - } - else if (const TypePack* tpp = get(tp)) - { - formatAppend(result, "TypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - for (TypeId tv : tpp->head) - visitChild(tv, index); - if (tpp->tail) - visitChild(*tpp->tail, index, "tail"); - } - else if (const VariadicTypePack* vtp = get(tp)) - { - formatAppend(result, "VariadicTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(vtp->ty, index); - } - else if (const FreeTypePack* ftp = get(tp)) - { - formatAppend(result, "FreeTypePack %d", ftp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (const GenericTypePack* gtp = get(tp)) - { - if (gtp->explicitName) - formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); - else - formatAppend(result, "GenericTypePack %d", gtp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (get(tp)) - { - formatAppend(result, "ErrorTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type pack kind"); - finishNodeLabel(tp); - finishNode(); - } -} - -} // namespace - -std::string toDot(TypeId ty, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(ty, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypePackId tp, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(tp, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypeId ty) -{ - return toDot(ty, {}); -} - -std::string toDot(TypePackId tp) -{ - return toDot(tp, {}); -} - -void dumpDot(TypeId ty) -{ - printf("%s\n", toDot(ty).c_str()); -} - -void dumpDot(TypePackId tp) -{ - printf("%s\n", toDot(tp).c_str()); -} - const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index e1a52be4..d0b18837 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,9 +18,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) -LUAU_FASTFLAG(LuauShareTxnSeen); -LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) @@ -136,38 +133,19 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(&countersData) - , counters_DEPRECATED(std::make_shared()) - , sharedState(sharedState) -{ - LUAU_ASSERT(sharedState.iceHandler); -} - -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) - : types(types) - , mode(mode) - , globalScope(std::move(globalScope)) - , log(ownedSeen) - , location(location) - , variance(variance) - , counters(counters ? counters : &countersData) - , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , sharedState(sharedState) { LUAU_ASSERT(sharedState.iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(sharedSeen) , location(location) , variance(variance) - , counters(counters ? counters : &countersData) - , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , sharedState(sharedState) { LUAU_ASSERT(sharedState.iceHandler); @@ -175,26 +153,18 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector< void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - if (FFlag::LuauTypecheckOpts) - counters->iterationCount = 0; - else - counters_DEPRECATED->iterationCount = 0; + sharedState.counters.iterationCount = 0; tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - if (FFlag::LuauTypecheckOpts) - ++counters->iterationCount; - else - ++counters_DEPRECATED->iterationCount; + ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && - FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -302,7 +272,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); - bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + bool cacheEnabled = !isFunctionCall && !isIntersection; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before @@ -563,8 +533,6 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool void Unifier::cacheResult(TypeId superTy, TypeId subTy) { - LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); - bool* superTyInfo = sharedState.skipCacheForType.find(superTy); if (superTyInfo && *superTyInfo) @@ -686,10 +654,7 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - if (FFlag::LuauTypecheckOpts) - counters->iterationCount = 0; - else - counters_DEPRECATED->iterationCount = 0; + sharedState.counters.iterationCount = 0; tryUnify_(superTp, subTp, isFunctionCall); } @@ -700,16 +665,11 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - if (FFlag::LuauTypecheckOpts) - ++counters->iterationCount; - else - ++counters_DEPRECATED->iterationCount; + ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && - FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -1727,39 +1687,8 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack_DEPRECATED( - std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOpts); - - while (true) - { - a = follow(a); - - if (seenTypePacks.count(a)) - break; - seenTypePacks.insert(a); - - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } - } -} - static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { - LUAU_ASSERT(FFlag::LuauTypecheckOpts); - while (true) { a = follow(a); @@ -1837,66 +1766,9 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny_DEPRECATED( - std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOpts); - - std::unordered_set seen; - - while (!queue.empty()) - { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.count(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); - - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. - } -} - static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { - LUAU_ASSERT(FFlag::LuauTypecheckOpts); - while (!queue.empty()) { TypeId ty = follow(queue.back()); @@ -1949,43 +1821,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); - if (FFlag::LuauTypecheckOpts) - { - // These types are not visited in general loop below - if (get(ty) || get(ty) || get(ty)) - return; - } + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - if (FFlag::LuauTypecheckOpts) - { - std::vector queue = {ty}; + std::vector queue = {ty}; - if (FFlag::LuauCacheUnifyTableResults) - { - sharedState.tempSeenTy.clear(); - sharedState.tempSeenTp.clear(); + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); - } - else - { - tempSeenTy_DEPRECATED.clear(); - tempSeenTp_DEPRECATED.clear(); - - Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); - } - } - else - { - std::unordered_set seenTypePacks; - std::vector queue = {ty}; - - Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); - } + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1994,38 +1843,14 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorRecoveryType(); - if (FFlag::LuauTypecheckOpts) - { - std::vector queue; + std::vector queue; - if (FFlag::LuauCacheUnifyTableResults) - { - sharedState.tempSeenTy.clear(); - sharedState.tempSeenTp.clear(); + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); - } - else - { - tempSeenTy_DEPRECATED.clear(); - tempSeenTp_DEPRECATED.clear(); - - queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); - - Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); - } - } - else - { - std::unordered_set seenTypePacks; - std::vector queue; - - queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); - - Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); - } + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -2035,46 +1860,22 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen_DEPRECATED; + sharedState.tempSeenTy.clear(); - if (FFlag::LuauCacheUnifyTableResults) - { - if (FFlag::LuauTypecheckOpts) - sharedState.tempSeenTy.clear(); - - return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); - } - else - { - if (FFlag::LuauTypecheckOpts) - tempSeenTy_DEPRECATED.clear(); - - return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); - } + return occursCheck(sharedState.tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (FFlag::LuauTypecheckOpts) - { - if (seen.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen.insert(haystack); - } - else - { - if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) - return; - - seen_DEPRECATED.insert(haystack); - } + seen.insert(haystack); if (get(needle)) return; @@ -2091,7 +1892,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash } auto check = [&](TypeId tv) { - occursCheck(seen_DEPRECATED, seen, needle, tv); + occursCheck(seen, needle, tv); }; if (get(haystack)) @@ -2121,43 +1922,20 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen_DEPRECATED; + sharedState.tempSeenTp.clear(); - if (FFlag::LuauCacheUnifyTableResults) - { - if (FFlag::LuauTypecheckOpts) - sharedState.tempSeenTp.clear(); - - return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); - } - else - { - if (FFlag::LuauTypecheckOpts) - tempSeenTp_DEPRECATED.clear(); - - return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); - } + return occursCheck(sharedState.tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (FFlag::LuauTypecheckOpts) - { - if (seen.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen.insert(haystack); - } - else - { - if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) - return; - - seen_DEPRECATED.insert(haystack); - } + seen.insert(haystack); if (get(needle)) return; @@ -2165,8 +1943,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -2186,8 +1963,8 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { if (auto f = get(follow(ty))) { - occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); - occursCheck(seen_DEPRECATED, seen, needle, f->retType); + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); } } } @@ -2204,10 +1981,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauShareTxnSeen) - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; - else - return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index bc63e37d..3d0d5b7e 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,8 +13,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) -LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) @@ -782,8 +780,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstType* type = parseTypeAnnotation(); - return allocator.alloc( - Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); + return allocator.alloc(Location(start, type->location), name->name, generics, genericPacks, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1602,30 +1599,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(Location(begin, end), expr), {}}; } - if (FFlag::LuauParseTypePackTypeParameters) + bool hasParameters = false; + AstArray parameters{}; + + if (lexer.current().type == '<') { - bool hasParameters = false; - AstArray parameters{}; - - if (lexer.current().type == '<') - { - hasParameters = true; - parameters = parseTypeParams(); - } - - Location end = lexer.previousLocation(); - - return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; + hasParameters = true; + parameters = parseTypeParams(); } - else - { - AstArray generics = parseTypeParams(); - Location end = lexer.previousLocation(); + Location end = lexer.previousLocation(); - // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks - return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; - } + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; } else if (lexer.current().type == '{') { @@ -2414,37 +2399,24 @@ AstArray Parser::parseTypeParams() while (true) { - if (FFlag::LuauParseTypePackTypeParameters) + if (shouldParseTypePackAnnotation(lexer)) { - if (shouldParseTypePackAnnotation(lexer)) - { - auto typePack = parseTypePackAnnotation(); + auto typePack = parseTypePackAnnotation(); - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them - parameters.push_back({{}, typePack}); - } - else if (lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPackAnnotation(); + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); - if (typePack) - { - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them - parameters.push_back({{}, typePack}); - } - else - { - parameters.push_back({type, {}}); - } - } - else if (lexer.current().type == '>' && parameters.empty()) - { - break; - } + if (typePack) + parameters.push_back({{}, typePack}); else - { - parameters.push_back({parseTypeAnnotation(), {}}); - } + parameters.push_back({type, {}}); + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; } else { diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index ebdd7896..9230d80d 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -121,7 +121,7 @@ struct CliFileResolver : Luau::FileResolver if (Luau::AstExprConstantString* expr = node->as()) { Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; - if (!moduleExists(name)) + if (!readFile(name)) { // fall back to .lua if a module with .luau doesn't exist name = std::string(expr->value.data, expr->value.size) + ".lua"; @@ -132,27 +132,6 @@ struct CliFileResolver : Luau::FileResolver return std::nullopt; } - - bool moduleExists(const Luau::ModuleName& name) const override - { - return !!readFile(name); - } - - - std::optional fromAstFragment(Luau::AstExpr* expr) const override - { - return std::nullopt; - } - - Luau::ModuleName concat(const Luau::ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + std::string(rhs); - } - - std::optional getParentModuleName(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index b29cd6f9..2cdd0062 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -198,11 +198,6 @@ static std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); -#ifdef __EMSCRIPTEN__ - // nicer formatting for errors in web repl - error = "Error:" + error; -#endif - fprintf(stdout, "%s", error.c_str()); } @@ -210,39 +205,6 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } -#ifdef __EMSCRIPTEN__ -extern "C" -{ - const char* executeScript(const char* source) - { - // setup flags - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; - - // create new state - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - // setup state - setupState(L); - - // sandbox thread - luaL_sandboxthread(L); - - // static string for caching result (prevents dangling ptr on function exit) - static std::string result; - - // run code + collect error - result = runCode(L, source); - - return result.empty() ? NULL : result.c_str(); - } -} -#endif - -// Excluded from emscripten compilation to avoid -Wunused-function errors. -#ifndef __EMSCRIPTEN__ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; @@ -564,6 +526,5 @@ int main(int argc, char** argv) return failed; } } -#endif diff --git a/CLI/Web.cpp b/CLI/Web.cpp new file mode 100644 index 00000000..cf5c831e --- /dev/null +++ b/CLI/Web.cpp @@ -0,0 +1,106 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" +#include "luacode.h" + +#include "Luau/Common.h" + +#include + +#include + +static void setupState(lua_State* L) +{ + luaL_openlibs(L); + + luaL_sandbox(L); +} + +static std::string runCode(lua_State* L, const std::string& source) +{ + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.length(), nullptr, &bytecodeSize); + int result = luau_load(L, "=stdin", bytecode, bytecodeSize, 0); + free(bytecode); + + if (result != 0) + { + size_t len; + const char* msg = lua_tolstring(L, -1, &len); + + std::string error(msg, len); + lua_pop(L, 1); + + return error; + } + + lua_State* T = lua_newthread(L); + + lua_pushvalue(L, -2); + lua_remove(L, -3); + lua_xmove(L, T, 1); + + int status = lua_resume(T, NULL, 0); + + if (status == 0) + { + int n = lua_gettop(T); + + if (n) + { + luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); + lua_getglobal(T, "print"); + lua_insert(T, 1); + lua_pcall(T, n, 0, 0); + } + } + else + { + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + + error += "\nstack backtrace:\n"; + error += lua_debugtrace(T); + + error = "Error:" + error; + + fprintf(stdout, "%s", error.c_str()); + } + + lua_pop(L, 1); + return std::string(); +} + +extern "C" const char* executeScript(const char* source) +{ + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c69521e..bafc59e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) add_library(Luau.Ast STATIC) @@ -18,26 +19,22 @@ add_library(Luau.VM STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) - if(NOT EMSCRIPTEN) - add_executable(Luau.Analyze.CLI) - else() - # add -fexceptions for emscripten to allow exceptions to be caught in C++ - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") - endif() + add_executable(Luau.Analyze.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) - - if(NOT EMSCRIPTEN) - set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) - endif() + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) endif() -if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) +if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) endif() +if(LUAU_BUILD_WEB) + add_executable(Luau.Web) +endif() + include(Sources.cmake) target_compile_features(Luau.Ast PUBLIC cxx_std_17) @@ -72,16 +69,18 @@ if(LUAU_WERROR) endif() endif() +if(LUAU_BUILD_WEB) + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + list(APPEND LUAU_OPTIONS -fexceptions) +endif() + target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) - - if(NOT EMSCRIPTEN) - target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) - endif() + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) @@ -93,20 +92,10 @@ if(LUAU_BUILD_CLI) endif() endif() - if(NOT EMSCRIPTEN) - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) - endif() - - if(EMSCRIPTEN) - # declare exported functions to emscripten - target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions) - - # custom output directory for wasm + js file - set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau) - endif() + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) endif() -if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) +if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) @@ -115,3 +104,17 @@ if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) endif() + +if(LUAU_BUILD_WEB) + target_compile_options(Luau.Web PRIVATE ${LUAU_OPTIONS}) + target_link_libraries(Luau.Web PRIVATE Luau.Compiler Luau.VM) + + # declare exported functions to emscripten + target_link_options(Luau.Web PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap']) + + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + target_link_options(Luau.Web PRIVATE -fexceptions) + + # the output is a single .js file with an embedded wasm blob + target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1) +endif() diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 5b93c1dc..2c1e85ff 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -321,13 +321,15 @@ struct Compiler compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); } - setDebugLine(expr->func); + setDebugLineEnd(expr->func); if (expr->self) { AstExprIndexName* fi = expr->func->as(); LUAU_ASSERT(fi); + setDebugLine(fi->indexLocation); + BytecodeBuilder::StringRef iname = sref(fi->index); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) @@ -1313,6 +1315,8 @@ struct Compiler RegScope rs(this); uint8_t reg = compileExprAuto(expr->expr, rs); + setDebugLine(expr->indexLocation); + BytecodeBuilder::StringRef iname = sref(expr->index); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) @@ -2710,6 +2714,12 @@ struct Compiler bytecode.setDebugLine(node->location.begin.line + 1); } + void setDebugLine(const Location& location) + { + if (options.debugLevel >= 1) + bytecode.setDebugLine(location.begin.line + 1); + } + void setDebugLineEnd(AstNode* node) { if (options.debugLevel >= 1) @@ -3650,7 +3660,7 @@ struct Compiler { if (options.vectorLib) { - if (builtin.object == options.vectorLib && builtin.method == options.vectorCtor) + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) return LBF_VECTOR; } else diff --git a/Makefile b/Makefile index cab3d43f..15c7ff7a 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze -FUZZ_SOURCES=$(wildcard fuzz/*.cpp) +FUZZ_SOURCES=$(wildcard fuzz/*.cpp) fuzz/luau.pb.cpp FUZZ_OBJECTS=$(FUZZ_SOURCES:%=$(BUILD)/%.o) TESTS_ARGS= @@ -167,8 +167,8 @@ fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. mv fuzz/luau.pb.cc fuzz/luau.pb.cpp -$(BUILD)/fuzz/proto.cpp.o: build/libprotobuf-mutator -$(BUILD)/fuzz/protoprint.cpp.o: build/libprotobuf-mutator +$(BUILD)/fuzz/proto.cpp.o: fuzz/luau.pb.cpp +$(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator diff --git a/Sources.cmake b/Sources.cmake index 23b931c6..57df9b91 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -54,6 +54,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h + Analysis/include/Luau/ToDot.h Analysis/include/Luau/TopoSortStatements.h Analysis/include/Luau/ToString.h Analysis/include/Luau/Transpiler.h @@ -86,6 +87,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp + Analysis/src/ToDot.cpp Analysis/src/TopoSortStatements.cpp Analysis/src/ToString.cpp Analysis/src/Transpiler.cpp @@ -118,6 +120,7 @@ target_sources(Luau.VM PRIVATE VM/src/ldo.cpp VM/src/lfunc.cpp VM/src/lgc.cpp + VM/src/lgcdebug.cpp VM/src/linit.cpp VM/src/lmathlib.cpp VM/src/lmem.cpp @@ -194,6 +197,7 @@ if(TARGET Luau.UnitTest) tests/RequireTracer.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp + tests/ToDot.test.cpp tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp @@ -224,3 +228,9 @@ if(TARGET Luau.Conformance) tests/Conformance.test.cpp tests/main.cpp) endif() + +if(TARGET Luau.Web) + # Luau.Web Sources + target_sources(Luau.Web PRIVATE + CLI/Web.cpp) +endif() diff --git a/VM/include/lua.h b/VM/include/lua.h index 1568d191..7078acd0 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -21,6 +21,7 @@ #define LUA_ENVIRONINDEX (-10001) #define LUA_GLOBALSINDEX (-10002) #define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i)) +#define lua_ispseudo(i) ((i) <= LUA_REGISTRYINDEX) /* thread status; 0 is OK */ enum lua_Status @@ -108,6 +109,7 @@ LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation */ +LUA_API int lua_absindex(lua_State* L, int idx); LUA_API int lua_gettop(lua_State* L); LUA_API void lua_settop(lua_State* L, int idx); LUA_API void lua_pushvalue(lua_State* L, int idx); @@ -159,7 +161,11 @@ LUA_API void lua_pushnil(lua_State* L); LUA_API void lua_pushnumber(lua_State* L, double n); LUA_API void lua_pushinteger(lua_State* L, int n); LUA_API void lua_pushunsigned(lua_State* L, unsigned n); +#if LUA_VECTOR_SIZE == 4 +LUA_API void lua_pushvector(lua_State* L, float x, float y, float z, float w); +#else LUA_API void lua_pushvector(lua_State* L, float x, float y, float z); +#endif LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); @@ -183,7 +189,7 @@ LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); -LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API int lua_getmetatable(lua_State* L, int objindex); LUA_API void lua_getfenv(lua_State* L, int idx); @@ -227,6 +233,7 @@ enum lua_GCOp LUA_GCRESTART, LUA_GCCOLLECT, LUA_GCCOUNT, + LUA_GCCOUNTB, LUA_GCISRUNNING, // garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation @@ -281,6 +288,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_pop(L, n) lua_settop(L, -(n)-1) #define lua_newtable(L) lua_createtable(L, 0, 0) +#define lua_newuserdata(L, s) lua_newuserdatatagged(L, s, 0) #define lua_strlen(L, i) lua_objlen(L, (i)) @@ -289,6 +297,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_islightuserdata(L, n) (lua_type(L, (n)) == LUA_TLIGHTUSERDATA) #define lua_isnil(L, n) (lua_type(L, (n)) == LUA_TNIL) #define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN) +#define lua_isvector(L, n) (lua_type(L, (n)) == LUA_TVECTOR) #define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD) #define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE) #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index aa008a24..a01a1481 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -34,7 +34,10 @@ #endif /* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUA_API #define LUA_API extern +#endif + #define LUALIB_API LUA_API /* Can be used to reconfigure visibility for internal APIs */ @@ -47,10 +50,14 @@ #endif /* Can be used to reconfigure internal error handling to use longjmp instead of C++ EH */ +#ifndef LUA_USE_LONGJMP #define LUA_USE_LONGJMP 0 +#endif /* LUA_IDSIZE gives the maximum size for the description of the source */ +#ifndef LUA_IDSIZE #define LUA_IDSIZE 256 +#endif /* @@ LUAI_GCGOAL defines the desired top heap size in relation to the live heap @@ -59,7 +66,9 @@ ** mean larger GC pauses which mean slower collection.) You can also change ** this value dynamically. */ +#ifndef LUAI_GCGOAL #define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#endif /* @@ LUAI_GCSTEPMUL / LUAI_GCSTEPSIZE define the default speed of garbage collection @@ -69,38 +78,63 @@ ** CHANGE it if you want to change the granularity of the garbage ** collection. */ +#ifndef LUAI_GCSTEPMUL #define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ +#endif + +#ifndef LUAI_GCSTEPSIZE #define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#endif /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ +#ifndef LUA_MINSTACK #define LUA_MINSTACK 20 +#endif /* LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use */ +#ifndef LUAI_MAXCSTACK #define LUAI_MAXCSTACK 8000 +#endif /* LUAI_MAXCALLS limits the number of nested calls */ +#ifndef LUAI_MAXCALLS #define LUAI_MAXCALLS 20000 +#endif /* LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size */ +#ifndef LUAI_MAXCCALLS #define LUAI_MAXCCALLS 200 +#endif /* buffer size used for on-stack string operations; this limit depends on native stack size */ +#ifndef LUA_BUFFERSIZE #define LUA_BUFFERSIZE 512 +#endif /* number of valid Lua userdata tags */ +#ifndef LUA_UTAG_LIMIT #define LUA_UTAG_LIMIT 128 +#endif /* upper bound for number of size classes used by page allocator */ +#ifndef LUA_SIZECLASSES #define LUA_SIZECLASSES 32 +#endif /* available number of separate memory categories */ +#ifndef LUA_MEMORY_CATEGORIES #define LUA_MEMORY_CATEGORIES 256 +#endif /* minimum size for the string table (must be power of 2) */ +#ifndef LUA_MINSTRTABSIZE #define LUA_MINSTRTABSIZE 32 +#endif /* maximum number of captures supported by pattern matching */ +#ifndef LUA_MAXCAPTURES #define LUA_MAXCAPTURES 32 +#endif /* }================================================================== */ @@ -122,3 +156,7 @@ void* s; \ long l; \ } + +#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ + +#define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2 diff --git a/VM/include/lualib.h b/VM/include/lualib.h index fa836955..baf27b47 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -25,11 +25,17 @@ LUALIB_API const char* luaL_optlstring(lua_State* L, int numArg, const char* def LUALIB_API double luaL_checknumber(lua_State* L, int numArg); LUALIB_API double luaL_optnumber(lua_State* L, int nArg, double def); +LUALIB_API int luaL_checkboolean(lua_State* L, int narg); +LUALIB_API int luaL_optboolean(lua_State* L, int narg, int def); + LUALIB_API int luaL_checkinteger(lua_State* L, int numArg); LUALIB_API int luaL_optinteger(lua_State* L, int nArg, int def); LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int numArg); LUALIB_API unsigned luaL_optunsigned(lua_State* L, int numArg, unsigned def); +LUALIB_API const float* luaL_checkvector(lua_State* L, int narg); +LUALIB_API const float* luaL_optvector(lua_State* L, int narg, const float* def); + LUALIB_API void luaL_checkstack(lua_State* L, int sz, const char* msg); LUALIB_API void luaL_checktype(lua_State* L, int narg, int t); LUALIB_API void luaL_checkany(lua_State* L, int narg); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index a79ba0d4..76043b9c 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAG(LuauActivateBeforeExec) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -170,6 +172,12 @@ lua_State* lua_mainthread(lua_State* L) ** basic stack manipulation */ +int lua_absindex(lua_State* L, int idx) +{ + api_check(L, (idx > 0 && idx <= L->top - L->base) || (idx < 0 && -idx <= L->top - L->base) || lua_ispseudo(idx)); + return idx > 0 || lua_ispseudo(idx) ? idx : cast_int(L->top - L->base) + idx + 1; +} + int lua_gettop(lua_State* L) { return cast_int(L->top - L->base); @@ -550,12 +558,21 @@ void lua_pushunsigned(lua_State* L, unsigned u) return; } -void lua_pushvector(lua_State* L, float x, float y, float z) +#if LUA_VECTOR_SIZE == 4 +void lua_pushvector(lua_State* L, float x, float y, float z, float w) { - setvvalue(L->top, x, y, z); + setvvalue(L->top, x, y, z, w); api_incr_top(L); return; } +#else +void lua_pushvector(lua_State* L, float x, float y, float z) +{ + setvvalue(L->top, x, y, z, 0.0f); + api_incr_top(L); + return; +} +#endif void lua_pushlstring(lua_State* L, const char* s, size_t len) { @@ -922,14 +939,21 @@ void lua_call(lua_State* L, int nargs, int nresults) checkresults(L, nargs, nresults); func = L->top - (nargs + 1); - int wasActive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + if (FFlag::LuauActivateBeforeExec) + { + luaD_call(L, func, nresults); + } + else + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - luaD_call(L, func, nresults); + luaD_call(L, func, nresults); - if (!wasActive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } adjustresults(L, nresults); return; @@ -970,14 +994,21 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - int wasActive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + if (FFlag::LuauActivateBeforeExec) + { + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + } + else + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - if (!wasActive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } adjustresults(L, nresults); return status; @@ -1030,6 +1061,11 @@ int lua_gc(lua_State* L, int what, int data) res = cast_int(g->totalbytes >> 10); break; } + case LUA_GCCOUNTB: + { + res = cast_int(g->totalbytes & 1023); + break; + } case LUA_GCISRUNNING: { res = (g->GCthreshold != SIZE_MAX); @@ -1146,7 +1182,7 @@ void lua_concat(lua_State* L, int n) return; } -void* lua_newuserdata(lua_State* L, size_t sz, int tag) +void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); @@ -1231,6 +1267,7 @@ uintptr_t lua_encodepointer(lua_State* L, uintptr_t p) int lua_ref(lua_State* L, int idx) { + api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; StkId p = index2adr(L, idx); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 2a684ee4..7ed2a62e 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -30,7 +30,7 @@ static const char* currfuncname(lua_State* L) return debugname; } -LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) +l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) { const char* fname = currfuncname(L); @@ -40,7 +40,7 @@ LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) luaL_error(L, "invalid argument #%d (%s)", narg, extramsg); } -LUALIB_API l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) +l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) { const char* fname = currfuncname(L); const TValue* obj = luaA_toobject(L, narg); @@ -66,7 +66,7 @@ static l_noret tag_error(lua_State* L, int narg, int tag) luaL_typeerrorL(L, narg, lua_typename(L, tag)); } -LUALIB_API void luaL_where(lua_State* L, int level) +void luaL_where(lua_State* L, int level) { lua_Debug ar; if (lua_getinfo(L, level, "sl", &ar) && ar.currentline > 0) @@ -77,7 +77,7 @@ LUALIB_API void luaL_where(lua_State* L, int level) lua_pushliteral(L, ""); /* else, no information available... */ } -LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) +l_noret luaL_errorL(lua_State* L, const char* fmt, ...) { va_list argp; va_start(argp, fmt); @@ -90,7 +90,7 @@ LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) /* }====================================================== */ -LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) +int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) { const char* name = (def) ? luaL_optstring(L, narg, def) : luaL_checkstring(L, narg); int i; @@ -101,7 +101,7 @@ LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const c luaL_argerrorL(L, narg, msg); } -LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) +int luaL_newmetatable(lua_State* L, const char* tname) { lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get registry.name */ if (!lua_isnil(L, -1)) /* name already in use? */ @@ -113,7 +113,7 @@ LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) return 1; } -LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) +void* luaL_checkudata(lua_State* L, int ud, const char* tname) { void* p = lua_touserdata(L, ud); if (p != NULL) @@ -131,25 +131,25 @@ LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) luaL_typeerrorL(L, ud, tname); /* else error */ } -LUALIB_API void luaL_checkstack(lua_State* L, int space, const char* mes) +void luaL_checkstack(lua_State* L, int space, const char* mes) { if (!lua_checkstack(L, space)) luaL_error(L, "stack overflow (%s)", mes); } -LUALIB_API void luaL_checktype(lua_State* L, int narg, int t) +void luaL_checktype(lua_State* L, int narg, int t) { if (lua_type(L, narg) != t) tag_error(L, narg, t); } -LUALIB_API void luaL_checkany(lua_State* L, int narg) +void luaL_checkany(lua_State* L, int narg) { if (lua_type(L, narg) == LUA_TNONE) luaL_error(L, "missing argument #%d", narg); } -LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) +const char* luaL_checklstring(lua_State* L, int narg, size_t* len) { const char* s = lua_tolstring(L, narg, len); if (!s) @@ -157,7 +157,7 @@ LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) return s; } -LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) +const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) { if (lua_isnoneornil(L, narg)) { @@ -169,7 +169,7 @@ LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, return luaL_checklstring(L, narg, len); } -LUALIB_API double luaL_checknumber(lua_State* L, int narg) +double luaL_checknumber(lua_State* L, int narg) { int isnum; double d = lua_tonumberx(L, narg, &isnum); @@ -178,12 +178,28 @@ LUALIB_API double luaL_checknumber(lua_State* L, int narg) return d; } -LUALIB_API double luaL_optnumber(lua_State* L, int narg, double def) +double luaL_optnumber(lua_State* L, int narg, double def) { return luaL_opt(L, luaL_checknumber, narg, def); } -LUALIB_API int luaL_checkinteger(lua_State* L, int narg) +int luaL_checkboolean(lua_State* L, int narg) +{ + // This checks specifically for boolean values, ignoring + // all other truthy/falsy values. If the desired result + // is true if value is present then lua_toboolean should + // directly be used instead. + if (!lua_isboolean(L, narg)) + tag_error(L, narg, LUA_TBOOLEAN); + return lua_toboolean(L, narg); +} + +int luaL_optboolean(lua_State* L, int narg, int def) +{ + return luaL_opt(L, luaL_checkboolean, narg, def); +} + +int luaL_checkinteger(lua_State* L, int narg) { int isnum; int d = lua_tointegerx(L, narg, &isnum); @@ -192,12 +208,12 @@ LUALIB_API int luaL_checkinteger(lua_State* L, int narg) return d; } -LUALIB_API int luaL_optinteger(lua_State* L, int narg, int def) +int luaL_optinteger(lua_State* L, int narg, int def) { return luaL_opt(L, luaL_checkinteger, narg, def); } -LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) +unsigned luaL_checkunsigned(lua_State* L, int narg) { int isnum; unsigned d = lua_tounsignedx(L, narg, &isnum); @@ -206,12 +222,25 @@ LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) return d; } -LUALIB_API unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) +unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) { return luaL_opt(L, luaL_checkunsigned, narg, def); } -LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) +const float* luaL_checkvector(lua_State* L, int narg) +{ + const float* v = lua_tovector(L, narg); + if (!v) + tag_error(L, narg, LUA_TVECTOR); + return v; +} + +const float* luaL_optvector(lua_State* L, int narg, const float* def) +{ + return luaL_opt(L, luaL_checkvector, narg, def); +} + +int luaL_getmetafield(lua_State* L, int obj, const char* event) { if (!lua_getmetatable(L, obj)) /* no metatable? */ return 0; @@ -229,7 +258,7 @@ LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) } } -LUALIB_API int luaL_callmeta(lua_State* L, int obj, const char* event) +int luaL_callmeta(lua_State* L, int obj, const char* event) { obj = abs_index(L, obj); if (!luaL_getmetafield(L, obj, event)) /* no metafield? */ @@ -247,7 +276,7 @@ static int libsize(const luaL_Reg* l) return size; } -LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) +void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) { if (libname) { @@ -273,7 +302,7 @@ LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* } } -LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) +const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) { const char* e; lua_pushvalue(L, idx); @@ -324,7 +353,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired return newsize; } -LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) +void luaL_buffinit(lua_State* L, luaL_Buffer* B) { // start with an internal buffer B->p = B->buffer; @@ -334,14 +363,14 @@ LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) B->storage = nullptr; } -LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) +char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) { luaL_buffinit(L, B); luaL_reservebuffer(B, size, -1); return B->p; } -LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) +char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) { lua_State* L = B->L; @@ -372,13 +401,13 @@ LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int bo return B->p; } -LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) +void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) { if (size_t(B->end - B->p) < size) luaL_extendbuffer(B, size - (B->end - B->p), boxloc); } -LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) +void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) { if (size_t(B->end - B->p) < len) luaL_extendbuffer(B, len - (B->end - B->p), -1); @@ -387,7 +416,7 @@ LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) B->p += len; } -LUALIB_API void luaL_addvalue(luaL_Buffer* B) +void luaL_addvalue(luaL_Buffer* B) { lua_State* L = B->L; @@ -404,7 +433,7 @@ LUALIB_API void luaL_addvalue(luaL_Buffer* B) } } -LUALIB_API void luaL_pushresult(luaL_Buffer* B) +void luaL_pushresult(luaL_Buffer* B) { lua_State* L = B->L; @@ -428,7 +457,7 @@ LUALIB_API void luaL_pushresult(luaL_Buffer* B) } } -LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) +void luaL_pushresultsize(luaL_Buffer* B, size_t size) { B->p += size; luaL_pushresult(B); @@ -436,7 +465,7 @@ LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) /* }====================================================== */ -LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) +const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { if (luaL_callmeta(L, idx, "__tostring")) /* is there a metafield? */ { @@ -462,7 +491,11 @@ LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); +#if LUA_VECTOR_SIZE == 4 + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); +#else lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); +#endif break; } default: diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 61798e2b..881c804d 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -401,7 +401,7 @@ static int luaB_newproxy(lua_State* L) bool needsmt = lua_toboolean(L, 1); - lua_newuserdata(L, 0, 0); + lua_newuserdata(L, 0); if (needsmt) { @@ -441,7 +441,7 @@ static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFuncti lua_setfield(L, -2, name); } -LUALIB_API int luaopen_base(lua_State* L) +int luaopen_base(lua_State* L) { /* set global _G */ lua_pushvalue(L, LUA_GLOBALSINDEX); diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 907c43c4..8b511edf 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -236,7 +236,7 @@ static const luaL_Reg bitlib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_bit32(lua_State* L) +int luaopen_bit32(lua_State* L) { luaL_register(L, LUA_BITLIBNAME, bitlib); diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 9ab57ac9..34e9ebc1 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1018,13 +1018,23 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { +#if LUA_VECTOR_SIZE == 4 + if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) +#else if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) +#endif { double x = nvalue(arg0); double y = nvalue(args); double z = nvalue(args + 1); - setvvalue(res, float(x), float(y), float(z)); +#if LUA_VECTOR_SIZE == 4 + double w = nvalue(args + 2); + setvvalue(res, float(x), float(y), float(z), float(w)); +#else + setvvalue(res, float(x), float(y), float(z), 0.0f); +#endif + return 1; } diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 0178fae8..abcde779 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -272,7 +272,7 @@ static const luaL_Reg co_funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_coroutine(lua_State* L) +int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index 965d2b3d..93d8703a 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -160,7 +160,7 @@ static const luaL_Reg dblib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_debug(lua_State* L) +int luaopen_debug(lua_State* L) { luaL_register(L, LUA_DBLIBNAME, dblib); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 1259d461..62bbdb7c 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,9 +17,9 @@ #include -LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) +LUAU_FASTFLAGVARIABLE(LuauActivateBeforeExec, false) /* ** {====================================================== @@ -74,35 +74,28 @@ public: const char* what() const throw() override { - if (FFlag::LuauExceptionMessageFix) + // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. + if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) + // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + if (const char* str = lua_tostring(L, -1)) { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. - if (const char* str = lua_tostring(L, -1)) - { - return str; - } - } - - switch (status) - { - case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; - case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; - case LUA_ERRMEM: - return "lua_exception: " LUA_MEMERRMSG; - case LUA_ERRERR: - return "lua_exception: " LUA_ERRERRMSG; - default: - return "lua_exception: unexpected exception status"; + return str; } } - else + + switch (status) { - return lua_tostring(L, -1); + case LUA_ERRRUN: + return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + case LUA_ERRSYNTAX: + return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + case LUA_ERRMEM: + return "lua_exception: " LUA_MEMERRMSG; + case LUA_ERRERR: + return "lua_exception: " LUA_ERRERRMSG; + default: + return "lua_exception: unexpected exception status"; } } @@ -234,7 +227,22 @@ void luaD_call(lua_State* L, StkId func, int nResults) if (luau_precall(L, func, nResults) == PCRLUA) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ - luau_execute(L); /* call it */ + + if (FFlag::LuauActivateBeforeExec) + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); + + luau_execute(L); /* call it */ + + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } + else + { + luau_execute(L); /* call it */ + } } L->nCcalls--; luaC_checkGC(L); @@ -527,10 +535,10 @@ static void restore_stack_limit(lua_State* L) int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t ef) { - int status; unsigned short oldnCcalls = L->nCcalls; ptrdiff_t old_ci = saveci(L, L->ci); - status = luaD_rawrunprotected(L, func, u); + int oldactive = luaC_threadactive(L); + int status = luaD_rawrunprotected(L, func, u); if (status != 0) { // call user-defined error function (used in xpcall) @@ -541,6 +549,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } + if (FFlag::LuauActivateBeforeExec) + { + // since the call failed with an error, we might have to reset the 'active' thread state + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } + if (FFlag::LuauCcallRestoreFix) { // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 11f79d1a..ab416041 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -10,9 +10,7 @@ #include "ldo.h" #include -#include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -988,7 +986,7 @@ void luaC_barriertable(lua_State* L, Table* t, GCObject* v) GCObject* o = obj2gco(t); // in the second propagation stage, table assignment barrier works as a forward barrier - if (FFlag::LuauRescanGrayAgainForwardBarrier && g->gcstate == GCSpropagateagain) + if (g->gcstate == GCSpropagateagain) { LUAU_ASSERT(isblack(o) && iswhite(v) && !isdead(g, v) && !isdead(g, o)); reallymarkobject(g, v); @@ -1044,550 +1042,6 @@ void luaC_linkupval(lua_State* L, UpVal* uv) } } -static void validateobjref(global_State* g, GCObject* f, GCObject* t) -{ - LUAU_ASSERT(!isdead(g, t)); - - if (keepinvariant(g)) - { - /* basic incremental invariant: black can't point to white */ - LUAU_ASSERT(!(isblack(f) && iswhite(t))); - } -} - -static void validateref(global_State* g, GCObject* f, TValue* v) -{ - if (iscollectable(v)) - { - LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); - validateobjref(g, f, gcvalue(v)); - } -} - -static void validatetable(global_State* g, Table* h) -{ - int sizenode = 1 << h->lsizenode; - - if (FFlag::LuauArrayBoundary) - LUAU_ASSERT(h->lastfree <= sizenode); - else - LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); - - if (h->metatable) - validateobjref(g, obj2gco(h), obj2gco(h->metatable)); - - for (int i = 0; i < h->sizearray; ++i) - validateref(g, obj2gco(h), &h->array[i]); - - for (int i = 0; i < sizenode; ++i) - { - LuaNode* n = &h->node[i]; - - LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); - LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); - - if (!ttisnil(gval(n))) - { - TValue k = {}; - k.tt = gkey(n)->tt; - k.value = gkey(n)->value; - - validateref(g, obj2gco(h), &k); - validateref(g, obj2gco(h), gval(n)); - } - } -} - -static void validateclosure(global_State* g, Closure* cl) -{ - validateobjref(g, obj2gco(cl), obj2gco(cl->env)); - - if (cl->isC) - { - for (int i = 0; i < cl->nupvalues; ++i) - validateref(g, obj2gco(cl), &cl->c.upvals[i]); - } - else - { - LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); - - validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); - - for (int i = 0; i < cl->nupvalues; ++i) - validateref(g, obj2gco(cl), &cl->l.uprefs[i]); - } -} - -static void validatestack(global_State* g, lua_State* l) -{ - validateref(g, obj2gco(l), gt(l)); - - for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) - { - LUAU_ASSERT(l->stack <= ci->base); - LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); - LUAU_ASSERT(ci->top <= l->stack_last); - } - - // note: stack refs can violate gc invariant so we only check for liveness - for (StkId o = l->stack; o < l->top; ++o) - checkliveness(g, o); - - if (l->namecall) - validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - - for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) - { - LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); - LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); - } -} - -static void validateproto(global_State* g, Proto* f) -{ - if (f->source) - validateobjref(g, obj2gco(f), obj2gco(f->source)); - - if (f->debugname) - validateobjref(g, obj2gco(f), obj2gco(f->debugname)); - - for (int i = 0; i < f->sizek; ++i) - validateref(g, obj2gco(f), &f->k[i]); - - for (int i = 0; i < f->sizeupvalues; ++i) - if (f->upvalues[i]) - validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); - - for (int i = 0; i < f->sizep; ++i) - if (f->p[i]) - validateobjref(g, obj2gco(f), obj2gco(f->p[i])); - - for (int i = 0; i < f->sizelocvars; i++) - if (f->locvars[i].varname) - validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); -} - -static void validateobj(global_State* g, GCObject* o) -{ - /* dead objects can only occur during sweep */ - if (isdead(g, o)) - { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - return; - } - - switch (o->gch.tt) - { - case LUA_TSTRING: - break; - - case LUA_TTABLE: - validatetable(g, gco2h(o)); - break; - - case LUA_TFUNCTION: - validateclosure(g, gco2cl(o)); - break; - - case LUA_TUSERDATA: - if (gco2u(o)->metatable) - validateobjref(g, o, obj2gco(gco2u(o)->metatable)); - break; - - case LUA_TTHREAD: - validatestack(g, gco2th(o)); - break; - - case LUA_TPROTO: - validateproto(g, gco2p(o)); - break; - - case LUA_TUPVAL: - validateref(g, o, gco2uv(o)->v); - break; - - default: - LUAU_ASSERT(!"unexpected object type"); - } -} - -static void validatelist(global_State* g, GCObject* o) -{ - while (o) - { - validateobj(g, o); - - o = o->gch.next; - } -} - -static void validategraylist(global_State* g, GCObject* o) -{ - if (!keepinvariant(g)) - return; - - while (o) - { - LUAU_ASSERT(isgray(o)); - - switch (o->gch.tt) - { - case LUA_TTABLE: - o = gco2h(o)->gclist; - break; - case LUA_TFUNCTION: - o = gco2cl(o)->gclist; - break; - case LUA_TTHREAD: - o = gco2th(o)->gclist; - break; - case LUA_TPROTO: - o = gco2p(o)->gclist; - break; - default: - LUAU_ASSERT(!"unknown object in gray list"); - return; - } - } -} - -void luaC_validate(lua_State* L) -{ - global_State* g = L->global; - - LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); - checkliveness(g, &g->registry); - - for (int i = 0; i < LUA_T_COUNT; ++i) - if (g->mt[i]) - LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); - - validategraylist(g, g->weak); - validategraylist(g, g->gray); - validategraylist(g, g->grayagain); - - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, g->strt.hash[i]); - - validatelist(g, g->rootgc); - validatelist(g, g->strbufgc); - - for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) - { - LUAU_ASSERT(uv->tt == LUA_TUPVAL); - LUAU_ASSERT(uv->v != &uv->u.value); - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); - } -} - -inline bool safejson(char ch) -{ - return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; -} - -static void dumpref(FILE* f, GCObject* o) -{ - fprintf(f, "\"%p\"", o); -} - -static void dumprefs(FILE* f, TValue* data, size_t size) -{ - bool first = true; - - for (size_t i = 0; i < size; ++i) - { - if (iscollectable(&data[i])) - { - if (!first) - fputc(',', f); - first = false; - - dumpref(f, gcvalue(&data[i])); - } - } -} - -static void dumpstringdata(FILE* f, const char* data, size_t len) -{ - for (size_t i = 0; i < len; ++i) - fputc(safejson(data[i]) ? data[i] : '?', f); -} - -static void dumpstring(FILE* f, TString* ts) -{ - fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); - dumpstringdata(f, ts->data, ts->len); - fprintf(f, "\"}"); -} - -static void dumptable(FILE* f, Table* h) -{ - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); - - fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); - - if (h->node != &luaH_dummynode) - { - fprintf(f, ",\"pairs\":["); - - bool first = true; - - for (int i = 0; i < sizenode(h); ++i) - { - const LuaNode& n = h->node[i]; - - if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) - { - if (!first) - fputc(',', f); - first = false; - - if (iscollectable(&n.key)) - dumpref(f, gcvalue(&n.key)); - else - fprintf(f, "null"); - - fputc(',', f); - - if (iscollectable(&n.val)) - dumpref(f, gcvalue(&n.val)); - else - fprintf(f, "null"); - } - } - - fprintf(f, "]"); - } - if (h->sizearray) - { - fprintf(f, ",\"array\":["); - dumprefs(f, h->array, h->sizearray); - fprintf(f, "]"); - } - if (h->metatable) - { - fprintf(f, ",\"metatable\":"); - dumpref(f, obj2gco(h->metatable)); - } - fprintf(f, "}"); -} - -static void dumpclosure(FILE* f, Closure* cl) -{ - fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, - cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); - - fprintf(f, ",\"env\":"); - dumpref(f, obj2gco(cl->env)); - if (cl->isC) - { - if (cl->nupvalues) - { - fprintf(f, ",\"upvalues\":["); - dumprefs(f, cl->c.upvals, cl->nupvalues); - fprintf(f, "]"); - } - } - else - { - fprintf(f, ",\"proto\":"); - dumpref(f, obj2gco(cl->l.p)); - if (cl->nupvalues) - { - fprintf(f, ",\"upvalues\":["); - dumprefs(f, cl->l.uprefs, cl->nupvalues); - fprintf(f, "]"); - } - } - fprintf(f, "}"); -} - -static void dumpudata(FILE* f, Udata* u) -{ - fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); - - if (u->metatable) - { - fprintf(f, ",\"metatable\":"); - dumpref(f, obj2gco(u->metatable)); - } - fprintf(f, "}"); -} - -static void dumpthread(FILE* f, lua_State* th) -{ - size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; - - fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); - - if (iscollectable(&th->l_gt)) - { - fprintf(f, ",\"env\":"); - dumpref(f, gcvalue(&th->l_gt)); - } - - Closure* tcl = 0; - for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) - { - if (ttisfunction(ci->func)) - { - tcl = clvalue(ci->func); - break; - } - } - - if (tcl && !tcl->isC && tcl->l.p->source) - { - Proto* p = tcl->l.p; - - fprintf(f, ",\"source\":\""); - dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); - } - - if (th->top > th->stack) - { - fprintf(f, ",\"stack\":["); - dumprefs(f, th->stack, th->top - th->stack); - fprintf(f, "]"); - } - fprintf(f, "}"); -} - -static void dumpproto(FILE* f, Proto* p) -{ - size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; - - fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); - - if (p->source) - { - fprintf(f, ",\"source\":\""); - dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); - } - - if (p->sizek) - { - fprintf(f, ",\"constants\":["); - dumprefs(f, p->k, p->sizek); - fprintf(f, "]"); - } - - if (p->sizep) - { - fprintf(f, ",\"protos\":["); - for (int i = 0; i < p->sizep; ++i) - { - if (i != 0) - fputc(',', f); - dumpref(f, obj2gco(p->p[i])); - } - fprintf(f, "]"); - } - - fprintf(f, "}"); -} - -static void dumpupval(FILE* f, UpVal* uv) -{ - fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); - - if (iscollectable(uv->v)) - { - fprintf(f, ",\"object\":"); - dumpref(f, gcvalue(uv->v)); - } - fprintf(f, "}"); -} - -static void dumpobj(FILE* f, GCObject* o) -{ - switch (o->gch.tt) - { - case LUA_TSTRING: - return dumpstring(f, gco2ts(o)); - - case LUA_TTABLE: - return dumptable(f, gco2h(o)); - - case LUA_TFUNCTION: - return dumpclosure(f, gco2cl(o)); - - case LUA_TUSERDATA: - return dumpudata(f, gco2u(o)); - - case LUA_TTHREAD: - return dumpthread(f, gco2th(o)); - - case LUA_TPROTO: - return dumpproto(f, gco2p(o)); - - case LUA_TUPVAL: - return dumpupval(f, gco2uv(o)); - - default: - LUAU_ASSERT(0); - } -} - -static void dumplist(FILE* f, GCObject* o) -{ - while (o) - { - dumpref(f, o); - fputc(':', f); - dumpobj(f, o); - fputc(',', f); - fputc('\n', f); - - // thread has additional list containing collectable objects that are not present in rootgc - if (o->gch.tt == LUA_TTHREAD) - dumplist(f, gco2th(o)->openupval); - - o = o->gch.next; - } -} - -void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) -{ - global_State* g = L->global; - FILE* f = static_cast(file); - - fprintf(f, "{\"objects\":{\n"); - dumplist(f, g->rootgc); - dumplist(f, g->strbufgc); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, g->strt.hash[i]); - - fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , - fprintf(f, "},\"roots\":{\n"); - fprintf(f, "\"mainthread\":"); - dumpref(f, obj2gco(g->mainthread)); - fprintf(f, ",\"registry\":"); - dumpref(f, gcvalue(&g->registry)); - - fprintf(f, "},\"stats\":{\n"); - - fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); - - fprintf(f, "\"categories\":{\n"); - for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) - { - if (size_t bytes = g->memcatbytes[i]) - { - if (categoryName) - fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); - else - fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); - } - } - fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , - fprintf(f, "}\n"); - fprintf(f, "}}\n"); -} - // measure the allocation rate in bytes/sec // returns -1 if allocation rate cannot be measured int64_t luaC_allocationrate(lua_State* L) diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp new file mode 100644 index 00000000..a79e7b95 --- /dev/null +++ b/VM/src/lgcdebug.cpp @@ -0,0 +1,558 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lgc.h" + +#include "lobject.h" +#include "lstate.h" +#include "ltable.h" +#include "lfunc.h" +#include "lstring.h" + +#include +#include + +LUAU_FASTFLAG(LuauArrayBoundary) + +static void validateobjref(global_State* g, GCObject* f, GCObject* t) +{ + LUAU_ASSERT(!isdead(g, t)); + + if (keepinvariant(g)) + { + /* basic incremental invariant: black can't point to white */ + LUAU_ASSERT(!(isblack(f) && iswhite(t))); + } +} + +static void validateref(global_State* g, GCObject* f, TValue* v) +{ + if (iscollectable(v)) + { + LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); + validateobjref(g, f, gcvalue(v)); + } +} + +static void validatetable(global_State* g, Table* h) +{ + int sizenode = 1 << h->lsizenode; + + if (FFlag::LuauArrayBoundary) + LUAU_ASSERT(h->lastfree <= sizenode); + else + LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + + if (h->metatable) + validateobjref(g, obj2gco(h), obj2gco(h->metatable)); + + for (int i = 0; i < h->sizearray; ++i) + validateref(g, obj2gco(h), &h->array[i]); + + for (int i = 0; i < sizenode; ++i) + { + LuaNode* n = &h->node[i]; + + LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); + LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); + + if (!ttisnil(gval(n))) + { + TValue k = {}; + k.tt = gkey(n)->tt; + k.value = gkey(n)->value; + + validateref(g, obj2gco(h), &k); + validateref(g, obj2gco(h), gval(n)); + } + } +} + +static void validateclosure(global_State* g, Closure* cl) +{ + validateobjref(g, obj2gco(cl), obj2gco(cl->env)); + + if (cl->isC) + { + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->c.upvals[i]); + } + else + { + LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); + + validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); + + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->l.uprefs[i]); + } +} + +static void validatestack(global_State* g, lua_State* l) +{ + validateref(g, obj2gco(l), gt(l)); + + for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) + { + LUAU_ASSERT(l->stack <= ci->base); + LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); + LUAU_ASSERT(ci->top <= l->stack_last); + } + + // note: stack refs can violate gc invariant so we only check for liveness + for (StkId o = l->stack; o < l->top; ++o) + checkliveness(g, o); + + if (l->namecall) + validateobjref(g, obj2gco(l), obj2gco(l->namecall)); + + for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) + { + LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); + LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); + } +} + +static void validateproto(global_State* g, Proto* f) +{ + if (f->source) + validateobjref(g, obj2gco(f), obj2gco(f->source)); + + if (f->debugname) + validateobjref(g, obj2gco(f), obj2gco(f->debugname)); + + for (int i = 0; i < f->sizek; ++i) + validateref(g, obj2gco(f), &f->k[i]); + + for (int i = 0; i < f->sizeupvalues; ++i) + if (f->upvalues[i]) + validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); + + for (int i = 0; i < f->sizep; ++i) + if (f->p[i]) + validateobjref(g, obj2gco(f), obj2gco(f->p[i])); + + for (int i = 0; i < f->sizelocvars; i++) + if (f->locvars[i].varname) + validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); +} + +static void validateobj(global_State* g, GCObject* o) +{ + /* dead objects can only occur during sweep */ + if (isdead(g, o)) + { + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + return; + } + + switch (o->gch.tt) + { + case LUA_TSTRING: + break; + + case LUA_TTABLE: + validatetable(g, gco2h(o)); + break; + + case LUA_TFUNCTION: + validateclosure(g, gco2cl(o)); + break; + + case LUA_TUSERDATA: + if (gco2u(o)->metatable) + validateobjref(g, o, obj2gco(gco2u(o)->metatable)); + break; + + case LUA_TTHREAD: + validatestack(g, gco2th(o)); + break; + + case LUA_TPROTO: + validateproto(g, gco2p(o)); + break; + + case LUA_TUPVAL: + validateref(g, o, gco2uv(o)->v); + break; + + default: + LUAU_ASSERT(!"unexpected object type"); + } +} + +static void validatelist(global_State* g, GCObject* o) +{ + while (o) + { + validateobj(g, o); + + o = o->gch.next; + } +} + +static void validategraylist(global_State* g, GCObject* o) +{ + if (!keepinvariant(g)) + return; + + while (o) + { + LUAU_ASSERT(isgray(o)); + + switch (o->gch.tt) + { + case LUA_TTABLE: + o = gco2h(o)->gclist; + break; + case LUA_TFUNCTION: + o = gco2cl(o)->gclist; + break; + case LUA_TTHREAD: + o = gco2th(o)->gclist; + break; + case LUA_TPROTO: + o = gco2p(o)->gclist; + break; + default: + LUAU_ASSERT(!"unknown object in gray list"); + return; + } + } +} + +void luaC_validate(lua_State* L) +{ + global_State* g = L->global; + + LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); + checkliveness(g, &g->registry); + + for (int i = 0; i < LUA_T_COUNT; ++i) + if (g->mt[i]) + LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); + + validategraylist(g, g->weak); + validategraylist(g, g->gray); + validategraylist(g, g->grayagain); + + for (int i = 0; i < g->strt.size; ++i) + validatelist(g, g->strt.hash[i]); + + validatelist(g, g->rootgc); + validatelist(g, g->strbufgc); + + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + { + LUAU_ASSERT(uv->tt == LUA_TUPVAL); + LUAU_ASSERT(uv->v != &uv->u.value); + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + } +} + +inline bool safejson(char ch) +{ + return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; +} + +static void dumpref(FILE* f, GCObject* o) +{ + fprintf(f, "\"%p\"", o); +} + +static void dumprefs(FILE* f, TValue* data, size_t size) +{ + bool first = true; + + for (size_t i = 0; i < size; ++i) + { + if (iscollectable(&data[i])) + { + if (!first) + fputc(',', f); + first = false; + + dumpref(f, gcvalue(&data[i])); + } + } +} + +static void dumpstringdata(FILE* f, const char* data, size_t len) +{ + for (size_t i = 0; i < len; ++i) + fputc(safejson(data[i]) ? data[i] : '?', f); +} + +static void dumpstring(FILE* f, TString* ts) +{ + fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); + dumpstringdata(f, ts->data, ts->len); + fprintf(f, "\"}"); +} + +static void dumptable(FILE* f, Table* h) +{ + size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + + fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); + + if (h->node != &luaH_dummynode) + { + fprintf(f, ",\"pairs\":["); + + bool first = true; + + for (int i = 0; i < sizenode(h); ++i) + { + const LuaNode& n = h->node[i]; + + if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) + { + if (!first) + fputc(',', f); + first = false; + + if (iscollectable(&n.key)) + dumpref(f, gcvalue(&n.key)); + else + fprintf(f, "null"); + + fputc(',', f); + + if (iscollectable(&n.val)) + dumpref(f, gcvalue(&n.val)); + else + fprintf(f, "null"); + } + } + + fprintf(f, "]"); + } + if (h->sizearray) + { + fprintf(f, ",\"array\":["); + dumprefs(f, h->array, h->sizearray); + fprintf(f, "]"); + } + if (h->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(h->metatable)); + } + fprintf(f, "}"); +} + +static void dumpclosure(FILE* f, Closure* cl) +{ + fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, + cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); + + fprintf(f, ",\"env\":"); + dumpref(f, obj2gco(cl->env)); + if (cl->isC) + { + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->c.upvals, cl->nupvalues); + fprintf(f, "]"); + } + } + else + { + fprintf(f, ",\"proto\":"); + dumpref(f, obj2gco(cl->l.p)); + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->l.uprefs, cl->nupvalues); + fprintf(f, "]"); + } + } + fprintf(f, "}"); +} + +static void dumpudata(FILE* f, Udata* u) +{ + fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); + + if (u->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(u->metatable)); + } + fprintf(f, "}"); +} + +static void dumpthread(FILE* f, lua_State* th) +{ + size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; + + fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); + + if (iscollectable(&th->l_gt)) + { + fprintf(f, ",\"env\":"); + dumpref(f, gcvalue(&th->l_gt)); + } + + Closure* tcl = 0; + for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) + { + if (ttisfunction(ci->func)) + { + tcl = clvalue(ci->func); + break; + } + } + + if (tcl && !tcl->isC && tcl->l.p->source) + { + Proto* p = tcl->l.p; + + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (th->top > th->stack) + { + fprintf(f, ",\"stack\":["); + dumprefs(f, th->stack, th->top - th->stack); + fprintf(f, "]"); + } + fprintf(f, "}"); +} + +static void dumpproto(FILE* f, Proto* p) +{ + size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; + + fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); + + if (p->source) + { + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (p->sizek) + { + fprintf(f, ",\"constants\":["); + dumprefs(f, p->k, p->sizek); + fprintf(f, "]"); + } + + if (p->sizep) + { + fprintf(f, ",\"protos\":["); + for (int i = 0; i < p->sizep; ++i) + { + if (i != 0) + fputc(',', f); + dumpref(f, obj2gco(p->p[i])); + } + fprintf(f, "]"); + } + + fprintf(f, "}"); +} + +static void dumpupval(FILE* f, UpVal* uv) +{ + fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); + + if (iscollectable(uv->v)) + { + fprintf(f, ",\"object\":"); + dumpref(f, gcvalue(uv->v)); + } + fprintf(f, "}"); +} + +static void dumpobj(FILE* f, GCObject* o) +{ + switch (o->gch.tt) + { + case LUA_TSTRING: + return dumpstring(f, gco2ts(o)); + + case LUA_TTABLE: + return dumptable(f, gco2h(o)); + + case LUA_TFUNCTION: + return dumpclosure(f, gco2cl(o)); + + case LUA_TUSERDATA: + return dumpudata(f, gco2u(o)); + + case LUA_TTHREAD: + return dumpthread(f, gco2th(o)); + + case LUA_TPROTO: + return dumpproto(f, gco2p(o)); + + case LUA_TUPVAL: + return dumpupval(f, gco2uv(o)); + + default: + LUAU_ASSERT(0); + } +} + +static void dumplist(FILE* f, GCObject* o) +{ + while (o) + { + dumpref(f, o); + fputc(':', f); + dumpobj(f, o); + fputc(',', f); + fputc('\n', f); + + // thread has additional list containing collectable objects that are not present in rootgc + if (o->gch.tt == LUA_TTHREAD) + dumplist(f, gco2th(o)->openupval); + + o = o->gch.next; + } +} + +void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) +{ + global_State* g = L->global; + FILE* f = static_cast(file); + + fprintf(f, "{\"objects\":{\n"); + dumplist(f, g->rootgc); + dumplist(f, g->strbufgc); + for (int i = 0; i < g->strt.size; ++i) + dumplist(f, g->strt.hash[i]); + + fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , + fprintf(f, "},\"roots\":{\n"); + fprintf(f, "\"mainthread\":"); + dumpref(f, obj2gco(g->mainthread)); + fprintf(f, ",\"registry\":"); + dumpref(f, gcvalue(&g->registry)); + + fprintf(f, "},\"stats\":{\n"); + + fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); + + fprintf(f, "\"categories\":{\n"); + for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) + { + if (size_t bytes = g->memcatbytes[i]) + { + if (categoryName) + fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); + else + fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); + } + } + fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , + fprintf(f, "}\n"); + fprintf(f, "}}\n"); +} diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index 4e40165a..c93f431f 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -17,7 +17,7 @@ static const luaL_Reg lualibs[] = { {NULL, NULL}, }; -LUALIB_API void luaL_openlibs(lua_State* L) +void luaL_openlibs(lua_State* L) { const luaL_Reg* lib = lualibs; for (; lib->func; lib++) @@ -28,7 +28,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) } } -LUALIB_API void luaL_sandbox(lua_State* L) +void luaL_sandbox(lua_State* L) { // set all libraries to read-only lua_pushnil(L); @@ -44,14 +44,14 @@ LUALIB_API void luaL_sandbox(lua_State* L) lua_pushliteral(L, ""); lua_getmetatable(L, -1); lua_setreadonly(L, -1, true); - lua_pop(L, 1); + lua_pop(L, 2); // set globals to readonly and activate safeenv since the env is immutable lua_setreadonly(L, LUA_GLOBALSINDEX, true); lua_setsafeenv(L, LUA_GLOBALSINDEX, true); } -LUALIB_API void luaL_sandboxthread(lua_State* L) +void luaL_sandboxthread(lua_State* L) { // create new global table that proxies reads to original table lua_newtable(L); @@ -81,7 +81,7 @@ static void* l_alloc(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsi return realloc(ptr, nsize); } -LUALIB_API lua_State* luaL_newstate(void) +lua_State* luaL_newstate(void) { return lua_newstate(l_alloc, NULL); } diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 8e476a52..a6e7b494 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -385,8 +385,7 @@ static int math_sign(lua_State* L) static int math_round(lua_State* L) { - double v = luaL_checknumber(L, 1); - lua_pushnumber(L, round(v)); + lua_pushnumber(L, round(luaL_checknumber(L, 1))); return 1; } @@ -429,7 +428,7 @@ static const luaL_Reg mathlib[] = { /* ** Open math library */ -LUALIB_API int luaopen_math(lua_State* L) +int luaopen_math(lua_State* L) { uint64_t seed = uintptr_t(L); seed ^= time(NULL); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index d8b265cb..9f9d4a98 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -33,11 +33,17 @@ #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : ms32) #endif +#if LUA_VECTOR_SIZE == 4 +static_assert(sizeof(TValue) == ABISWITCH(24, 24, 24), "size mismatch for value"); +static_assert(sizeof(LuaNode) == ABISWITCH(48, 48, 48), "size mismatch for table entry"); +#else static_assert(sizeof(TValue) == ABISWITCH(16, 16, 16), "size mismatch for value"); +static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); +#endif + static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); -static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 43f8014b..67f832dc 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -18,12 +18,20 @@ inline bool luai_veceq(const float* a, const float* b) { +#if LUA_VECTOR_SIZE == 4 + return a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3]; +#else return a[0] == b[0] && a[1] == b[1] && a[2] == b[2]; +#endif } inline bool luai_vecisnan(const float* a) { +#if LUA_VECTOR_SIZE == 4 + return a[0] != a[0] || a[1] != a[1] || a[2] != a[2] || a[3] != a[3]; +#else return a[0] != a[0] || a[1] != a[1] || a[2] != a[2]; +#endif } LUAU_FASTMATH_BEGIN diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index bf13e6e9..370c7b28 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -15,7 +15,7 @@ -const TValue luaO_nilobject_ = {{NULL}, LUA_TNIL}; +const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; int luaO_log2(unsigned int x) { diff --git a/VM/src/lobject.h b/VM/src/lobject.h index c5f2e2f4..ba040af6 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -47,7 +47,7 @@ typedef union typedef struct lua_TValue { Value value; - int extra; + int extra[LUA_EXTRA_SIZE]; int tt; } TValue; @@ -105,7 +105,19 @@ typedef struct lua_TValue i_o->tt = LUA_TNUMBER; \ } -#define setvvalue(obj, x, y, z) \ +#if LUA_VECTOR_SIZE == 4 +#define setvvalue(obj, x, y, z, w) \ + { \ + TValue* i_o = (obj); \ + float* i_v = i_o->value.v; \ + i_v[0] = (x); \ + i_v[1] = (y); \ + i_v[2] = (z); \ + i_v[3] = (w); \ + i_o->tt = LUA_TVECTOR; \ + } +#else +#define setvvalue(obj, x, y, z, w) \ { \ TValue* i_o = (obj); \ float* i_v = i_o->value.v; \ @@ -114,6 +126,7 @@ typedef struct lua_TValue i_v[2] = (z); \ i_o->tt = LUA_TVECTOR; \ } +#endif #define setpvalue(obj, x) \ { \ @@ -364,7 +377,7 @@ typedef struct Closure typedef struct TKey { ::Value value; - int extra; + int extra[LUA_EXTRA_SIZE]; unsigned tt : 4; int next : 28; /* for chaining */ } TKey; @@ -381,7 +394,7 @@ typedef struct LuaNode LuaNode* n_ = (node); \ const TValue* i_o = (obj); \ n_->key.value = i_o->value; \ - n_->key.extra = i_o->extra; \ + memcpy(n_->key.extra, i_o->extra, sizeof(n_->key.extra)); \ n_->key.tt = i_o->tt; \ checkliveness(L->global, i_o); \ } @@ -392,7 +405,7 @@ typedef struct LuaNode TValue* i_o = (obj); \ const LuaNode* n_ = (node); \ i_o->value = n_->key.value; \ - i_o->extra = n_->key.extra; \ + memcpy(i_o->extra, n_->key.extra, sizeof(i_o->extra)); \ i_o->tt = n_->key.tt; \ checkliveness(L->global, i_o); \ } diff --git a/VM/src/loslib.cpp b/VM/src/loslib.cpp index 8eaef60c..b5901865 100644 --- a/VM/src/loslib.cpp +++ b/VM/src/loslib.cpp @@ -186,7 +186,7 @@ static const luaL_Reg syslib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_os(lua_State* L) +int luaopen_os(lua_State* L) { luaL_register(L, LUA_OSLIBNAME, syslib); return 1; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index b576f809..0b3054ae 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -1657,7 +1657,7 @@ static void createmetatable(lua_State* L) /* ** Open string library */ -LUALIB_API int luaopen_string(lua_State* L) +int luaopen_string(lua_State* L) { luaL_register(L, LUA_STRLIBNAME, strlib); createmetatable(L); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 07d22d59..0b55fcea 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -31,18 +31,19 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) #define MAXSIZE (1 << MAXBITS) static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect"); + // TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case -static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); -static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); -static_assert(TKey{{NULL}, 0, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); +static_assert(TKey{{NULL}, {0}, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); +static_assert(TKey{{NULL}, {0}, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); +static_assert(TKey{{NULL}, {0}, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); // reset cache of absent metamethods, cache is updated in luaT_gettm #define invalidateTMcache(t) t->flags = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { - {{NULL}, 0, LUA_TNIL}, /* value */ - {{NULL}, 0, LUA_TNIL, 0} /* key */ + {{NULL}, {0}, LUA_TNIL}, /* value */ + {{NULL}, {0}, LUA_TNIL, 0} /* key */ }; #define dummynode (&luaH_dummynode) @@ -96,7 +97,7 @@ static LuaNode* hashnum(const Table* t, double n) static LuaNode* hashvec(const Table* t, const float* v) { - unsigned int i[3]; + unsigned int i[LUA_VECTOR_SIZE]; memcpy(i, v, sizeof(i)); // convert -0 to 0 to make sure they hash to the same value @@ -112,6 +113,12 @@ static LuaNode* hashvec(const Table* t, const float* v) // Optimized Spatial Hashing for Collision Detection of Deformable Objects unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); +#if LUA_VECTOR_SIZE == 4 + i[3] = (i[3] == 0x8000000) ? 0 : i[3]; + i[3] ^= i[3] >> 17; + h ^= i[3] * 39916801; +#endif + return hashpow2(t, h); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 37025818..0d3374ef 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -527,7 +527,7 @@ static const luaL_Reg tab_funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_table(lua_State* L) +int luaopen_table(lua_State* L) { luaL_register(L, LUA_TABLIBNAME, tab_funcs); diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 378de3d0..8bc8200a 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -283,7 +283,7 @@ static const luaL_Reg funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_utf8(lua_State* L) +int luaopen_utf8(lua_State* L) { luaL_register(L, LUA_UTF8LIBNAME, funcs); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index eed2862b..bf8d493e 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -601,7 +601,13 @@ static void luau_execute(lua_State* L) const char* name = getstr(tsvalue(kv)); int ic = (name[0] | ' ') - 'x'; - if (unsigned(ic) < 3 && name[1] == '\0') +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') { setnvalue(ra, rb->value.v[ic]); VM_NEXT(); @@ -1526,7 +1532,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); VM_NEXT(); } else @@ -1572,7 +1578,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); VM_NEXT(); } else @@ -1618,21 +1624,21 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(rc)); - setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); VM_NEXT(); } else if (ttisvector(rb) && ttisvector(rc)) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); VM_NEXT(); } else if (ttisnumber(rb) && ttisvector(rc)) { float vb = cast_to(float, nvalue(rb)); const float* vc = rc->value.v; - setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2]); + setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2], vb * vc[3]); VM_NEXT(); } else @@ -1679,21 +1685,21 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(rc)); - setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); VM_NEXT(); } else if (ttisvector(rb) && ttisvector(rc)) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); VM_NEXT(); } else if (ttisnumber(rb) && ttisvector(rc)) { float vb = cast_to(float, nvalue(rb)); const float* vc = rc->value.v; - setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2]); + setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2], vb / vc[3]); VM_NEXT(); } else @@ -1826,7 +1832,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(kv)); - setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); VM_NEXT(); } else @@ -1872,7 +1878,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(kv)); - setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); VM_NEXT(); } else @@ -2037,7 +2043,7 @@ static void luau_execute(lua_State* L) else if (ttisvector(rb)) { const float* vb = rb->value.v; - setvvalue(ra, -vb[0], -vb[1], -vb[2]); + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); VM_NEXT(); } else diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index a168b652..add3588d 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -9,6 +9,7 @@ #include "lgc.h" #include "lmem.h" #include "lbytecode.h" +#include "lapi.h" #include @@ -162,9 +163,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; - // env is 0 for current environment and a stack relative index otherwise - LUAU_ASSERT(env <= 0 && L->top - L->base >= -env); - Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(L->top + env); + // env is 0 for current environment and a stack index otherwise + Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index f52e8e74..740a4cfd 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -401,19 +401,19 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_ADD: - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); return; case TM_SUB: - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); return; case TM_MUL: - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); return; case TM_DIV: - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); return; case TM_UNM: - setvvalue(ra, -vb[0], -vb[1], -vb[2]); + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); return; default: break; @@ -430,10 +430,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_MUL: - setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc); + setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); return; case TM_DIV: - setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc); + setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); return; default: break; @@ -451,10 +451,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_MUL: - setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2]); + setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); return; case TM_DIV: - setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2]); + setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); return; default: break; diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index 4be78502..fe5b4eb2 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index 4be78502..fe5b4eb2 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index de962a74..79cbe38b 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -275,7 +275,7 @@ local function memory(s) local t=os.clock() --local dt=string.format("%f",t-t0) local dt=t-t0 - --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage"count"/1024),"M\n") + --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage("count")/1024),"M\n") t0=t end @@ -286,7 +286,7 @@ local function do_(f,s) end local function julia(l,a,b) -memory"begin" +memory("begin") cx=a cy=b root=newcell() exterior=newcell() exterior.color=white @@ -297,14 +297,14 @@ memory"begin" do_(update,"update") repeat N=0 color(root,Rxmin,Rxmax,Rymin,Rymax) --print("color",N) - until N==0 memory"color" + until N==0 memory("color") repeat N=0 prewhite(root,Rxmin,Rxmax,Rymin,Rymax) --print("prewhite",N) - until N==0 memory"prewhite" + until N==0 memory("prewhite") do_(recolor,"recolor") do_(colorup,"colorup") --print("colorup",N) local g,b=do_(area,"area") --print("area",g,b,g+b) - show(i) memory"output" + show(i) memory("output") --print("edges",nE) end end diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index ae2399e4..27e53492 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -23,9 +23,11 @@ const bool kFuzzCompiler = true; const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; -const bool kFuzzTypes = true; const bool kFuzzTranspile = true; +// Should we generate type annotations? +const bool kFuzzTypes = true; + static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); std::string protoprint(const luau::StatBlock& stat, bool types); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index aa53a92b..2090b014 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -78,3 +78,26 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("AstQuery"); + +TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") +{ + ScopedFastFlag luauTailArgumentTypeInfo{"LuauTailArgumentTypeInfo", true}; + + check(R"( +local function foo() return 2 end +local function bar(a: number) return -a end +bar(foo()) + )"); + + auto oty = findTypeAtPosition(Position(3, 7)); + REQUIRE(oty); + CHECK_EQ("number", toString(*oty)); + + auto expectedOty = findExpectedTypeAtPosition(Position(3, 7)); + REQUIRE(expectedOty); + CHECK_EQ("number", toString(*expectedOty)); +} + +TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 5a7c8602..3b74a99e 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1935,6 +1935,39 @@ return target(b@1 CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); + + check(R"( +local function bar(a: number) return -a end +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + +TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); + + check(R"( +local function foo() return 1 end +local function bar(a: number) return -a end +local abc = bar(@1) + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("foo")); + CHECK(ac.entryMap["foo"].parens == ParenthesesRecommendation::CursorAfter); +} + TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( @@ -2210,8 +2243,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require") { - ScopedFastFlag luauResolveModuleNameWithoutACurrentModule("LuauResolveModuleNameWithoutACurrentModule", true); - std::string_view source = R"( local a = require(w -- Line 1 -- | Column 27 @@ -2287,8 +2318,6 @@ until TEST_CASE_FIXTURE(ACFixture, "if_then_else_elseif_completions") { - ScopedFastFlag sff{"ElseElseIfCompletionImprovements", true}; - check(R"( local elsewhere = false @@ -2585,9 +2614,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - check(R"( type A = () -> T... local a: A<(number, s@1> diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 4ce8d08a..6ba39ada 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1057,6 +1057,18 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return if false then 10 else 20"), R"( LOADN R0 20 RETURN R0 1 +)"); + + // codegen for a true constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if true then {} else error()"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 +)"); + + // codegen for a false constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if false then error() else {}"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 )"); // codegen for a false (in this case 'nil') constant condition @@ -2360,6 +2372,58 @@ Foo:Bar( )"); } +TEST_CASE("DebugLineInfoCallChain") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo = ... + +Foo +:Bar(1) +:Baz(2) +.Qux(3) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 1 +5: LOADN R4 1 +5: NAMECALL R2 R0 K0 +5: CALL R2 2 1 +6: LOADN R4 2 +6: NAMECALL R2 R2 K1 +6: CALL R2 2 1 +7: GETTABLEKS R1 R2 K2 +7: LOADN R2 3 +7: CALL R1 1 0 +8: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoFastCall") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo, Bar = ... + +return + math.max( + Foo, + Bar) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 2 +5: FASTCALL2 18 R0 R1 +5 +5: MOVE R3 R0 +5: MOVE R4 R1 +5: GETIMPORT R2 2 +5: CALL R2 2 -1 +5: RETURN R2 -1 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -3742,4 +3806,108 @@ RETURN R0 0 )"); } +TEST_CASE("ConstantsNoFolding") +{ + const char* source = "return nil, true, 42, 'hello'"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 0; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 +LOADK R3 K1 +RETURN R0 4 +)"); +} + +TEST_CASE("VectorFastCall") +{ + const char* source = "return Vector3.new(1, 2, 3)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +FASTCALL 54 +2 +GETIMPORT R0 2 +CALL R0 3 -1 +RETURN R0 -1 +)"); +} + +TEST_CASE("TypeAssertion") +{ + // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated + CHECK_EQ("\n" + compileFunction0(R"( +print(foo() :: typeof(error("compile time"))) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + // note that above, foo() is treated as single-arg function; removing type assertion changes the bytecode + CHECK_EQ("\n" + compileFunction0(R"( +print(foo()) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 -1 +CALL R0 -1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("Arithmetics") +{ + // basic arithmetics codegen with non-constants + CHECK_EQ("\n" + compileFunction0(R"( +local a, b = ... +return a + b, a - b, a / b, a * b, a % b, a ^ b +)"), + R"( +GETVARARGS R0 2 +ADD R2 R0 R1 +SUB R3 R0 R1 +DIV R4 R0 R1 +MUL R5 R0 R1 +MOD R6 R0 R1 +POW R7 R0 R1 +RETURN R2 6 +)"); + + // basic arithmetics codegen with constants on the right side + // note that we don't simplify these expressions as we don't know the type of a + CHECK_EQ("\n" + compileFunction0(R"( +local a = ... +return a + 1, a - 1, a / 1, a * 1, a % 1, a ^ 1 +)"), + R"( +GETVARARGS R0 1 +ADDK R1 R0 K0 +SUBK R2 R0 K0 +DIVK R3 R0 K0 +MULK R4 R0 K0 +MODK R5 R0 K0 +POWK R6 R0 K0 +RETURN R1 6 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e495a213..b2aad316 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -67,44 +67,42 @@ static int lua_vector(lua_State* L) double y = luaL_checknumber(L, 2); double z = luaL_checknumber(L, 3); +#if LUA_VECTOR_SIZE == 4 + double w = luaL_optnumber(L, 4, 0.0); + lua_pushvector(L, float(x), float(y), float(z), float(w)); +#else lua_pushvector(L, float(x), float(y), float(z)); +#endif return 1; } static int lua_vector_dot(lua_State* L) { - const float* a = lua_tovector(L, 1); - const float* b = lua_tovector(L, 2); + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); - if (a && b) - { - lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); - return 1; - } - - throw std::runtime_error("invalid arguments to vector:Dot"); + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); + return 1; } static int lua_vector_index(lua_State* L) { + const float* v = luaL_checkvector(L, 1); const char* name = luaL_checkstring(L, 2); - if (const float* v = lua_tovector(L, 1)) + if (strcmp(name, "Magnitude") == 0) { - if (strcmp(name, "Magnitude") == 0) - { - lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); - return 1; - } - - if (strcmp(name, "Dot") == 0) - { - lua_pushcfunction(L, lua_vector_dot, "Dot"); - return 1; - } + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); + return 1; } - throw std::runtime_error(Luau::format("%s is not a valid member of vector", name)); + if (strcmp(name, "Dot") == 0) + { + lua_pushcfunction(L, lua_vector_dot, "Dot"); + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); } static int lua_vector_namecall(lua_State* L) @@ -115,7 +113,7 @@ static int lua_vector_namecall(lua_State* L) return lua_vector_dot(L); } - throw std::runtime_error(Luau::format("%s is not a valid method of vector", luaL_checkstring(L, 1))); + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); } int lua_silence(lua_State* L) @@ -373,11 +371,17 @@ TEST_CASE("Pack") TEST_CASE("Vector") { + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + runConformance("vector.lua", [](lua_State* L) { lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); @@ -504,6 +508,9 @@ TEST_CASE("Debugger") cb->debugbreak = [](lua_State* L, lua_Debug* ar) { breakhits++; + // make sure we can trace the stack for every breakpoint we hit + lua_debugtrace(L); + // for every breakpoint, we break on the first invocation and continue on second // this allows us to easily step off breakpoints // (real implementaiton may require singlestepping) @@ -524,7 +531,7 @@ TEST_CASE("Debugger") L, [](lua_State* L) -> int { int line = luaL_checkinteger(L, 1); - bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; + bool enabled = luaL_optboolean(L, 2, true); lua_Debug ar = {}; lua_getinfo(L, 1, "f", &ar); @@ -699,21 +706,52 @@ TEST_CASE("ApiFunctionCalls") StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_call(L, 2, 1); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_call + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_call(L, 2, 1); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_pcall(L, 2, 1, 0); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_pcall + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_pcall(L, 2, 1, 0); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } + + // lua_equal with a sleeping thread wake up + { + ScopedFastFlag luauActivateBeforeExec("LuauActivateBeforeExec", true); + + lua_State* L2 = lua_newthread(L); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + // Reset GC + lua_gc(L2, LUA_GCCOLLECT, 0); + + // Try to mark 'L2' as sleeping + // Can't control GC precisely, even in tests + lua_gc(L2, LUA_GCSTEP, 8); + + CHECK(lua_equal(L2, -1, -2) == 1); + lua_pop(L2, 2); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -727,8 +765,6 @@ static bool endsWith(const std::string& str, const std::string& suffix) #if !LUA_USE_LONGJMP TEST_CASE("ExceptionObject") { - ScopedFastFlag sff("LuauExceptionMessageFix", true); - struct ExceptionResult { bool exceptionGenerated; diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 29c33f7c..36d6f561 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -19,19 +19,6 @@ static const char* mainModuleName = "MainModule"; namespace Luau { -std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const -{ - auto g = expr->as(); - if (!g) - return std::nullopt; - - std::string_view value = g->name.value; - if (value == "game" || value == "Game" || value == "workspace" || value == "Workspace" || value == "script" || value == "Script") - return ModuleName(value); - - return std::nullopt; -} - std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) { if (AstExprGlobal* g = expr->as()) @@ -81,24 +68,6 @@ std::optional TestFileResolver::resolveModule(const ModuleInfo* cont return std::nullopt; } -ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const -{ - return lhs + "/" + ModuleName(rhs); -} - -std::optional TestFileResolver::getParentModuleName(const ModuleName& name) const -{ - std::string_view view = name; - const size_t lastSeparatorIndex = view.find_last_of('/'); - - if (lastSeparatorIndex != std::string_view::npos) - { - return ModuleName(view.substr(0, lastSeparatorIndex)); - } - - return std::nullopt; -} - std::string TestFileResolver::getHumanReadableModuleName(const ModuleName& name) const { return name; @@ -324,6 +293,13 @@ std::optional Fixture::findTypeAtPosition(Position position) return Luau::findTypeAtPosition(*module, *sourceModule, position); } +std::optional Fixture::findExpectedTypeAtPosition(Position position) +{ + ModulePtr module = getMainModule(); + SourceModule* sourceModule = getMainSourceModule(); + return Luau::findExpectedTypeAtPosition(*module, *sourceModule, position); +} + TypeId Fixture::requireTypeAtPosition(Position position) { auto ty = findTypeAtPosition(position); diff --git a/tests/Fixture.h b/tests/Fixture.h index 1480a7f6..de2b7381 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -64,12 +64,8 @@ struct TestFileResolver return SourceCode{it->second, sourceType}; } - std::optional fromAstFragment(AstExpr* expr) const override; std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; - std::optional getParentModuleName(const ModuleName& name) const override; - std::string getHumanReadableModuleName(const ModuleName& name) const override; std::optional getEnvironmentForModule(const ModuleName& name) const override; @@ -126,6 +122,7 @@ struct Fixture std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); + std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index fbfec636..51fcd3d6 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -46,18 +46,6 @@ NaiveModuleResolver naiveModuleResolver; struct NaiveFileResolver : NullFileResolver { - std::optional fromAstFragment(AstExpr* expr) const override - { - AstExprGlobal* g = expr->as(); - if (g && g->name == "Modules") - return "Modules"; - - if (g && g->name == "game") - return "game"; - - return std::nullopt; - } - std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override { if (AstExprGlobal* g = expr->as()) @@ -86,11 +74,6 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } - - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + ModuleName(rhs); - } }; } // namespace diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 7ba40c50..1d13df28 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1469,6 +1469,22 @@ _ = true and true or false -- no warning since this is is a common pattern used CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); } +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") +{ + LintResult result = lint(R"( +local correct, opaque = ... + +if correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) then +end +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 4"); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 5); +} + TEST_CASE_FIXTURE(Fixture, "DuplicateLocal") { LintResult result = lint(R"( diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7a3543c7..2800d2fe 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -44,9 +44,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } @@ -56,12 +57,13 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -89,9 +91,10 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks); + TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -142,11 +145,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks); + TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -159,11 +163,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks); + TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -188,8 +193,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks); + TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -211,16 +217,16 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TypeArena dest; SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ("any", toString(clonedTy)); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); - encounteredFreeType = false; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); + cloneState = {}; + TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ("...any", toString(clonedTp)); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); } TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") @@ -232,12 +238,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TypeArena dest; SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); const TableTypeVar* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); } TEST_CASE_FIXTURE(Fixture, "clone_self_property") @@ -267,4 +273,34 @@ TEST_CASE_FIXTURE(Fixture, "clone_self_property") "dot or pass 1 extra nil to suppress this warning"); } +TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") +{ +#if defined(_DEBUG) || defined(_NOOPT) + int limit = 250; +#else + int limit = 500; +#endif + ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + + TypeArena src; + + TypeId table = src.addType(TableTypeVar{}); + TypeId nested = table; + + for (unsigned i = 0; i < limit + 100; i++) + { + TableTypeVar* ttv = getMutable(nested); + + ttv->props["a"].type = src.addType(TableTypeVar{}); + nested = ttv->props["a"].type; + } + + TypeArena dest; + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + CloneState cloneState; + + CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e3e6ce6d..72d3a9a6 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2518,8 +2518,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") { - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - AstStat* stat = parse(R"( type Packed = () -> T... diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp new file mode 100644 index 00000000..0ca9c994 --- /dev/null +++ b/tests/ToDot.test.cpp @@ -0,0 +1,366 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" +#include "Luau/ToDot.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +struct ToDotClassFixture : Fixture +{ + ToDotClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId baseClassMetaType = arena.addType(TableTypeVar{}); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}}); + getMutable(baseClassInstanceType)->props = { + {"BaseField", {typeChecker.numberType}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}}); + getMutable(childClassInstanceType)->props = { + {"ChildField", {typeChecker.stringType}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + + freeze(arena); + } +}; + +TEST_SUITE_BEGIN("ToDot"); + +TEST_CASE_FIXTURE(Fixture, "primitive") +{ + CheckResult result = check(R"( +local a: nil +local b: number +local c: any +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_NE("nil", toDot(requireType("a"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="number"]; +})", + toDot(requireType("b"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="any"]; +})", + toDot(requireType("c"))); + + ToDotOptions opts; + opts.showPointers = false; + opts.duplicatePrimitives = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="PrimitiveTypeVar number"]; +})", + toDot(requireType("b"), opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="AnyTypeVar 1"]; +})", + toDot(requireType("c"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound") +{ + CheckResult result = check(R"( +local a = 444 +local b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "function") +{ + ScopedFastFlag luauQuantifyInPlace2{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( +local function f(a, ...: string) return a end +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="BoundTypePack 6"]; +n6 -> n7; +n7 [label="TypePack 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "union") +{ + CheckResult result = check(R"( +local a: string | number +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection") +{ + CheckResult result = check(R"( +local a: string & number -- uninhabited +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="IntersectionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table") +{ + CheckResult result = check(R"( +type A = { x: T, y: (U...) -> (), [string]: any } +local a: A +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar A"]; +n1 -> n2 [label="x"]; +n2 [label="number"]; +n1 -> n3 [label="y"]; +n3 [label="FunctionTypeVar 3"]; +n3 -> n4 [label="arg"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n3 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n1 -> n7 [label="[index]"]; +n7 [label="string"]; +n1 -> n8 [label="[value]"]; +n8 [label="any"]; +n1 -> n9 [label="typeParam"]; +n9 [label="number"]; +n1 -> n4 [label="typePackParam"]; +})", + toDot(requireType("a"), opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + CheckResult result = check(R"( +local a: typeof(setmetatable({}, {})) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="MetatableTypeVar 1"]; +n1 -> n2 [label="table"]; +n2 [label="TableTypeVar 2"]; +n1 -> n3 [label="metatable"]; +n3 [label="TableTypeVar 3"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free") +{ + TypeVar type{TypeVariant{FreeTypeVar{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error") +{ + TypeVar type{TypeVariant{ErrorTypeVar{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "generic") +{ + TypeVar type{TypeVariant{GenericTypeVar{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypeVar T"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(ToDotClassFixture, "class") +{ + CheckResult result = check(R"( +local a: ChildClass +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ClassTypeVar ChildClass"]; +n1 -> n2 [label="ChildField"]; +n2 [label="string"]; +n1 -> n3 [label="[parent]"]; +n3 [label="ClassTypeVar BaseClass"]; +n3 -> n4 [label="BaseField"]; +n4 [label="number"]; +n3 -> n5 [label="[metatable]"]; +n5 [label="TableTypeVar 5"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free_pack") +{ + TypePackVar pack{TypePackVariant{FreeTypePack{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypePack 1"]; +})", + toDot(&pack, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error_pack") +{ + TypePackVar pack{TypePackVariant{Unifiable::Error{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypePack 1"]; +})", + toDot(&pack, opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(&pack); +} + +TEST_CASE_FIXTURE(Fixture, "generic_pack") +{ + TypePackVar pack1{TypePackVariant{GenericTypePack{}}}; + TypePackVar pack2{TypePackVariant{GenericTypePack{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack 1"]; +})", + toDot(&pack1, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack T"]; +})", + toDot(&pack2, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_pack") +{ + TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypePack 1"]; +n1 -> n2; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="number"]; +})", + toDot(&bound, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_table") +{ + CheckResult result = check(R"( +local a = {x=2} +local b +b.x = 2 +b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar 1"]; +n1 -> n2 [label="boundTo"]; +n2 [label="TableTypeVar a"]; +n2 -> n3 [label="x"]; +n3 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 928c03a3..327fa0bb 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -445,9 +445,6 @@ local a: Import.Type TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - std::string code = R"( type Packed = (T...)->(T...) local a: Packed<> diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 74ce155c..822bd727 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -537,8 +537,6 @@ TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") { - ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; - CheckResult result = check(R"( type Array = { [number]: T } type Tuple = Array diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 88c2dc85..aba50891 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -609,8 +609,6 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") { - ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; - CheckResult result = check(R"( local exports = {} local nested = {} @@ -627,4 +625,23 @@ return exports LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") +{ + ScopedFastFlag luauFunctionArgumentNameSize{"LuauFunctionArgumentNameSize", true}; + + CheckResult result = check(R"( +local function f(a: T, ...: U...) end + +f(1, 2, 3) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = findTypeAtPosition(Position(3, 0)); + REQUIRE(ty); + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e5c14dde..e6d3d4d4 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -31,8 +31,6 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - const std::string code = R"( function f(a) if type(a) == "boolean" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index c3694be7..cb72faaf 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2022,4 +2022,74 @@ caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local tmp = { p = { x = 5, y = 7 }} +local a: HasSuper = tmp +a.p = { x = 9 } +-- needs to be an error because +local y: number = tmp.p.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' +caused by: + Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { [string] : Super } +type HasSub = { [string] : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 99fd8339..e3222a41 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4779,4 +4779,24 @@ local bar = foo.nutrition + 100 // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); } +TEST_CASE_FIXTURE(Fixture, "require_failed_module") +{ + ScopedFastFlag luauModuleRequireErrorPack{"LuauModuleRequireErrorPack", true}; + + fileResolver.source["game/A"] = R"( +return unfortunately() + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(aResult); + + CheckResult result = check(R"( +local ModuleA = require(game.A) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional oty = requireType("ModuleA"); + CHECK_EQ("*unknown*", toString(*oty)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index c6de0abf..3f4420cd 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -296,9 +296,6 @@ end TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed = (T...) -> T... local a: Packed<> @@ -360,9 +357,6 @@ local c: Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } return {} @@ -393,9 +387,6 @@ local d: { a: typeof(c) } TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } return {} @@ -431,9 +422,6 @@ type C = Import.Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed1 = (T...) -> (T...) type Packed2 = (Packed1, T...) -> (Packed1, T...) @@ -452,9 +440,6 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = (T...) -> (string, T...) @@ -470,9 +455,6 @@ type E = X<(number, ...string)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) type A = Y @@ -501,9 +483,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = (T...) -> (T...) @@ -527,9 +506,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) @@ -549,9 +525,6 @@ type D = Y TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = { f: (T...) -> (U...) } @@ -567,9 +540,6 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = () -> T type Y = (T) -> U @@ -588,9 +558,6 @@ type C = Y<(number), boolean> TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed = (T, U) -> (V...) local b: Packed diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 91efa818..13db923e 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -3,6 +3,7 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -323,4 +324,48 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } +struct VisitCountTracker +{ + std::unordered_map tyVisits; + std::unordered_map tpVisits; + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + template + bool operator()(TypeId ty, const T& t) + { + tyVisits[ty]++; + return true; + } + + template + bool operator()(TypePackId tp, const T&) + { + tpVisits[tp]++; + return true; + } +}; + +TEST_CASE_FIXTURE(Fixture, "visit_once") +{ + CheckResult result = check(R"( +type T = { a: number, b: () -> () } +local b: (T, T, T) -> T +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId bType = requireType("b"); + + VisitCountTracker tester; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(bType, tester, seen); + + for (auto [_, count] : tester.tyVisits) + CHECK_EQ(count, 1); + + for (auto [_, count] : tester.tpVisits) + CHECK_EQ(count, 1); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 5e03b055..7a4058b5 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -2,7 +2,13 @@ print('testing function calls through API') function add(a, b) - return a + b + return a + b +end + +local m = { __eq = function(a, b) return a.a == b.a end } + +function create_with_tm(x) + return setmetatable({ a = x }, m) end return('OK') diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 687fff1e..188b8ebc 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -441,7 +441,8 @@ assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r en assert((function() a = {} b = {} function eq(l, r) return #l == #r end setmetatable(a, {__eq = eq}) setmetatable(b, {__eq = eq}) return concat(a == b, a ~= b) end)() == "true,false") assert((function() a = {} b = {} setmetatable(a, {__eq = function(l, r) return #l == #r end}) setmetatable(b, {__eq = function(l, r) return #l == #r end}) return concat(a == b, a ~= b) end)() == "false,true") --- userdata, reference equality (no mt) +-- userdata, reference equality (no mt or mt.__eq) +assert((function() a = newproxy() return concat(a == newproxy(),a ~= newproxy()) end)() == "false,true") assert((function() a = newproxy(true) return concat(a == newproxy(true),a ~= newproxy(true)) end)() == "false,true") -- rawequal @@ -876,4 +877,4 @@ assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number testgetfenv() -- DONT MOVE THIS LINE -return'OK' +return 'OK' diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index aac42c56..f32d5bdc 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -419,11 +419,5 @@ co = coroutine.create(function () return loadstring("return a")() end) -a = {a = 15} --- debug.setfenv(co, a) --- assert(debug.getfenv(co) == a) --- assert(select(2, coroutine.resume(co)) == a) --- assert(select(2, coroutine.resume(co)) == a.a) - -return'OK' +return 'OK' diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua index 16c63b00..f133501f 100644 --- a/tests/conformance/constructs.lua +++ b/tests/conformance/constructs.lua @@ -237,4 +237,4 @@ repeat i = i+1 until i==c -return'OK' +return 'OK' diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 4d9b1295..f2ecc96b 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -373,4 +373,4 @@ do assert(f() == 42) end -return'OK' +return 'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index 21ef60d7..ca35cf2f 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -74,4 +74,4 @@ assert(os.difftime(t1,t2) == 60*2-19) assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) -return'OK' +return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index ee79a14f..9cf3c742 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,4 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") -return'OK' +return 'OK' diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index eded14e9..d5ff215b 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -34,15 +34,15 @@ assert(doit("error('hi', 0)") == 'hi') assert(doit("unpack({}, 1, n=2^30)")) assert(doit("a=math.sin()")) assert(not doit("tostring(1)") and doit("tostring()")) -assert(doit"tonumber()") -assert(doit"repeat until 1; a") +assert(doit("tonumber()")) +assert(doit("repeat until 1; a")) checksyntax("break label", "", "label", 1) -assert(doit";") -assert(doit"a=1;;") -assert(doit"return;;") -assert(doit"assert(false)") -assert(doit"assert(nil)") -assert(doit"a=math.sin\n(3)") +assert(doit(";")) +assert(doit("a=1;;")) +assert(doit("return;;")) +assert(doit("assert(false)")) +assert(doit("assert(nil)")) +assert(doit("a=math.sin\n(3)")) assert(doit("function a (... , ...) end")) assert(doit("function a (, ...) end")) @@ -59,7 +59,7 @@ checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)", "local 'bbbb'") checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") -assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) +assert(not string.find(doit("a={13}; local bbbb=1; a[bbbb](3)"), "'bbbb'")) checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") aaa = nil @@ -67,14 +67,14 @@ checkmessage("aaa.bbb:ddd(9)", "global 'aaa'") checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") -assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") +assert(not doit("local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)")) checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") checkmessage("aaa={}; x=3/aaa", "global 'aaa'") checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") checkmessage("aaa={}; x=-aaa", "global 'aaa'") -assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) -assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) +assert(not string.find(doit("aaa={}; x=(aaa or aaa)+(aaa and aaa)"), "'aaa'")) +assert(not string.find(doit("aaa={}; (aaa or aaa)()"), "'aaa'")) checkmessage([[aaa=9 repeat until 3==3 @@ -122,10 +122,10 @@ function lineerror (s) return line and line+0 end -assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) --- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) --- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) -assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) +assert(lineerror("local a\n for i=1,'a' do \n print(i) \n end") == 2) +-- assert(lineerror("\n local a \n for k,v in 3 \n do \n print(k) \n end") == 3) +-- assert(lineerror("\n\n for k,v in \n 3 \n do \n print(k) \n end") == 4) +assert(lineerror("function a.x.y ()\na=a+1\nend") == 1) local p = [[ function g() f() end diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 4263dfda..6d9eb854 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -77,7 +77,7 @@ end local function dosteps (siz) collectgarbage() - collectgarbage"stop" + collectgarbage("stop") local a = {} for i=1,100 do a[i] = {{}}; local b = {} end local x = gcinfo() @@ -99,11 +99,11 @@ assert(dosteps(10000) == 1) do local x = gcinfo() collectgarbage() - collectgarbage"stop" + collectgarbage("stop") repeat local a = {} until gcinfo() > 1000 - collectgarbage"restart" + collectgarbage("restart") repeat local a = {} until gcinfo() < 1000 @@ -123,7 +123,7 @@ for n in pairs(b) do end b = nil collectgarbage() -for n in pairs(a) do error'cannot be here' end +for n in pairs(a) do error("cannot be here") end for i=1,lim do a[i] = i end for i=1,lim do assert(a[i] == i) end diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index 7f9b7596..94ba5ccf 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,9 +368,9 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error"not here" end -for i=1,0 do error'not here' end -for i=0,1,-1 do error'not here' end +for a,b in pairs{} do error("not here") end +for i=1,0 do error("not here") end +for i=0,1,-1 do error("not here") end a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index a2072d2c..84ac2ba1 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,4 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) -return'OK' +return 'OK' diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.lua index 024cb16d..bfd7a1ac 100644 --- a/tests/conformance/utf8.lua +++ b/tests/conformance/utf8.lua @@ -205,4 +205,4 @@ for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do end end -return'OK' +return 'OK' diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 620f646a..7d18bda3 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -1,6 +1,9 @@ -- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details print('testing vectors') +-- detect vector size +local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 + -- equality assert(vector(1, 2, 3) == vector(1, 2, 3)) assert(vector(0, 1, 2) == vector(-0, 1, 2)) @@ -13,8 +16,14 @@ assert(not rawequal(vector(1, 2, 3), vector(1, 2, 4))) -- type & tostring assert(type(vector(1, 2, 3)) == "vector") -assert(tostring(vector(1, 2, 3)) == "1, 2, 3") -assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") + +if vector_size == 4 then + assert(tostring(vector(1, 2, 3, 4)) == "1, 2, 3, 4") + assert(tostring(vector(-1, 2, 0.5, 0)) == "-1, 2, 0.5, 0") +else + assert(tostring(vector(1, 2, 3)) == "1, 2, 3") + assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") +end local t = {} @@ -42,12 +51,19 @@ assert(8 * vector(8, 16, 24) == vector(64, 128, 192)); assert(vector(1, 2, 4) * '8' == vector(8, 16, 32)); assert('8' * vector(8, 16, 24) == vector(64, 128, 192)); -assert(vector(1, 2, 4) / vector(8, 16, 24) == vector(1/8, 2/16, 4/24)); +if vector_size == 4 then + assert(vector(1, 2, 4, 8) / vector(8, 16, 24, 32) == vector(1/8, 2/16, 4/24, 8/32)); + assert(8 / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); + assert('8' / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); +else + assert(vector(1, 2, 4) / vector(8, 16, 24, 1) == vector(1/8, 2/16, 4/24)); + assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); + assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); +end + assert(vector(1, 2, 4) / 8 == vector(1/8, 1/4, 1/2)); assert(vector(1, 2, 4) / (1 / val) == vector(1/8, 2/8, 4/8)); -assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(vector(1, 2, 4) / '8' == vector(1/8, 1/4, 1/2)); -assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(-vector(1, 2, 4) == vector(-1, -2, -4)); @@ -71,4 +87,9 @@ assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == fal -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) +-- additional checks for 4-component vectors +if vector_size == 4 then + assert(vector(1, 2, 3, 4).w == 4) +end + return 'OK' diff --git a/tools/svg.py b/tools/svg.py index 3b3bb28c..99853fb6 100644 --- a/tools/svg.py +++ b/tools/svg.py @@ -458,13 +458,16 @@ def display(root, title, colors, flip = False): framewidth = 1200 - 20 + def pixels(x): + return float(x) / root.width * framewidth if root.width > 0 else 0 + for n in root.subtree(): - if n.width / root.width * framewidth < 0.1: + if pixels(n.width) < 0.1: continue - x = 10 + n.offset / root.width * framewidth + x = 10 + pixels(n.offset) y = (maxdepth - 1 - n.depth if flip else n.depth) * 16 + 3 * 16 - width = n.width / root.width * framewidth + width = pixels(n.width) height = 15 if colors == "cold": From e440729e2bb98aba0bdb41109e257de522b857e4 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 2 Dec 2021 15:46:33 -0800 Subject: [PATCH 009/102] Fix signed/unsigned mismatch warning + lower limit to match upstream --- tests/Module.test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe..e3993cc5 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); From a8673f0f99885da92597d6efb4feaf0f14d59991 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 10 Dec 2021 13:17:10 -0800 Subject: [PATCH 010/102] Sync to upstream/release/507-pre This doesn't contain all changes for 507 yet but we might want to do the Luau 0.507 release a bit earlier to end the year sooner. --- Analysis/include/Luau/Error.h | 11 +- Analysis/include/Luau/IostreamHelpers.h | 4 + Analysis/include/Luau/ToString.h | 4 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeVar.h | 8 +- Analysis/include/Luau/Unifiable.h | 11 +- Analysis/include/Luau/Unifier.h | 3 + Analysis/src/Autocomplete.cpp | 115 +++++++-- Analysis/src/BuiltinDefinitions.cpp | 9 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 12 +- Analysis/src/Error.cpp | 17 +- Analysis/src/IostreamHelpers.cpp | 6 + Analysis/src/JsonEncoder.cpp | 4 +- Analysis/src/Module.cpp | 19 +- Analysis/src/Predicate.cpp | 2 +- Analysis/src/ToString.cpp | 38 ++- Analysis/src/TypeInfer.cpp | 191 +++++++++----- Analysis/src/TypeUtils.cpp | 6 +- Analysis/src/TypeVar.cpp | 220 ++++------------ Analysis/src/Unifier.cpp | 266 +++++++++++++++++--- Ast/src/Parser.cpp | 5 +- CLI/Analyze.cpp | 20 +- CLI/Coverage.cpp | 88 +++++++ CLI/Coverage.h | 10 + CLI/FileUtils.cpp | 4 + CLI/Profiler.cpp | 8 +- CLI/Profiler.h | 2 +- CLI/Repl.cpp | 35 ++- Compiler/src/Compiler.cpp | 22 +- Makefile | 6 +- Sources.cmake | 4 + VM/include/lua.h | 4 + VM/src/lapi.cpp | 162 ++++++------ VM/src/ldebug.cpp | 63 +++++ VM/src/lgc.cpp | 49 +--- VM/src/lgcdebug.cpp | 1 + VM/src/lobject.h | 10 +- VM/src/lstring.cpp | 29 --- VM/src/lstring.h | 7 - VM/src/ludata.cpp | 37 +++ VM/src/ludata.h | 13 + VM/src/lvmexecute.cpp | 14 +- VM/src/lvmutils.cpp | 4 +- bench/tests/sunspider/3d-raytrace.lua | 15 +- tests/Autocomplete.test.cpp | 55 +++- tests/Compiler.test.cpp | 49 +--- tests/Conformance.test.cpp | 100 ++++++-- tests/Module.test.cpp | 4 +- tests/Parser.test.cpp | 4 - tests/TypeInfer.annotations.test.cpp | 30 +++ tests/TypeInfer.builtins.test.cpp | 26 ++ tests/TypeInfer.generics.test.cpp | 40 +++ tests/TypeInfer.refinements.test.cpp | 17 +- tests/TypeInfer.singletons.test.cpp | 50 ++++ tests/TypeInfer.tables.test.cpp | 61 +++-- tests/TypeInfer.test.cpp | 239 +++++++++++++++++- tests/TypeInfer.tryUnify.test.cpp | 28 +++ tests/TypeInfer.unionTypes.test.cpp | 16 ++ tests/conformance/coverage.lua | 64 +++++ 59 files changed, 1704 insertions(+), 644 deletions(-) create mode 100644 CLI/Coverage.cpp create mode 100644 CLI/Coverage.h create mode 100644 VM/src/ludata.cpp create mode 100644 VM/src/ludata.h create mode 100644 tests/conformance/coverage.lua diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 9ee75004..aff3c4d9 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -277,11 +277,20 @@ struct MissingUnionProperty bool operator==(const MissingUnionProperty& rhs) const; }; +struct TypesAreUnrelated +{ + TypeId left; + TypeId right; + + bool operator==(const TypesAreUnrelated& rhs) const; +}; + using TypeErrorData = Variant; + DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, + TypesAreUnrelated>; struct TypeError { diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index f9e9cd48..ee994296 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -36,6 +36,10 @@ std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); +std::ostream& operator<<(std::ostream& lhs, const SwappedGenericTypeParameter& error); +std::ostream& operator<<(std::ostream& lhs, const OptionalValueAccess& error); +std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error); +std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error); std::ostream& operator<<(std::ostream& lhs, const TableState& tv); std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 50379c1c..a97bf6d6 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,8 +69,8 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression -void dump(TypeId ty); -void dump(TypePackId ty); +std::string dump(TypeId ty); +std::string dump(TypePackId ty); std::string generateName(size_t n); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9f553bc1..451976e4 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -156,13 +156,14 @@ struct TypeChecker // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). // Note: the binding may be null. + // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); - TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); + TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalNameLoc, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); @@ -174,7 +175,7 @@ struct TypeChecker ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -277,7 +278,7 @@ public: [[noreturn]] void ice(const std::string& message); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); - ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + ScopePtr childScope(const ScopePtr& parent, const Location& location); // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8c4c2f34..f6829ec3 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -499,6 +499,7 @@ struct SingletonTypes const TypePackId anyTypePack; SingletonTypes(); + ~SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; @@ -509,10 +510,12 @@ struct SingletonTypes private: std::unique_ptr arena; + bool debugFreezeArena = false; + TypeId makeStringMetatable(); }; -extern SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes(); void persist(TypeId ty); void persist(TypePackId tp); @@ -523,9 +526,6 @@ TypeLevel* getMutableLevel(TypeId ty); const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); -bool hasGeneric(TypeId ty); -bool hasGeneric(TypePackId tp); - TypeVar* asMutable(TypeId ty); template diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index b47610fc..e8eafe68 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -24,7 +24,7 @@ struct TypeLevel int level = 0; int subLevel = 0; - // Returns true if the typelevel "this" is "bigger" than rhs + // Returns true if the level of "this" belongs to an equal or larger scope than that of rhs bool subsumes(const TypeLevel& rhs) const { if (level < rhs.level) @@ -38,6 +38,15 @@ struct TypeLevel return false; } + // Returns true if the level of "this" belongs to a larger (not equal) scope than that of rhs + bool subsumesStrict(const TypeLevel& rhs) const + { + if (level == rhs.level && subLevel == rhs.subLevel) + return false; + else + return subsumes(rhs); + } + TypeLevel incr() const { TypeLevel result; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4588cdd8..7681b966 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -91,6 +91,9 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + // Available after regular type pack unification errors + std::optional firstPackErrorPos; }; } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index db2d1d0e..4b583792 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -190,7 +191,48 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve return ParenthesesRecommendation::None; } -static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, TypeId ty) +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); + + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionTypeVar* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) { ty = follow(ty); @@ -220,15 +262,29 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); - TypeId expectedType = follow(*it); + if (!typeAtPosition) + return TypeCorrectKind::None; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return TypeCorrectKind::None; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return TypeCorrectKind::None; + + expectedType = follow(*it); + } if (FFlag::LuauAutocompletePreferToCallFunctions) { @@ -333,8 +389,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (result.count(name) == 0 && name != Parser::errorName) { Luau::TypeId type = Luau::follow(prop.type); - TypeCorrectKind typeCorrect = - indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, nodes.back(), type); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -692,17 +748,31 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) return ret; } -static std::optional functionIsExpectedAt(const Module& module, AstNode* node) +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); - TypeId expectedType = follow(*it); + if (!typeAtPosition) + return std::nullopt; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + expectedType = follow(*it); + } if (get(expectedType)) return true; @@ -1171,7 +1241,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1181,9 +1251,10 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, typeChecker.booleanType); - TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; if (FFlag::LuauIfElseExpressionAnalysisSupport) result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bac94a2b..d527414a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -217,9 +217,9 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + std::optional stringMetatableTy = getMetatable(getSingletonTypes().stringType); LUAU_ASSERT(stringMetatableTy); const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); @@ -271,7 +271,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) persist(pair.second.typeId); if (TableTypeVar* ttv = getMutable(pair.second.typeId)) - ttv->name = toString(pair.first); + { + if (!ttv->name) + ttv->name = toString(pair.first); + } } attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 9f5c8250..d0afa742 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAGVARIABLE(LuauFixTonumberReturnType, false) + namespace Luau { @@ -113,7 +115,6 @@ declare function gcinfo(): number declare function error(message: T, level: number?) declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V @@ -204,7 +205,14 @@ declare function gcinfo(): number std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + if (FFlag::LuauFixTonumberReturnType) + result += "declare function tonumber(value: T, radix: number?): number?\n"; + else + result += "declare function tonumber(value: T, radix: number?): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 8334bd62..ce832c6b 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -58,7 +58,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + ". "; + result += tm.reason + " "; result += Luau::toString(*tm.error); } @@ -410,6 +410,11 @@ struct ErrorConverter return ss + " in the type '" + toString(e.type) + "'"; } + + std::string operator()(const TypesAreUnrelated& e) const + { + return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; + } }; struct InvalidNameChecker @@ -658,6 +663,11 @@ bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const return *type == *rhs.type && key == rhs.key; } +bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const +{ + return left == rhs.left && right == rhs.right; +} + std::string toString(const TypeError& error) { ErrorConverter converter; @@ -793,6 +803,11 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& for (auto& ty : e.missing) ty = clone(ty); } + else if constexpr (std::is_same_v) + { + e.left = clone(e.left); + e.right = clone(e.right); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index ac46b5a4..5bc76ade 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -262,6 +262,12 @@ std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error return stream << " }, key = '" + error.key + "' }"; } +std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) +{ + stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; + return stream; +} + std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index c7f623ee..23491a5a 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -262,7 +262,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(a); } @@ -379,7 +379,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(prop); } }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b4b6eb42..e1e53c97 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau @@ -23,7 +22,7 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) @@ -194,7 +193,7 @@ struct TypePackCloner { cloneState.encounteredFreeType = true; - TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); TypePackId cloned = dest.addTypePack(*err); seenTypePacks[typePackId] = cloned; } @@ -247,7 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { cloneState.encounteredFreeType = true; - TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; } @@ -421,9 +420,6 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -440,12 +436,11 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks { TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // TODO: Make this work when the arena of 'res' might be frozen asMutable(res)->documentationSymbol = typeId->documentationSymbol; } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -508,8 +503,8 @@ bool Module::clonePublicInterface() if (moduleScope->varargPack) moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); - for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); + for (auto& [name, tf] : moduleScope->exportedTypeBindings) + tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 848627cf..7bd8001e 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -24,7 +24,7 @@ std::optional tryGetLValue(const AstExpr& node) else if (auto indexexpr = expr->as()) { if (auto lvalue = tryGetLValue(*indexexpr->expr)) - if (auto string = indexexpr->expr->as()) + if (auto string = indexexpr->index->as()) return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 6322096c..a6be5348 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,6 +13,13 @@ LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) +/* + * Prefix generic typenames with gen- + * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 + * Fair warning: Setting this will break a lot of Luau unit tests. + */ +LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) + namespace Luau { @@ -290,7 +297,15 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(ftv.level.level)); + } } void operator()(TypeId, const BoundTypeVar& btv) @@ -802,6 +817,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("gen-"); if (pack.explicitName) { state.result.nameMap.typePacks[tp] = pack.name; @@ -817,7 +834,16 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(tp)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(pack.level.level)); + } + state.emit("..."); } @@ -1181,20 +1207,24 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } -void dump(TypeId ty) +std::string dump(TypeId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } -void dump(TypePackId ty) +std::string dump(TypePackId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } std::string generateName(size_t i) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 617bf482..abbc2901 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -9,9 +9,9 @@ #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" -#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/ToString.h" #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" @@ -29,7 +29,6 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) -LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) @@ -37,6 +36,12 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) +LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) +LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) +LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -206,14 +211,14 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan : resolver(resolver) , iceHandler(iceHandler) , unifierState(iceHandler) - , nilType(singletonTypes.nilType) - , numberType(singletonTypes.numberType) - , stringType(singletonTypes.stringType) - , booleanType(singletonTypes.booleanType) - , threadType(singletonTypes.threadType) - , anyType(singletonTypes.anyType) - , optionalNumberType(singletonTypes.optionalNumberType) - , anyTypePack(singletonTypes.anyTypePack) + , nilType(getSingletonTypes().nilType) + , numberType(getSingletonTypes().numberType) + , stringType(getSingletonTypes().stringType) + , booleanType(getSingletonTypes().booleanType) + , threadType(getSingletonTypes().threadType) + , anyType(getSingletonTypes().anyType) + , optionalNumberType(getSingletonTypes().optionalNumberType) + , anyTypePack(getSingletonTypes().anyTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -443,7 +448,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name); + TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); unify(leftType, funTy, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -711,14 +716,15 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else if (auto tail = valueIter.tail()) { - if (get(*tail)) + TypePackId tailPack = follow(*tail); + if (get(tailPack)) right = errorRecoveryType(scope); - else if (auto vtp = get(*tail)) + else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(*tail)) + else if (get(tailPack)) { - *asMutable(*tail) = TypePack{{left}}; - growingPack = getMutable(*tail); + *asMutable(tailPack) = TypePack{{left}}; + growingPack = getMutable(tailPack); } } @@ -1107,8 +1113,27 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco unify(leftType, ty, function.location); - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + if (FFlag::LuauUpdateFunctionNameBinding) + { + LUAU_ASSERT(function.name->is() || function.name->is()); + + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + { + if (auto ttv = getMutableTableType(*typeIt)) + { + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); + } + } + } + } + else + { + if (leftTypeBinding) + *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + } } } @@ -1148,8 +1173,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = - FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); + ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + if (FFlag::LuauProperTypeLevels) + aliasScope->level.subLevel = subLevel; auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); @@ -1166,6 +1193,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ice("Not predeclared"); ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); for (TypeId ty : binding->typeParams) { @@ -1505,9 +1533,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!"); + ice("Unexpected abstract type pack!", expr.location); else - ice("Unknown TypePack type!"); + ice("Unknown TypePack type!", expr.location); } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) @@ -1574,7 +1602,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = freshType(scope); + TypeId result = FFlag::LuauAscribeCorrectLevelToInferredProperitesOfFreeTables ? freshType(tableType->level) : freshType(scope); tableType->props[name] = {result}; return result; } @@ -1738,7 +1766,16 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - return {checkLValue(scope, expr)}; + TypeId ty = checkLValue(scope, expr); + + if (FFlag::LuauRefiLookupFromIndexExpr) + { + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + } + + return {ty}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2421,12 +2458,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy TypeId annotationType = resolveType(scope, *expr.annotation); ExprResult result = checkExpr(scope, *expr.expr, annotationType); - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); - reportErrors(errorVec); - if (!errorVec.empty()) - annotationType = errorRecoveryType(annotationType); + if (FFlag::LuauBidirectionalAsExpr) + { + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(result.type, annotationType, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; - return {annotationType, std::move(result.predicates)}; + if (canUnify(annotationType, result.type, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; + + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; + } + else + { + ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); + if (!errorVec.empty()) + annotationType = errorRecoveryType(annotationType); + + return {annotationType, std::move(result.predicates)}; + } } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) @@ -2674,8 +2726,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope // Answers the question: "Can I define another function with this name?" // Primarily about detecting duplicates. -TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) +TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { + auto freshTy = [&]() { + if (FFlag::LuauProperTypeLevels) + return freshType(level); + else + return freshType(scope); + }; + if (auto globalName = funName.as()) { const ScopePtr& globalScope = currentModule->getModuleScope(); @@ -2689,7 +2748,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = freshType(scope); + TypeId ty = freshTy(); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2699,7 +2758,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {freshType(scope), funName.location}; + binding = {freshTy(), funName.location}; return binding.typeId; } @@ -2730,7 +2789,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = freshType(scope); + property.type = freshTy(); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -3327,7 +3386,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A fn = follow(fn); if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } @@ -3402,9 +3461,11 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { + LUAU_ASSERT(argLocations); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) @@ -3428,31 +3489,44 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{retPack}}; } - const FunctionTypeVar* ftv = get(fn); - if (!ftv) + std::vector metaArgLocations; + + // Might be a callable table + if (const MetatableTypeVar* mttv = get(fn)) { - // Might be a callable table - if (const MetatableTypeVar* mttv = get(fn)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + + if (FFlag::LuauFixRecursiveMetatableCall) { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - - std::vector metaArgLocations = argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + fn = instantiate(scope, *ty, expr.func->location); + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; + } + else + { TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors); } } + } + const FunctionTypeVar* ftv = get(fn); + if (!ftv) + { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(retPack, errorRecoveryTypePack(scope), expr.func->location); return {{errorRecoveryTypePack(retPack)}}; @@ -3477,7 +3551,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -3772,7 +3846,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (moduleInfo.name.empty()) { - if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + if (currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); return errorRecoveryType(anyType); @@ -4268,9 +4342,11 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& } // Creates a new Scope and carries forward the varargs from the parent. -ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location, int subLevel) +ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { - ScopePtr scope = std::make_shared(parent, subLevel); + ScopePtr scope = std::make_shared(parent); + if (FFlag::LuauProperTypeLevels) + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4329,22 +4405,22 @@ TypeId TypeChecker::singletonType(std::string value) TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) { - return singletonTypes.errorRecoveryType(); + return getSingletonTypes().errorRecoveryType(); } TypeId TypeChecker::errorRecoveryType(TypeId guess) { - return singletonTypes.errorRecoveryType(guess); + return getSingletonTypes().errorRecoveryType(guess); } TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) { - return singletonTypes.errorRecoveryTypePack(); + return getSingletonTypes().errorRecoveryTypePack(); } TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) { - return singletonTypes.errorRecoveryTypePack(guess); + return getSingletonTypes().errorRecoveryTypePack(guess); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4547,6 +4623,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (const auto& func = annotation.as()) { ScopePtr funcScope = childScope(scope, func->location); + funcScope->level = scope->level.incr(); auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 0d9d91e0..8c6d5e49 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -19,7 +19,7 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa TypeId unwrapped = follow(*metatable); if (get(unwrapped)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) @@ -61,12 +61,12 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc { std::optional r = first(follow(itf->retType)); if (!r) - return singletonTypes.nilType; + return getSingletonTypes().nilType; else return *r; } else if (get(index)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 62715af5..571b13ca 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -579,11 +580,25 @@ SingletonTypes::SingletonTypes() , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); - stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; + stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; persist(stringMetatable); + + debugFreezeArena = FFlag::DebugLuauFreezeArena; freeze(*arena); } +SingletonTypes::~SingletonTypes() +{ + // Destroy the arena with the same memory management flags it was created with + bool prevFlag = FFlag::DebugLuauFreezeArena; + FFlag::DebugLuauFreezeArena.value = debugFreezeArena; + + unfreeze(*arena); + arena.reset(nullptr); + + FFlag::DebugLuauFreezeArena.value = prevFlag; +} + TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); @@ -641,6 +656,9 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + if (TableTypeVar* ttv = getMutable(tableType)) + ttv->name = "string"; + return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -670,7 +688,11 @@ TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) return &errorTypePack_; } -SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes() +{ + static SingletonTypes singletonTypes; + return singletonTypes; +} void persist(TypeId ty) { @@ -719,6 +741,18 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto mtv = get(t)) + { + queue.push_back(mtv->table); + queue.push_back(mtv->metatable); + } + else if (get(t) || get(t) || get(t) || get(t) || get(t)) + { + } + else + { + LUAU_ASSERT(!"TypeId is not supported in a persist call"); + } } } @@ -736,6 +770,17 @@ void persist(TypePackId tp) if (p->tail) persist(*p->tail); } + else if (auto vtp = get(tp)) + { + persist(vtp->ty); + } + else if (get(tp)) + { + } + else + { + LUAU_ASSERT(!"TypePackId is not supported in a persist call"); + } } const TypeLevel* getLevel(TypeId ty) @@ -757,167 +802,6 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } -struct QVarFinder -{ - mutable DenseHashSet seen; - - QVarFinder() - : seen(nullptr) - { - } - - bool hasSeen(const void* tv) const - { - if (seen.contains(tv)) - return true; - - seen.insert(tv); - return false; - } - - bool hasGeneric(TypeId tid) const - { - if (hasSeen(&tid->ty)) - return false; - - return Luau::visit(*this, tid->ty); - } - - bool hasGeneric(TypePackId tp) const - { - if (hasSeen(&tp->ty)) - return false; - - return Luau::visit(*this, tp->ty); - } - - bool operator()(const Unifiable::Free&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const Unifiable::Generic&) const - { - return true; - } - bool operator()(const Unifiable::Error&) const - { - return false; - } - bool operator()(const PrimitiveTypeVar&) const - { - return false; - } - - bool operator()(const SingletonTypeVar&) const - { - return false; - } - - bool operator()(const FunctionTypeVar& ftv) const - { - if (hasGeneric(ftv.argTypes)) - return true; - return hasGeneric(ftv.retType); - } - - bool operator()(const TableTypeVar& ttv) const - { - if (ttv.state == TableState::Generic) - return true; - - if (ttv.indexer) - { - if (hasGeneric(ttv.indexer->indexType)) - return true; - if (hasGeneric(ttv.indexer->indexResultType)) - return true; - } - - for (const auto& [_name, prop] : ttv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - return false; - } - - bool operator()(const MetatableTypeVar& mtv) const - { - return hasGeneric(mtv.table) || hasGeneric(mtv.metatable); - } - - bool operator()(const ClassTypeVar& ctv) const - { - for (const auto& [name, prop] : ctv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - if (ctv.parent) - return hasGeneric(*ctv.parent); - - return false; - } - - bool operator()(const AnyTypeVar&) const - { - return false; - } - - bool operator()(const UnionTypeVar& utv) const - { - for (TypeId tid : utv.options) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const IntersectionTypeVar& utv) const - { - for (TypeId tid : utv.parts) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const LazyTypeVar&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const TypePack& pack) const - { - for (TypeId ty : pack.head) - if (hasGeneric(ty)) - return true; - - if (pack.tail) - return hasGeneric(*pack.tail); - - return false; - } - - bool operator()(const VariadicTypePack& pack) const - { - return hasGeneric(pack.ty); - } -}; - const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) @@ -953,16 +837,6 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -bool hasGeneric(TypeId ty) -{ - return Luau::visit(QVarFinder{}, ty->ty); -} - -bool hasGeneric(TypePackId tp) -{ - return Luau::visit(QVarFinder{}, tp->ty); -} - UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) { LUAU_ASSERT(utv); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d0b18837..c5aab856 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,7 +14,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) @@ -22,9 +22,82 @@ LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { + +struct PromoteTypeLevels +{ + TxnLog& log; + TypeLevel minLevel; + + explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) + : log(log) + , minLevel(minLevel) + {} + + template + void promote(TID ty, T* t) + { + LUAU_ASSERT(t); + if (minLevel.subsumesStrict(t->level)) + { + log(ty); + t->level = minLevel; + } + } + + template + void cycle(TID) {} + + template + bool operator()(TID, const T&) + { + return true; + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const FunctionTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack&) + { + promote(tp, getMutable(tp)); + return true; + } +}; + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, ptl, seen); +} + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, ptl, seen); +} + struct SkipCacheForType { SkipCacheForType(const DenseHashMap& skipCacheForType) @@ -127,6 +200,29 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +// Used for tagged union matching heuristic, returns first singleton type field +static std::optional> getTableMatchTag(TypeId type) +{ + LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); + + type = follow(type); + + if (auto ttv = get(type)) + { + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + else if (auto mttv = get(type)) + { + return getTableMatchTag(mttv->table); + } + + return std::nullopt; +} + Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) @@ -214,9 +310,11 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(superTy, subTy); + TypeLevel superLevel = l->level; + // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -226,7 +324,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { - if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, superLevel, subTy); + else if (auto rightLevel = getMutableLevel(subTy)) { if (!rightLevel->subsumes(l->level)) *rightLevel = l->level; @@ -240,6 +340,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (r) { + TypeLevel subLevel = r->level; + occursCheck(subTy, superTy); // Unification can't change the level of a generic. @@ -253,10 +355,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (!get(subTy)) { - if (auto leftLevel = getMutableLevel(superTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, subLevel, superTy); + + if (auto superLevel = getMutableLevel(superTy)) { - if (!leftLevel->subsumes(r->level)) - *leftLevel = r->level; + if (!superLevel->subsumes(r->level)) + { + log(superTy); + *superLevel = r->level; + } } log(subTy); @@ -327,7 +435,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (failed) { if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } @@ -338,28 +446,46 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; + size_t failedOptionCount = 0; + std::optional failedOption; + + bool foundHeuristic = false; size_t startIndex = 0; if (FFlag::LuauUnionHeuristic) { - bool found = false; - - const std::string* subName = getName(subTy); - if (subName) + if (const std::string* subName = getName(subTy)) { for (size_t i = 0; i < uv->options.size(); ++i) { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) { - found = true; + foundHeuristic = true; startIndex = i; break; } } } - if (!found && cacheEnabled) + if (FFlag::LuauExtendedUnionMismatchError) + { + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + } + + if (!foundHeuristic && cacheEnabled) { for (size_t i = 0; i < uv->options.size(); ++i) { @@ -390,15 +516,27 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { unificationTooComplex = e; } + else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + { + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; + } innerState.log.rollback(); } if (unificationTooComplex) + { errors.push_back(*unificationTooComplex); + } else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) + if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + else if (FFlag::LuauExtendedTypeMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -431,7 +569,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else { @@ -771,6 +909,10 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal if (superIter.good() && subIter.good()) { tryUnify_(*superIter, *subIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + superIter.advance(); subIter.advance(); continue; @@ -853,13 +995,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *superIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *subIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); subIter.advance(); } @@ -917,14 +1059,22 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } for (size_t i = 0; i < numGenerics; i++) @@ -936,13 +1086,49 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + if (FFlag::LuauExtendedFunctionMismatchError) + { + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); - ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + bool reported = !innerState.errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + if (!reported) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + } + } + else + { + ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + + ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + } log.concat(std::move(innerState.log)); } @@ -994,7 +1180,7 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance) + if (!FFlag::LuauTableSubtypingVariance2) return DEPRECATED_tryUnifyTables(left, right, isIntersection); TableTypeVar* lt = getMutable(left); @@ -1133,7 +1319,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(lt); + log(left); lt->props[name] = clone; } else if (variance == Covariant) @@ -1146,7 +1332,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } else if (lt->state == TableState::Free) { - log(lt); + log(left); lt->props[name] = prop; } else @@ -1176,7 +1362,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(rt); + log(right); rt->indexer = lt->indexer; } } @@ -1185,7 +1371,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // Symmetric if we are invariant if (lt->state == TableState::Unsealed || lt->state == TableState::Free) { - log(lt); + log(left); lt->indexer = rt->indexer; } } @@ -1241,15 +1427,15 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, result}}); } else - return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; @@ -1467,7 +1653,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else if (lt->indexer) { - innerState.tryUnify_(lt->indexer->indexType, singletonTypes.stringType); + innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); // We already try to unify properties in both tables. // Skip those and just look for the ones remaining and see if they fit into the indexer. for (const auto& [name, type] : rt->props) @@ -1636,7 +1822,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorRecoveryType()); + tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { @@ -1825,7 +2011,7 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) if (get(ty) || get(ty) || get(ty)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); @@ -1834,14 +2020,14 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorRecoveryType(); + const TypeId anyTy = getSingletonTypes().errorRecoveryType(); std::vector queue; @@ -1887,7 +2073,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryType(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; } @@ -1951,7 +2137,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); return; } @@ -2005,7 +2191,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s errors.push_back(*e); else if (!innerErrors.empty()) errors.push_back( - TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3d0d5b7e..dd24f27c 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) @@ -159,7 +158,7 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { std::vector hotcomments; - while (isComment(p.lexer.current()) || (FFlag::LuauCaptureBrokenCommentSpans && p.lexer.current().type == Lexeme::BrokenComment)) + while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) { const char* text = p.lexer.current().data; unsigned int length = p.lexer.current().length; @@ -2780,7 +2779,7 @@ const Lexeme& Parser::nextLexeme() const Lexeme& lexeme = lexer.next(/*skipComments*/ false); // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. // The parser will turn this into a proper syntax error. - if (FFlag::LuauCaptureBrokenCommentSpans && lexeme.type == Lexeme::BrokenComment) + if (lexeme.type == Lexeme::BrokenComment) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); if (isComment(lexeme)) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9230d80d..aecb619a 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -11,26 +11,33 @@ enum class ReportFormat { Default, - Luacheck + Luacheck, + Gnu, }; -static void report(ReportFormat format, const char* name, const Luau::Location& location, const char* type, const char* message) +static void report(ReportFormat format, const char* name, const Luau::Location& loc, const char* type, const char* message) { switch (format) { case ReportFormat::Default: - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, type, message); break; case ReportFormat::Luacheck: { // Note: luacheck's end column is inclusive but our end column is exclusive // In addition, luacheck doesn't support multi-line messages, so if the error is multiline we'll fake end column as 100 and hope for the best - int columnEnd = (location.begin.line == location.end.line) ? location.end.column : 100; + int columnEnd = (loc.begin.line == loc.end.line) ? loc.end.column : 100; - fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, columnEnd, type, message); + // Use stdout to match luacheck behavior + fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, columnEnd, type, message); break; } + + case ReportFormat::Gnu: + // Note: GNU end column is inclusive but our end column is exclusive + fprintf(stderr, "%s:%d.%d-%d.%d: %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, loc.end.line + 1, loc.end.column, type, message); + break; } } @@ -97,6 +104,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); + printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -201,6 +209,8 @@ int main(int argc, char** argv) if (strcmp(argv[i], "--formatter=plain") == 0) format = ReportFormat::Luacheck; + else if (strcmp(argv[i], "--formatter=gnu") == 0) + format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; } diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp new file mode 100644 index 00000000..254df3f0 --- /dev/null +++ b/CLI/Coverage.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Coverage.h" + +#include "lua.h" + +#include +#include + +struct Coverage +{ + lua_State* L = nullptr; + std::vector functions; +} gCoverage; + +void coverageInit(lua_State* L) +{ + gCoverage.L = lua_mainthread(L); +} + +bool coverageActive() +{ + return gCoverage.L != nullptr; +} + +void coverageTrack(lua_State* L, int funcindex) +{ + int ref = lua_ref(L, funcindex); + gCoverage.functions.push_back(ref); +} + +static void coverageCallback(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) +{ + FILE* f = static_cast(context); + + std::string name; + + if (depth == 0) + name = "
"; + else if (function) + name = std::string(function) + ":" + std::to_string(linedefined); + else + name = ":" + std::to_string(linedefined); + + fprintf(f, "FN:%d,%s\n", linedefined, name.c_str()); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + fprintf(f, "FNDA:%d,%s\n", hits[i], name.c_str()); + break; + } + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + fprintf(f, "DA:%d,%d\n", int(i), hits[i]); +} + +void coverageDump(const char* path) +{ + lua_State* L = gCoverage.L; + + FILE* f = fopen(path, "w"); + if (!f) + { + fprintf(stderr, "Error opening coverage %s\n", path); + return; + } + + fprintf(f, "TN:\n"); + + for (int fref: gCoverage.functions) + { + lua_getref(L, fref); + + lua_Debug ar = {}; + lua_getinfo(L, -1, "s", &ar); + + fprintf(f, "SF:%s\n", ar.short_src); + lua_getcoverage(L, -1, f, coverageCallback); + fprintf(f, "end_of_record\n"); + + lua_pop(L, 1); + } + + fclose(f); + + printf("Coverage dump written to %s (%d functions)\n", path, int(gCoverage.functions.size())); +} diff --git a/CLI/Coverage.h b/CLI/Coverage.h new file mode 100644 index 00000000..74be4e5c --- /dev/null +++ b/CLI/Coverage.h @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +void coverageInit(lua_State* L); +bool coverageActive(); + +void coverageTrack(lua_State* L, int funcindex); +void coverageDump(const char* path); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index b3c9557b..cb993dfe 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -67,6 +67,10 @@ std::optional readFile(const std::string& name) if (read != size_t(length)) return std::nullopt; + // Skip first line if it's a shebang + if (length > 2 && result[0] == '#' && result[1] == '!') + result.erase(0, result.find('\n')); + return result; } diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index c6d15a7f..30a171f0 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -110,12 +110,12 @@ void profilerStop() gProfiler.thread.join(); } -void profilerDump(const char* name) +void profilerDump(const char* path) { - FILE* f = fopen(name, "wb"); + FILE* f = fopen(path, "wb"); if (!f) { - fprintf(stderr, "Error opening profile %s\n", name); + fprintf(stderr, "Error opening profile %s\n", path); return; } @@ -129,7 +129,7 @@ void profilerDump(const char* name) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", name, double(total) / 1e6, + printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); uint64_t totalgc = 0; diff --git a/CLI/Profiler.h b/CLI/Profiler.h index 0a407e47..67b1acfd 100644 --- a/CLI/Profiler.h +++ b/CLI/Profiler.h @@ -5,4 +5,4 @@ struct lua_State; void profilerStart(lua_State* L, int frequency); void profilerStop(); -void profilerDump(const char* name); \ No newline at end of file +void profilerDump(const char* path); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 2cdd0062..35c02f2c 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -8,6 +8,7 @@ #include "FileUtils.h" #include "Profiler.h" +#include "Coverage.h" #include "linenoise.hpp" @@ -24,6 +25,16 @@ enum class CompileFormat Binary }; +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = 1; + result.debugLevel = 1; + result.coverageLevel = coverageActive() ? 2 : 0; + + return result; +} + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -32,7 +43,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); + std::string bytecode = Luau::compile(std::string(s, l), copts()); if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; @@ -79,9 +90,12 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(ML, -1); + int status = lua_resume(ML, L, 0); if (status == 0) @@ -149,7 +163,7 @@ static void setupState(lua_State* L) static std::string runCode(lua_State* L, const std::string& source) { - std::string bytecode = Luau::compile(source); + std::string bytecode = Luau::compile(source, copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { @@ -329,11 +343,14 @@ static bool runFile(const char* name, lua_State* GL) std::string chunkname = "=" + std::string(name); - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); int status = 0; if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(L, -1); + status = lua_resume(L, NULL, 0); } else @@ -437,6 +454,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); + printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -495,6 +513,7 @@ int main(int argc, char** argv) setupState(L); int profile = 0; + bool coverage = false; for (int i = 1; i < argc; ++i) { @@ -505,11 +524,16 @@ int main(int argc, char** argv) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + else if (strcmp(argv[i], "--coverage") == 0) + coverage = true; } if (profile) profilerStart(L, profile); + if (coverage) + coverageInit(L); + std::vector files = getSourceFiles(argc, argv); int failed = 0; @@ -523,6 +547,9 @@ int main(int argc, char** argv) profilerDump("profile.out"); } + if (coverage) + coverageDump("coverage.out"); + return failed; } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 2c1e85ff..8f74ffed 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -462,20 +461,17 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosures) + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) { - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); + int32_t cid = bytecode.addConstantClosure(f->id); - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - shared = true; - } + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + shared = true; } } diff --git a/Makefile b/Makefile index 15c7ff7a..b144cac6 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ TESTS_SOURCES=$(wildcard tests/*.cpp) TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -128,10 +128,10 @@ luau-size: luau # executable target aliases luau: $(REPL_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ luau-analyze: $(ANALYZE_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) diff --git a/Sources.cmake b/Sources.cmake index 57df9b91..14834b3a 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -133,6 +133,7 @@ target_sources(Luau.VM PRIVATE VM/src/ltable.cpp VM/src/ltablib.cpp VM/src/ltm.cpp + VM/src/ludata.cpp VM/src/lutf8lib.cpp VM/src/lvmexecute.cpp VM/src/lvmload.cpp @@ -152,12 +153,15 @@ target_sources(Luau.VM PRIVATE VM/src/lstring.h VM/src/ltable.h VM/src/ltm.h + VM/src/ludata.h VM/src/lvm.h ) if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp CLI/FileUtils.h CLI/FileUtils.cpp CLI/Profiler.h diff --git a/VM/include/lua.h b/VM/include/lua.h index 7078acd0..55902160 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -334,6 +334,10 @@ LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API void lua_singlestep(lua_State* L, int enabled); LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); +typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size); + +LUA_API void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback); + /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 76043b9c..a65b0325 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -8,6 +8,7 @@ #include "lfunc.h" #include "lgc.h" #include "ldo.h" +#include "ludata.h" #include "lvm.h" #include "lnumutils.h" @@ -43,36 +44,30 @@ static Table* getcurrenv(lua_State* L) } } -static LUAU_NOINLINE TValue* index2adrslow(lua_State* L, int idx) +static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) { - api_check(L, idx <= 0); - if (idx > LUA_REGISTRYINDEX) + api_check(L, lua_ispseudo(idx)); + switch (idx) + { /* pseudo-indices */ + case LUA_REGISTRYINDEX: + return registry(L); + case LUA_ENVIRONINDEX: { - api_check(L, idx != 0 && -idx <= L->top - L->base); - return L->top + idx; + sethvalue(L, &L->env, getcurrenv(L)); + return &L->env; + } + case LUA_GLOBALSINDEX: + return gt(L); + default: + { + Closure* func = curr_func(L); + idx = LUA_GLOBALSINDEX - idx; + return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); + } } - else - switch (idx) - { /* pseudo-indices */ - case LUA_REGISTRYINDEX: - return registry(L); - case LUA_ENVIRONINDEX: - { - sethvalue(L, &L->env, getcurrenv(L)); - return &L->env; - } - case LUA_GLOBALSINDEX: - return gt(L); - default: - { - Closure* func = curr_func(L); - idx = LUA_GLOBALSINDEX - idx; - return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); - } - } } -static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) +static LUAU_FORCEINLINE TValue* index2addr(lua_State* L, int idx) { if (idx > 0) { @@ -83,15 +78,20 @@ static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) else return o; } + else if (idx > LUA_REGISTRYINDEX) + { + api_check(L, idx != 0 && -idx <= L->top - L->base); + return L->top + idx; + } else { - return index2adrslow(L, idx); + return pseudo2addr(L, idx); } } const TValue* luaA_toobject(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); return (p == luaO_nilobject) ? NULL : p; } @@ -145,7 +145,7 @@ void lua_xpush(lua_State* from, lua_State* to, int idx) { api_check(from, from->global == to->global); luaC_checkthreadsleep(to); - setobj2s(to, to->top, index2adr(from, idx)); + setobj2s(to, to->top, index2addr(from, idx)); api_incr_top(to); return; } @@ -202,7 +202,7 @@ void lua_settop(lua_State* L, int idx) void lua_remove(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); while (++p < L->top) setobjs2s(L, p - 1, p); @@ -213,7 +213,7 @@ void lua_remove(lua_State* L, int idx) void lua_insert(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); for (StkId q = L->top; q > p; q--) setobjs2s(L, q, q - 1); @@ -228,7 +228,7 @@ void lua_replace(lua_State* L, int idx) luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { @@ -250,7 +250,7 @@ void lua_replace(lua_State* L, int idx) void lua_pushvalue(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); setobj2s(L, L->top, o); api_incr_top(L); return; @@ -262,7 +262,7 @@ void lua_pushvalue(lua_State* L, int idx) int lua_type(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (o == luaO_nilobject) ? LUA_TNONE : ttype(o); } @@ -273,20 +273,20 @@ const char* lua_typename(lua_State* L, int t) int lua_iscfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return iscfunction(o); } int lua_isLfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return isLfunction(o); } int lua_isnumber(lua_State* L, int idx) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return tonumber(o, &n); } @@ -298,14 +298,14 @@ int lua_isstring(lua_State* L, int idx) int lua_isuserdata(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return (ttisuserdata(o) || ttislightuserdata(o)); } int lua_rawequal(lua_State* L, int index1, int index2) { - StkId o1 = index2adr(L, index1); - StkId o2 = index2adr(L, index2); + StkId o1 = index2addr(L, index1); + StkId o2 = index2addr(L, index2); return (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaO_rawequalObj(o1, o2); } @@ -313,8 +313,8 @@ int lua_equal(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : equalobj(L, o1, o2); return i; } @@ -323,8 +323,8 @@ int lua_lessthan(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaV_lessthan(L, o1, o2); return i; } @@ -332,7 +332,7 @@ int lua_lessthan(lua_State* L, int index1, int index2) double lua_tonumberx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { if (isnum) @@ -350,7 +350,7 @@ double lua_tonumberx(lua_State* L, int idx, int* isnum) int lua_tointegerx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { int res; @@ -371,7 +371,7 @@ int lua_tointegerx(lua_State* L, int idx, int* isnum) unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { unsigned res; @@ -391,13 +391,13 @@ unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) int lua_toboolean(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return !l_isfalse(o); } const char* lua_tolstring(lua_State* L, int idx, size_t* len) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) { luaC_checkthreadsleep(L); @@ -408,7 +408,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) return NULL; } luaC_checkGC(L); - o = index2adr(L, idx); /* previous call may reallocate the stack */ + o = index2addr(L, idx); /* previous call may reallocate the stack */ } if (len != NULL) *len = tsvalue(o)->len; @@ -417,7 +417,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) const char* lua_tostringatom(lua_State* L, int idx, int* atom) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) return NULL; const TString* s = tsvalue(o); @@ -438,7 +438,7 @@ const char* lua_namecallatom(lua_State* L, int* atom) const float* lua_tovector(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisvector(o)) { return NULL; @@ -448,7 +448,7 @@ const float* lua_tovector(lua_State* L, int idx) int lua_objlen(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TSTRING: @@ -469,13 +469,13 @@ int lua_objlen(lua_State* L, int idx) lua_CFunction lua_tocfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); } void* lua_touserdata(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TUSERDATA: @@ -489,13 +489,13 @@ void* lua_touserdata(lua_State* L, int idx) void* lua_touserdatatagged(lua_State* L, int idx, int tag) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (ttisuserdata(o) && uvalue(o)->tag == tag) ? uvalue(o)->data : NULL; } int lua_userdatatag(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (ttisuserdata(o)) return uvalue(o)->tag; return -1; @@ -503,13 +503,13 @@ int lua_userdatatag(lua_State* L, int idx) lua_State* lua_tothread(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!ttisthread(o)) ? NULL : thvalue(o); } const void* lua_topointer(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TTABLE: @@ -657,7 +657,7 @@ int lua_pushthread(lua_State* L) void lua_gettable(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); return; @@ -666,7 +666,7 @@ void lua_gettable(lua_State* L, int idx) void lua_getfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -678,7 +678,7 @@ void lua_getfield(lua_State* L, int idx, const char* k) void lua_rawgetfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -690,7 +690,7 @@ void lua_rawgetfield(lua_State* L, int idx, const char* k) void lua_rawget(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); return; @@ -699,7 +699,7 @@ void lua_rawget(lua_State* L, int idx) void lua_rawgeti(lua_State* L, int idx, int n) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); api_incr_top(L); @@ -717,7 +717,7 @@ void lua_createtable(lua_State* L, int narray, int nrec) void lua_setreadonly(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); @@ -727,7 +727,7 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) int lua_getreadonly(lua_State* L, int objindex) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); int res = t->readonly; @@ -736,7 +736,7 @@ int lua_getreadonly(lua_State* L, int objindex) void lua_setsafeenv(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); t->safeenv = bool(enabled); @@ -748,7 +748,7 @@ int lua_getmetatable(lua_State* L, int objindex) const TValue* obj; Table* mt = NULL; int res; - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -775,7 +775,7 @@ int lua_getmetatable(lua_State* L, int objindex) void lua_getfenv(lua_State* L, int idx) { StkId o; - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -801,7 +801,7 @@ void lua_settable(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -813,7 +813,7 @@ void lua_setfield(lua_State* L, int idx, const char* k) StkId t; TValue key; api_checknelems(L, 1); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); @@ -825,7 +825,7 @@ void lua_rawset(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -839,7 +839,7 @@ void lua_rawseti(lua_State* L, int idx, int n) { StkId o; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -854,7 +854,7 @@ int lua_setmetatable(lua_State* L, int objindex) TValue* obj; Table* mt; api_checknelems(L, 1); - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); api_checkvalidindex(L, obj); if (ttisnil(L->top - 1)) mt = NULL; @@ -896,7 +896,7 @@ int lua_setfenv(lua_State* L, int idx) StkId o; int res = 1; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -987,7 +987,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) func = 0; else { - StkId o = index2adr(L, errfunc); + StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } @@ -1150,7 +1150,7 @@ l_noret lua_error(lua_State* L) int lua_next(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); int more = luaH_next(L, hvalue(t), L->top - 1); if (more) @@ -1187,7 +1187,7 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz, tag); + Udata* u = luaU_newudata(L, sz, tag); setuvalue(L, L->top, u); api_incr_top(L); return u->data; @@ -1197,7 +1197,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); + Udata* u = luaU_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); @@ -1232,7 +1232,7 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) { luaC_checkthreadsleep(L); TValue* val; - const char* name = aux_upvalue(index2adr(L, funcindex), n, &val); + const char* name = aux_upvalue(index2addr(L, funcindex), n, &val); if (name) { setobj2s(L, L->top, val); @@ -1246,7 +1246,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) const char* name; TValue* val; StkId fi; - fi = index2adr(L, funcindex); + fi = index2addr(L, funcindex); api_checknelems(L, 1); name = aux_upvalue(fi, n, &val); if (name) @@ -1270,7 +1270,7 @@ int lua_ref(lua_State* L, int idx) api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); if (!ttisnil(p)) { Table* reg = hvalue(registry(L)); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index d77f84ef..9fe1885f 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -370,6 +370,69 @@ void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } +static int getmaxline(Proto* p) +{ + int result = -1; + + for (int i = 0; i < p->sizecode; ++i) + { + int line = luaG_getline(p, i); + result = result < line ? line : result; + } + + for (int i = 0; i < p->sizep; ++i) + { + int psize = getmaxline(p->p[i]); + result = result < psize ? psize : result; + } + + return result; +} + +static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback) +{ + memset(buffer, -1, size * sizeof(int)); + + for (int i = 0; i < p->sizecode; ++i) + { + Instruction insn = p->code[i]; + if (LUAU_INSN_OP(insn) != LOP_COVERAGE) + continue; + + int line = luaG_getline(p, i); + int hits = LUAU_INSN_E(insn); + + LUAU_ASSERT(size_t(line) < size); + buffer[line] = buffer[line] < hits ? hits : buffer[line]; + } + + const char* debugname = p->debugname ? getstr(p->debugname) : NULL; + int linedefined = luaG_getline(p, 0); + + callback(context, debugname, linedefined, depth, buffer, size); + + for (int i = 0; i < p->sizep; ++i) + getcoverage(p->p[i], depth + 1, buffer, size, context, callback); +} + +void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback) +{ + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + Proto* p = clvalue(func)->l.p; + + size_t size = getmaxline(p) + 1; + if (size == 0) + return; + + int* buffer = luaM_newarray(L, size, int, 0); + + getcoverage(p, 0, buffer, size, context, callback); + + luaM_freearray(L, buffer, size, int, 0); +} + static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) { size_t size = strlen(data); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index ab416041..7393fc74 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -8,11 +8,10 @@ #include "lfunc.h" #include "lstring.h" #include "ldo.h" +#include "ludata.h" #include -LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) - LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -59,10 +58,6 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSpropagate: case GCSpropagateagain: g->gcstats.currcycle.marktime += seconds; - - // atomic step had to be performed during the switch and it's tracked separately - if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) - g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; case GCSatomic: g->gcstats.currcycle.atomictime += seconds; @@ -488,7 +483,7 @@ static void freeobj(lua_State* L, GCObject* o) luaS_free(L, gco2ts(o)); break; case LUA_TUSERDATA: - luaS_freeudata(L, gco2u(o)); + luaU_freeudata(L, gco2u(o)); break; default: LUAU_ASSERT(0); @@ -632,17 +627,9 @@ static size_t remarkupvals(global_State* g) static size_t atomic(lua_State* L) { global_State* g = L->global; + LUAU_ASSERT(g->gcstate == GCSatomic); + size_t work = 0; - - if (FFlag::LuauSeparateAtomic) - { - LUAU_ASSERT(g->gcstate == GCSatomic); - } - else - { - g->gcstate = GCSatomic; - } - /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ @@ -666,11 +653,6 @@ static size_t atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - if (!FFlag::LuauSeparateAtomic) - { - GC_INTERRUPT(GCSatomic); - } - return work; } @@ -716,22 +698,7 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSatomic; } break; } @@ -853,7 +820,7 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - ptrdiff_t lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -908,7 +875,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1049,7 +1016,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index a79e7b95..f6f7a878 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -7,6 +7,7 @@ #include "ltable.h" #include "lfunc.h" #include "lstring.h" +#include "ludata.h" #include #include diff --git a/VM/src/lobject.h b/VM/src/lobject.h index ba040af6..fd0a15b7 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -78,15 +78,7 @@ typedef struct lua_TValue #define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th) #define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv) -// beware bit magic: a value is false if it's nil or boolean false -// baseline implementation: (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) -// we'd like a branchless version of this which helps with performance, and a very fast version -// so our strategy is to always read the boolean value (not using bvalue(o) because that asserts when type isn't boolean) -// we then combine it with type to produce 0/1 as follows: -// - when type is nil (0), & makes the result 0 -// - when type is boolean (1), we effectively only look at the bottom bit, so result is 0 iff boolean value is 0 -// - when type is different, it must have some of the top bits set - we keep all top bits of boolean value so the result is non-0 -#define l_isfalse(o) (!(((o)->value.b | ~1) & ttype(o))) +#define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) /* ** for internal debug only diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 18ee1cda..a9e90d17 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -206,32 +206,3 @@ void luaS_free(lua_State* L, TString* ts) L->global->strt.nuse--; luaM_free(L, ts, sizestring(ts->len), ts->memcat); } - -Udata* luaS_newudata(lua_State* L, size_t s, int tag) -{ - if (s > INT_MAX - sizeof(Udata)) - luaM_toobig(L); - Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); - luaC_link(L, u, LUA_TUSERDATA); - u->len = int(s); - u->metatable = NULL; - LUAU_ASSERT(tag >= 0 && tag <= 255); - u->tag = uint8_t(tag); - return u; -} - -void luaS_freeudata(lua_State* L, Udata* u) -{ - LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); - - void (*dtor)(void*) = nullptr; - if (u->tag == UTAG_IDTOR) - memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); - else if (u->tag) - dtor = L->global->udatagc[u->tag]; - - if (dtor) - dtor(u->data); - - luaM_free(L, u, sizeudata(u->len), u->memcat); -} diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 612da28d..3fd0bd39 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -8,11 +8,7 @@ /* string size limit */ #define MAXSSIZE (1 << 30) -/* special tag value is used for user data with inline dtors */ -#define UTAG_IDTOR LUA_UTAG_LIMIT - #define sizestring(len) (offsetof(TString, data) + len + 1) -#define sizeudata(len) (offsetof(Udata, data) + len) #define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) #define luaS_newliteral(L, s) (luaS_newlstr(L, "" s, (sizeof(s) / sizeof(char)) - 1)) @@ -26,8 +22,5 @@ LUAI_FUNC void luaS_resize(lua_State* L, int newsize); LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); LUAI_FUNC void luaS_free(lua_State* L, TString* ts); -LUAI_FUNC Udata* luaS_newudata(lua_State* L, size_t s, int tag); -LUAI_FUNC void luaS_freeudata(lua_State* L, Udata* u); - LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp new file mode 100644 index 00000000..d180c388 --- /dev/null +++ b/VM/src/ludata.cpp @@ -0,0 +1,37 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ludata.h" + +#include "lgc.h" +#include "lmem.h" + +#include + +Udata* luaU_newudata(lua_State* L, size_t s, int tag) +{ + if (s > INT_MAX - sizeof(Udata)) + luaM_toobig(L); + Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + luaC_link(L, u, LUA_TUSERDATA); + u->len = int(s); + u->metatable = NULL; + LUAU_ASSERT(tag >= 0 && tag <= 255); + u->tag = uint8_t(tag); + return u; +} + +void luaU_freeudata(lua_State* L, Udata* u) +{ + LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); + + void (*dtor)(void*) = nullptr; + if (u->tag == UTAG_IDTOR) + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + else if (u->tag) + dtor = L->global->udatagc[u->tag]; + + if (dtor) + dtor(u->data); + + luaM_free(L, u, sizeudata(u->len), u->memcat); +} diff --git a/VM/src/ludata.h b/VM/src/ludata.h new file mode 100644 index 00000000..59cb85bd --- /dev/null +++ b/VM/src/ludata.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +/* special tag value is used for user data with inline dtors */ +#define UTAG_IDTOR LUA_UTAG_LIMIT + +#define sizeudata(len) (offsetof(Udata, data) + len) + +LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); +LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index bf8d493e..cebeeb58 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -63,7 +63,8 @@ #define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) #define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) -#define VM_PATCH_C(pc, slot) ((uint8_t*)(pc))[3] = uint8_t(slot) +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) // NOTE: If debugging the Luau code, disable this macro to prevent timeouts from // occurring when tracing code in Visual Studio / XCode @@ -120,7 +121,7 @@ */ #if VM_USE_CGOTO #define VM_CASE(op) CASE_##op: -#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[*(uint8_t*)pc]) +#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[LUAU_INSN_OP(*pc)]) #define VM_CONTINUE(op) goto* kDispatchTable[uint8_t(op)] #else #define VM_CASE(op) case op: @@ -325,7 +326,7 @@ static void luau_execute(lua_State* L) // ... and singlestep logic :) if (SingleStep) { - if (L->global->cb.debugstep && !luau_skipstep(*(uint8_t*)pc)) + if (L->global->cb.debugstep && !luau_skipstep(LUAU_INSN_OP(*pc))) { VM_PROTECT(luau_callhook(L, L->global->cb.debugstep, NULL)); @@ -335,13 +336,12 @@ static void luau_execute(lua_State* L) } #if VM_USE_CGOTO - VM_CONTINUE(*(uint8_t*)pc); + VM_CONTINUE(LUAU_INSN_OP(*pc)); #endif } #if !VM_USE_CGOTO - // Note: this assumes that LUAU_INSN_OP() decodes the first byte (aka least significant byte in the little endian encoding) - size_t dispatchOp = *(uint8_t*)pc; + size_t dispatchOp = LUAU_INSN_OP(*pc); dispatchContinue: switch (dispatchOp) @@ -2577,7 +2577,7 @@ static void luau_execute(lua_State* L) // update hits with saturated add and patch the instruction in place hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; - ((uint32_t*)pc)[-1] = LOP_COVERAGE | (uint32_t(hits) << 8); + VM_PATCH_E(pc - 1, hits); VM_NEXT(); } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 740a4cfd..5d802277 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -53,7 +53,7 @@ const float* luaV_tovector(const TValue* obj) static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers @@ -74,7 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) { - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 60e4f61e..c8f6b5dc 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -451,15 +451,16 @@ function raytraceScene() end function arrayToCanvasCommands(pixels) - local s = 'Test\nvar pixels = ['; + local s = {}; + table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do - s = s .. "["; + table.insert(s, "["); for x = 0,size-1 do - s = s .. "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"; + table.insert(s, "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"); end - s = s .. "],"; + table.insert(s, "],"); end - s = s .. '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ + table.insert(s, '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ \n\ \n\ var size = ' .. size .. ';\n\ @@ -479,9 +480,9 @@ for (var y = 0; y < size; y++) {\n\ canvas.setFillColor(l[0], l[1], l[2], 1);\n\ canvas.fillRect(x, y, 1, 1);\n\ }\n\ -}'; +}'); - return s; + return table.concat(s); end testOutput = arrayToCanvasCommands(raytraceScene()); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 3b74a99e..62a9999b 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -513,8 +513,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comme TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check(R"( --[[ @1 )"); @@ -526,8 +524,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[[@1"); auto ac = autocomplete('1'); @@ -2625,4 +2621,55 @@ local a: A<(number, s@1> CHECK(ac.entryMap.count("string")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); + + check(R"( +local function foo1() return 1 end +local function foo2() return "1" end + +local function bar0() return "got" .. a end +local function bar1(a: number) return "got " .. a end +local function bar2(a: number, b: string) return "got " .. a .. b end + +local t = {} +function t:bar1(a: number) return "got " .. a end + +local r1 = bar0(@1) +local r2 = bar1(@2) +local r3 = bar2(@3) +local r4 = t:bar1(@4) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::None); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('2'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('3'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('4'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6ba39ada..95811b3f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,9 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauGenericSpecialGlobals) - using namespace Luau; static std::string compileFunction(const char* source, uint32_t id) @@ -74,20 +71,10 @@ TEST_CASE("BasicFunction") bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); - if (FFlag::LuauPreloadClosures) - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( DUPCLOSURE R0 K0 RETURN R0 0 )"); - } - else - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( -NEWCLOSURE R0 P0 -RETURN R0 0 -)"); - } CHECK_EQ("\n" + bcb.dumpFunction(0), R"( RETURN R1 1 @@ -2859,47 +2846,35 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosures) - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( DUPCLOSURE R0 K0 CAPTURE VAL R0 RETURN R0 0 )"); - // multi-level recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( + // multi-level recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( DUPCLOSURE R0 K0 CAPTURE UPVAL U0 RETURN R0 1 )"); - // multi-level recursive capture where function isn't top-level - // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + // multi-level recursive capture where function isn't top-level + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( local function foo() local function bar() return function() return bar() end end end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 )"); - } - else - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( -NEWCLOSURE R0 P0 -CAPTURE VAL R0 -RETURN R0 0 -)"); - } } TEST_CASE("OutOfLocals") @@ -3504,8 +3479,6 @@ local t = { TEST_CASE("ConstantClosure") { - ScopedFastFlag sff("LuauPreloadClosures", true); - // closures without upvalues are created when bytecode is loaded CHECK_EQ("\n" + compileFunction(R"( return function() end @@ -3570,8 +3543,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff1("LuauPreloadClosures", true); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b2aad316..b055a38e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -123,8 +123,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; -static StateRef runConformance( - const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, lua_State* initialLuaState = nullptr) +static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, + lua_State* initialLuaState = nullptr, lua_CompileOptions* copts = nullptr) { std::string path = __FILE__; path.erase(path.find_last_of("\\/")); @@ -180,13 +180,8 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; // default - copts.debugLevel = 2; // for debugger tests - copts.vectorCtor = "vector"; // for vector tests - size_t bytecodeSize = 0; - char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + char* bytecode = luau_compile(source.data(), source.size(), copts, &bytecodeSize); int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); @@ -373,29 +368,37 @@ TEST_CASE("Vector") { ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector, "vector"); - lua_setglobal(L, "vector"); + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.vectorCtor = "vector"; + + runConformance( + "vector.lua", + [](lua_State* L) { + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); #if LUA_VECTOR_SIZE == 4 - lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); #else - lua_pushvector(L, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f); #endif - luaL_newmetatable(L, "vector"); + luaL_newmetatable(L, "vector"); - lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index, nullptr); + lua_settable(L, -3); - lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall, nullptr); + lua_settable(L, -3); - lua_setreadonly(L, -1, true); - lua_setmetatable(L, -2); - lua_pop(L, 1); - }); + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }, + nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -499,6 +502,10 @@ TEST_CASE("Debugger") breakhits = 0; interruptedthread = nullptr; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 2; + runConformance( "debugger.lua", [](lua_State* L) { @@ -614,7 +621,8 @@ TEST_CASE("Debugger") lua_resume(interruptedthread, nullptr, 0); interruptedthread = nullptr; } - }); + }, + nullptr, &copts); CHECK(breakhits == 10); // 2 hits per breakpoint } @@ -863,4 +871,46 @@ TEST_CASE("TagMethodError") }); } +TEST_CASE("Coverage") +{ + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.coverageLevel = 2; + + runConformance( + "coverage.lua", + [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) -> int { + luaL_argexpected(L, lua_isLfunction(L, 1), 1, "function"); + + lua_newtable(L); + lua_getcoverage(L, 1, L, [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) { + lua_State* L = static_cast(context); + + lua_newtable(L); + + lua_pushstring(L, function); + lua_setfield(L, -2, "name"); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + lua_pushinteger(L, hits[i]); + lua_rawseti(L, -2, int(i)); + } + + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + }); + + return 1; + }, + "getcoverage"); + lua_setglobal(L, "getcoverage"); + }, + nullptr, nullptr, &copts); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe..e3993cc5 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 72d3a9a6..5abcb09a 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2303,8 +2303,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; @@ -2319,8 +2317,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 091c2f01..71ff4e1b 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -207,6 +207,36 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") CHECK_EQ("number", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + + CheckResult result = check(R"( + local a = 55 :: number? + local b = a :: number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + local a = 55 :: string + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Cannot cast 'number' into 'string' because the types are unrelated", toString(result.errors[0])); + CHECK_EQ("string", toString(requireType("a"))); +} + TEST_CASE_FIXTURE(Fixture, "type_annotations_inside_function_bodies") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1e2eae14..1d8135d4 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixTonumberReturnType) + using namespace Luau; TEST_SUITE_BEGIN("BuiltinTests"); @@ -814,6 +816,30 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); } +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') + )"); + + if (FFlag::LuauFixTonumberReturnType) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); + } +} + +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index aba50891..b62044fa 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -644,4 +646,42 @@ f(1, 2, 3) CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ( + toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index fe8e7ff9..688680c1 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -279,7 +279,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1085,4 +1085,19 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 5f95efd5..1621ef32 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -374,4 +374,54 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = { tag = 'cat', cafood = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Good = { success: true, result: string } +type Bad = { success: false, error: string } +type Result = Good | Bad + +local a: Result = { success = false, result = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index cb72faaf..3ea9b80c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -275,7 +277,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local a = {} @@ -346,7 +348,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -369,7 +371,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local T = {} @@ -476,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -511,7 +513,7 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -771,7 +773,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } @@ -782,7 +784,7 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( function empty() return {} end @@ -1465,7 +1467,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1550,7 +1552,7 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} @@ -1937,7 +1939,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -1952,7 +1954,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1971,7 +1973,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1995,7 +1997,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -2015,11 +2017,22 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); + } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2027,7 +2040,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2048,7 +2061,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, {"LuauExtendedTypeMismatchError", true}, }; @@ -2076,7 +2089,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2092,4 +2105,18 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +{ + ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + + CheckResult result = check(R"( +local b +b = setmetatable({}, {__call = b}) +b() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e3222a41..ad9ea827 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -2084,7 +2085,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) - return a + tonumber(b), a .. b + return a + (tonumber(b) :: number), a .. b end local n, s = add(2,"3") )"); @@ -2485,7 +2486,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") CheckResult result = check(R"( --!strict function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() end )"); @@ -3329,7 +3330,7 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable { CheckResult result = check(R"( local x - print((x == true and (x .. "y")) .. 1) + print((x == true and (x .. "y")) .. 1) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -4473,7 +4474,18 @@ f(function(a, b, c, ...) return a + b end) )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); + + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( @@ -4604,7 +4616,17 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } + else + { + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4799,4 +4821,211 @@ local ModuleA = require(game.A) CHECK_EQ("*unknown*", toString(*oty)); } +/* + * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. + * + * We had an issue here where the scope for the `if` block here would + * have an elevated TypeLevel even though there is no function nesting going on. + * This would result in a free typevar for the type of _ that was much higher than + * it should be. This type would be erroneously quantified in the definition of `aaa`. + * This in turn caused an ice when evaluating `_()` in the while loop. + */ +TEST_CASE_FIXTURE(Fixture, "free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") +{ + check(R"( + --!strict + if _ then + _[_], _ = nil + _() + end + + local aaa = function():typeof(_) return 1 end + + if aaa then + while _() do + end + end + )"); + + // No ice()? No problem. +} + +/* + * This is a bit elaborate. Bear with me. + * + * The type of _ becomes free with the first statement. With the second, we unify it with a function. + * + * At this point, it is important that the newly created fresh types of this new function type are promoted + * to the same level as the original free type. If we do not, they are incorrectly ascribed the level of the + * containing function. + * + * If this is allowed to happen, the final lambda erroneously quantifies the type of _ to something ridiculous + * just before we typecheck the invocation to _. + */ +TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") +{ + check(R"( + l0, _ = nil + + local function p() + _() + end + + a = _( + function():(typeof(p),typeof(_)) + end + )[nil] + )"); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value. 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +{ + ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; + + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 9f9a007f..f55b46a4 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -214,4 +214,32 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } +TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +{ + ScopedFastFlag flags[] = { + {"LuauTableSubtypingVariance2", true}, + }; + // I am not sure how to make this happen in Luau code. + + TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = arena.addType(TableTypeVar{ + {{"prop", Property{getSingletonTypes().numberType}}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + const TableTypeVar* ttv = get(unsealedTable); + REQUIRE(ttv); + + state.tryUnify(unsealedTable, sealedTable); + + // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. + CHECK(!ttv->props.empty()); + + state.log.rollback(); + + CHECK(ttv->props.empty()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 48496b89..b095a0db 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -462,4 +462,20 @@ local a: XYZ = { w = 4 } CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") +{ + ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } + +local a: X? = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); +} + TEST_SUITE_END(); diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua new file mode 100644 index 00000000..f899603f --- /dev/null +++ b/tests/conformance/coverage.lua @@ -0,0 +1,64 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing coverage") + +function foo() + local x = 1 + local y = 2 + assert(x + y) +end + +function bar() + local function one(x) + return x + end + + local two = function(x) + return x + end + + one(1) +end + +function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true +end + +foo() +c = getcoverage(foo) +assert(#c == 1) +assert(c[1].name == "foo") +assert(validate(c[1], {5, 6, 7}, {})) + +bar() +c = getcoverage(bar) +assert(#c == 3) +assert(c[1].name == "bar") +assert(validate(c[1], {11, 15, 19}, {})) +assert(c[2].name == "one") +assert(validate(c[2], {12}, {})) +assert(c[3].name == nil) +assert(validate(c[3], {}, {16})) + +return 'OK' From a9aa4faf24e6cea1ac0e33d0054a7328a35f9d4a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 14:08:56 -0800 Subject: [PATCH 011/102] Sync to upstream/release/508 This version isn't for release because we've skipped some internal numbers due to year-end schedule changes, but it's better to merge separately. --- Analysis/include/Luau/LValue.h | 63 +++++++ Analysis/include/Luau/Predicate.h | 34 +--- Analysis/include/Luau/TypeInfer.h | 1 + Analysis/src/{Predicate.cpp => LValue.cpp} | 88 ++++++++- Analysis/src/TypeInfer.cpp | 83 ++++++++- Analysis/src/TypeVar.cpp | 11 +- Analysis/src/Unifier.cpp | 123 ++++--------- Ast/src/Parser.cpp | 4 - Sources.cmake | 5 +- VM/src/lbaselib.cpp | 8 +- tests/Autocomplete.test.cpp | 6 +- tests/LValue.test.cpp | 198 +++++++++++++++++++++ tests/Predicate.test.cpp | 117 ------------ tests/Symbol.test.cpp | 33 +++- tests/Transpiler.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 12 +- tests/TypeInfer.classes.test.cpp | 2 - tests/TypeInfer.intersectionTypes.test.cpp | 4 - tests/TypeInfer.refinements.test.cpp | 37 +++- tests/TypeInfer.singletons.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 4 - tests/TypeInfer.test.cpp | 13 ++ tests/TypeInfer.unionTypes.test.cpp | 4 - tests/conformance/math.lua | 1 + 24 files changed, 570 insertions(+), 284 deletions(-) create mode 100644 Analysis/include/Luau/LValue.h rename Analysis/src/{Predicate.cpp => LValue.cpp} (50%) create mode 100644 tests/LValue.test.cpp delete mode 100644 tests/Predicate.test.cpp diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h new file mode 100644 index 00000000..8fd96f05 --- /dev/null +++ b/Analysis/include/Luau/LValue.h @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" +#include "Luau/Symbol.h" + +#include // TODO: Kill with LuauLValueAsKey. +#include +#include + +namespace Luau +{ + +struct TypeVar; +using TypeId = const TypeVar*; + +struct Field; +using LValue = Variant; + +struct Field +{ + std::shared_ptr parent; + std::string key; + + bool operator==(const Field& rhs) const; + bool operator!=(const Field& rhs) const; +}; + +struct LValueHasher +{ + size_t operator()(const LValue& lvalue) const; +}; + +const LValue* baseof(const LValue& lvalue); + +std::optional tryGetLValue(const class AstExpr& expr); + +// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +std::pair> getFullName(const LValue& lvalue); + +// Kill with LuauLValueAsKey. +std::string toString(const LValue& lvalue); + +template +const T* get(const LValue& lvalue) +{ + return get_if(&lvalue); +} + +using NEW_RefinementMap = std::unordered_map; +using DEPRECATED_RefinementMap = std::map; + +// Transient. Kill with LuauLValueAsKey. +struct RefinementMap +{ + NEW_RefinementMap NEW_refinements; + DEPRECATED_RefinementMap DEPRECATED_refinements; +}; + +void merge(RefinementMap& l, const RefinementMap& r, std::function f); +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index a5e8b6ae..df93b4f4 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -1,12 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Variant.h" #include "Luau/Location.h" -#include "Luau/Symbol.h" +#include "Luau/LValue.h" +#include "Luau/Variant.h" -#include -#include #include namespace Luau @@ -15,34 +13,6 @@ namespace Luau struct TypeVar; using TypeId = const TypeVar*; -struct Field; -using LValue = Variant; - -struct Field -{ - std::shared_ptr parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait. - std::string key; -}; - -std::optional tryGetLValue(const class AstExpr& expr); - -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -std::pair> getFullName(const LValue& lvalue); - -std::string toString(const LValue& lvalue); - -template -const T* get(const LValue& lvalue) -{ - return get_if(&lvalue); -} - -// Key is a stringified encoding of an LValue. -using RefinementMap = std::map; - -void merge(RefinementMap& l, const RefinementMap& r, std::function f); -void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); - struct TruthyPredicate; struct IsAPredicate; struct TypeGuardPredicate; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 451976e4..862f50d7 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -350,6 +350,7 @@ public: private: std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); + std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/LValue.cpp similarity index 50% rename from Analysis/src/Predicate.cpp rename to Analysis/src/LValue.cpp index 7bd8001e..da6804c6 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/LValue.cpp @@ -1,11 +1,59 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Predicate.h" +#include "Luau/LValue.h" #include "Luau/Ast.h" +#include + +LUAU_FASTFLAG(LuauLValueAsKey) + namespace Luau { +bool Field::operator==(const Field& rhs) const +{ + LUAU_ASSERT(parent && rhs.parent); + return key == rhs.key && (parent == rhs.parent || *parent == *rhs.parent); +} + +bool Field::operator!=(const Field& rhs) const +{ + return !(*this == rhs); +} + +size_t LValueHasher::operator()(const LValue& lvalue) const +{ + // Most likely doesn't produce high quality hashes, but we're probably ok enough with it. + // When an evidence is shown that operator==(LValue) is used more often than it should, we can have a look at improving the hash quality. + size_t acc = 0; + size_t offset = 0; + + const LValue* current = &lvalue; + while (current) + { + if (auto field = get(*current)) + acc ^= (std::hash{}(field->key) << 1) >> ++offset; + else if (auto symbol = get(*current)) + acc ^= std::hash{}(*symbol) << 1; + else + LUAU_ASSERT(!"Hash not accumulated for this new LValue alternative."); + + current = baseof(*current); + } + + return acc; +} + +const LValue* baseof(const LValue& lvalue) +{ + if (auto field = get(lvalue)) + return field->parent.get(); + + auto symbol = get(lvalue); + LUAU_ASSERT(symbol); + return nullptr; // Base of root is null. +} + std::optional tryGetLValue(const AstExpr& node) { const AstExpr* expr = &node; @@ -38,15 +86,15 @@ std::pair> getFullName(const LValue& lvalue) while (auto field = get(*current)) { keys.push_back(field->key); - current = field->parent.get(); - if (!current) - LUAU_ASSERT(!"LValue root is a Field?"); + current = baseof(*current); } const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } +// Kill with LuauLValueAsKey. std::string toString(const LValue& lvalue) { auto [symbol, keys] = getFullName(lvalue); @@ -56,7 +104,18 @@ std::string toString(const LValue& lvalue) return s; } -void merge(RefinementMap& l, const RefinementMap& r, std::function f) +static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::function f) +{ + for (const auto& [k, a] : r) + { + if (auto it = l.find(k); it != l.end()) + l[k] = f(it->second, a); + else + l[k] = a; + } +} + +static void merge(DEPRECATED_RefinementMap& l, const DEPRECATED_RefinementMap& r, std::function f) { auto itL = l.begin(); auto itR = r.begin(); @@ -69,21 +128,32 @@ void merge(RefinementMap& l, const RefinementMap& r, std::functionfirst > k) + else if (itL->first < k) + ++itL; + else { l[k] = a; ++itR; } - else - ++itL; } l.insert(itR, r.end()); } +void merge(RefinementMap& l, const RefinementMap& r, std::function f) +{ + if (FFlag::LuauLValueAsKey) + return merge(l.NEW_refinements, r.NEW_refinements, f); + else + return merge(l.DEPRECATED_refinements, r.DEPRECATED_refinements, f); +} + void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) { - refis[toString(lvalue)] = ty; + if (FFlag::LuauLValueAsKey) + refis.NEW_refinements[lvalue] = ty; + else + refis.DEPRECATED_refinements[toString(lvalue)] = ty; } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index abbc2901..e29b6ec6 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) @@ -1626,6 +1627,10 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + // Not needed when we normalize types. + if (FFlag::LuauLValueAsKey && get(follow(t))) + return t; + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) goodOptions.push_back(*ty); else @@ -4967,13 +4972,83 @@ std::pair, std::vector> TypeChecker::createGener std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { - std::string path = toString(lvalue); + if (!FFlag::LuauLValueAsKey) + return DEPRECATED_resolveLValue(scope, lvalue); + + // We want to be walking the Scope parents. + // We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back. + // For example: + // There exists an entry t.x. + // We are asked to look for t.x.y. + // We need to search in the provided Scope. Find t.x.y first. + // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. + // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. + const auto& [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + std::optional found; + + std::vector childKeys; + const LValue* currentLValue = &lvalue; + while (currentLValue) + { + if (auto it = currentScope->refinements.NEW_refinements.find(*currentLValue); it != currentScope->refinements.NEW_refinements.end()) + { + found = it->second; + break; + } + + childKeys.push_back(*currentLValue); + currentLValue = baseof(*currentLValue); + } + + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = *it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; + } + + // No entry for it at all. Can happen when LValue root is a global. + return std::nullopt; +} + +std::optional TypeChecker::DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue) +{ auto [symbol, keys] = getFullName(lvalue); ScopePtr currentScope = scope; while (currentScope) { - if (auto it = currentScope->refinements.find(path); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.DEPRECATED_refinements.find(toString(lvalue)); it != currentScope->refinements.DEPRECATED_refinements.end()) return it->second; // Should not be using scope->lookup. This is already recursive. @@ -5000,7 +5075,9 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) { - if (auto it = refis.find(toString(lvalue)); it != refis.end()) + if (auto it = refis.DEPRECATED_refinements.find(toString(lvalue)); it != refis.DEPRECATED_refinements.end()) + return it->second; + else if (auto it = refis.NEW_refinements.find(lvalue); it != refis.NEW_refinements.end()) return it->second; else return resolveLValue(scope, lvalue); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 571b13ca..fb75aa02 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -996,18 +996,19 @@ std::optional> magicFunctionFormat( std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); - const size_t dataOffset = 1; + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + dataOffset < params.size(); ++i) + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) { - Location location = expr.args.data[std::min(i, expr.args.size - 1)]->location; + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + dataOffset], location); + typechecker.unify(expected[i], params[i + paramOffset], location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error - const size_t actualParamSize = params.size() - dataOffset; + size_t actualParamSize = params.size() - paramOffset; if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index c5aab856..43ea37e7 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,9 +18,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) -LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) @@ -416,7 +414,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (!innerState.errors.empty()) { // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type)) + if (!firstFailedOption && !isNil(type)) firstFailedOption = {innerState.errors.front()}; failed = true; @@ -434,7 +432,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(*unificationTooComplex); else if (failed) { - if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) + if (firstFailedOption) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -536,49 +534,36 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) errors.push_back( TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); - else if (FFlag::LuauExtendedTypeMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } else if (const IntersectionTypeVar* uv = get(superTy)) { - if (FFlag::LuauExtendedTypeMismatchError) + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) { - std::optional unificationTooComplex; - std::optional firstFailedOption; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; - } - - log.concat(std::move(innerState.log)); + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; } - if (unificationTooComplex) - errors.push_back(*unificationTooComplex); - else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); - } - else - { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) - { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - } + log.concat(std::move(innerState.log)); } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else if (const IntersectionTypeVar* uv = get(subTy)) { @@ -626,10 +611,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(*unificationTooComplex); else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } else if (get(superTy) && get(subTy)) @@ -1241,10 +1223,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, r->second.type); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1261,10 +1240,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, rt->indexer->indexResultType); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1302,10 +1278,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, lt->indexer->indexResultType); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1723,18 +1696,11 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse innerState.tryUnify_(lhs->table, rhs->table); innerState.tryUnify_(lhs->metatable, rhs->metatable); - if (FFlag::LuauExtendedTypeMismatchError) - { - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); - else if (!innerState.errors.empty()) - errors.push_back( - TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); - } - else - { - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); - } + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); log.concat(std::move(innerState.log)); } @@ -1821,31 +1787,22 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { - if (FFlag::LuauExtendedClassMismatchError) + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, classProp->type); + + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + + if (innerState.errors.empty()) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, classProp->type); - - checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - - if (innerState.errors.empty()) - { - log.concat(std::move(innerState.log)); - } - else - { - ok = false; - innerState.log.rollback(); - } + log.concat(std::move(innerState.log)); } else { - tryUnify_(prop.type, classProp->type); + ok = false; + innerState.log.rollback(); } } } @@ -2185,8 +2142,6 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) { - LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError); - if (auto e = hasUnificationTooComplex(innerErrors)) errors.push_back(*e); else if (!innerErrors.empty()) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index dd24f27c..72f61649 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau { @@ -1368,9 +1367,6 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme parameterStart = lexer.current(); - if (!FFlag::LuauParseGenericFunctionTypeBegin) - begin = parameterStart; - expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; diff --git a/Sources.cmake b/Sources.cmake index 14834b3a..a7153eb3 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -45,6 +45,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/IostreamHelpers.h Analysis/include/Luau/JsonEncoder.h Analysis/include/Luau/Linter.h + Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h @@ -80,8 +81,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/IostreamHelpers.cpp Analysis/src/JsonEncoder.cpp Analysis/src/Linter.cpp + Analysis/src/LValue.cpp Analysis/src/Module.cpp - Analysis/src/Predicate.cpp Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp @@ -194,10 +195,10 @@ if(TARGET Luau.UnitTest) tests/Frontend.test.cpp tests/JsonEncoder.test.cpp tests/Linter.test.cpp + tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Parser.test.cpp - tests/Predicate.test.cpp tests/RequireTracer.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 881c804d..988fd315 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -36,12 +36,14 @@ static int luaB_tonumber(lua_State* L) int base = luaL_optinteger(L, 2, 10); if (base == 10) { /* standard conversion */ - luaL_checkany(L, 1); - if (lua_isnumber(L, 1)) + int isnum = 0; + double n = lua_tonumberx(L, 1, &isnum); + if (isnum) { - lua_pushnumber(L, lua_tonumber(L, 1)); + lua_pushnumber(L, n); return 1; } + luaL_checkany(L, 1); /* error if we don't have any argument */ } else { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 62a9999b..8ca09c0e 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1394,7 +1394,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types") check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); @@ -1422,7 +1422,7 @@ return target(bar1, b@1 check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); @@ -1918,7 +1918,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp new file mode 100644 index 00000000..8a092779 --- /dev/null +++ b/tests/LValue.test.cpp @@ -0,0 +1,198 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) +{ + Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { + // TODO: normalize here also. + std::unordered_set s; + + if (auto utv = get(follow(a))) + s.insert(begin(utv), end(utv)); + else + s.insert(a); + + if (auto utv = get(follow(b))) + s.insert(begin(utv), end(utv)); + else + s.insert(b); + + std::vector options(s.begin(), s.end()); + return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); + }); +} + +static LValue mkSymbol(const std::string& s) +{ + return Symbol{AstName{s.data()}}; +} + +TEST_SUITE_BEGIN("LValue"); + +TEST_CASE("Luau_merge_hashmap_order") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); +} + +TEST_CASE("Luau_merge_hashmap_order2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); +} + +TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + std::string d = "d"; + std::string e = "e"; + + RefinementMap m{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().numberType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + RefinementMap other{{ + {mkSymbol(c), getSingletonTypes().stringType}, + {mkSymbol(d), getSingletonTypes().numberType}, + {mkSymbol(e), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(5, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE(m.NEW_refinements.count(mkSymbol(d))); + REQUIRE(m.NEW_refinements.count(mkSymbol(e))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | string", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(d)])); + CHECK_EQ("boolean", toString(m.NEW_refinements[mkSymbol(e)])); +} + +TEST_CASE("hashing_lvalue_global_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + LValue t_x1{Field{std::make_shared(Symbol{AstName{t1.data()}}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + LValue t_x2{Field{std::make_shared(Symbol{AstName{t2.data()}}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_EQ(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + NEW_RefinementMap m; + m[t_x1] = getSingletonTypes().stringType; + m[t_x2] = getSingletonTypes().numberType; + + CHECK_EQ(1, m.size()); +} + +TEST_CASE("hashing_lvalue_local_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + AstLocal localt1{AstName{t1.data()}, Location(), nullptr, 0, 0, nullptr}; + LValue t_x1{Field{std::make_shared(Symbol{&localt1}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + AstLocal localt2{AstName{t2.data()}, Location(), &localt1, 0, 0, nullptr}; + LValue t_x2{Field{std::make_shared(Symbol{&localt2}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_NE(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + NEW_RefinementMap m; + m[t_x1] = getSingletonTypes().stringType; + m[t_x2] = getSingletonTypes().numberType; + + CHECK_EQ(2, m.size()); +} + +TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp deleted file mode 100644 index 7081693e..00000000 --- a/tests/Predicate.test.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeInfer.h" - -#include "Fixture.h" -#include "ScopedFlags.h" - -#include "doctest.h" - -using namespace Luau; - -static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) -{ - Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { - // TODO: normalize here also. - std::unordered_set s; - - if (auto utv = get(follow(a))) - s.insert(begin(utv), end(utv)); - else - s.insert(a); - - if (auto utv = get(follow(b))) - s.insert(begin(utv), end(utv)); - else - s.insert(b); - - std::vector options(s.begin(), s.end()); - return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); - }); -} - -TEST_SUITE_BEGIN("Predicate"); - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") -{ - RefinementMap m{ - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") -{ - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") -{ - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.numberType}, - {"c", typeChecker.booleanType}, - }; - - RefinementMap other{ - {"c", typeChecker.stringType}, - {"d", typeChecker.numberType}, - {"e", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(5, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - REQUIRE(m.count("d")); - REQUIRE(m.count("e")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("number", toString(m["b"])); - CHECK_EQ("boolean | string", toString(m["c"])); - CHECK_EQ("number", toString(m["d"])); - CHECK_EQ("boolean", toString(m["e"])); -} - -TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index 44fe3a3c..e7d2973b 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; TEST_SUITE_BEGIN("SymbolTests"); -TEST_CASE("hashing") +TEST_CASE("hashing_globals") { std::string s1 = "name"; std::string s2 = "name"; @@ -31,10 +31,37 @@ TEST_CASE("hashing") CHECK_EQ(std::hash()(two), std::hash()(two)); std::unordered_map theMap; - theMap[AstName{s1.data()}] = 5; - theMap[AstName{s2.data()}] = 1; + theMap[n1] = 5; + theMap[n2] = 1; REQUIRE_EQ(1, theMap.size()); } +TEST_CASE("hashing_locals") +{ + std::string s1 = "name"; + std::string s2 = "name"; + + // These two names point to distinct memory areas. + AstLocal one{AstName{s1.data()}, Location(), nullptr, 0, 0, nullptr}; + AstLocal two{AstName{s2.data()}, Location(), &one, 0, 0, nullptr}; + + Symbol n1{&one}; + Symbol n2{&two}; + + CHECK(n1 == n1); + CHECK(n1 != n2); + CHECK(n2 == n2); + + CHECK_EQ(std::hash()(&one), std::hash()(&one)); + CHECK_NE(std::hash()(&one), std::hash()(&two)); + CHECK_EQ(std::hash()(&two), std::hash()(&two)); + + std::unordered_map theMap; + theMap[n1] = 5; + theMap[n2] = 1; + + REQUIRE_EQ(2, theMap.size()); +} + TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 327fa0bb..47c3883c 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -555,8 +555,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") { - ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true); - std::string code = R"( local function foo(a: T, ...: S...) return 1 end local f: (T, S...)->(number) = foo diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1d8135d4..506279b9 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -798,13 +798,14 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) + string.format("%s%d%s", 1, "hello", true) )"); TypeId stringType = typeChecker.stringType; TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(6, result); CHECK_EQ(Location(Position{1, 26}, Position{1, 27}), result.errors[0].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[0].data); @@ -814,6 +815,15 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(Location(Position{1, 38}, Position{1, 42}), result.errors[2].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); + + CHECK_EQ(Location(Position{2, 32}, Position{2, 33}), result.errors[3].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[3].data); + + CHECK_EQ(Location(Position{2, 35}, Position{2, 42}), result.errors[4].location); + CHECK_EQ(TypeErrorData(TypeMismatch{numberType, stringType}), result.errors[4].data); + + CHECK_EQ(Location(Position{2, 44}, Position{2, 48}), result.errors[5].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 1ff23fe6..0283ae19 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -449,8 +449,6 @@ b.X = 2 -- real Vector2.X is also read-only TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") { - ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true}; - CheckResult result = check(R"( local function foo(v) return v.X :: number + string.len(v.Y) diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 893bc2b3..93c0baf6 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -343,8 +343,6 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -363,8 +361,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 688680c1..503b613f 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -280,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -1085,6 +1084,41 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + local foo: string? = "hi" + assert(foo) + local foo: number = 5 + print(foo:sub(1, 1)) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = {x: string | number} + local t: T? = {x = "hi"} + if t then + if type(t.x) == "string" then + local foo = t.x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; @@ -1092,6 +1126,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip CheckResult result = check(R"( type T = { [string]: { prop: number }? } local t: T = {} + if t["hello"] then local foo = t["hello"].prop end diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 1621ef32..68dc1b4f 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -202,7 +202,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauExtendedTypeMismatchError", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 3ea9b80c..80f40407 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1955,7 +1955,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( type A = { x: number, y: number } @@ -1974,7 +1973,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( type AS = { x: number, y: number } @@ -1998,7 +1996,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2062,7 +2059,6 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, - {"LuauExtendedTypeMismatchError", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index ad9ea827..76324556 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5013,6 +5013,19 @@ caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } +TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + local function f(thing: any | string) + local foo = thing.SomeRandomKey + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") { ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index b095a0db..2357869e 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -425,8 +425,6 @@ y = x TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -446,8 +444,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index d5bca44f..bfea0e1f 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -26,6 +26,7 @@ function f(...) end end +assert(pcall(tonumber) == false) assert(tonumber{} == nil) assert(tonumber'+0.01' == 1/100 and tonumber'+.01' == 0.01 and tonumber'.01' == 0.01 and tonumber'-1.' == -1 and From 44ccd8282244228a3a3108cb8e5237c45b9d92c2 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 14:10:07 -0800 Subject: [PATCH 012/102] Sync to upstream/release/509 --- .gitignore | 18 +- Analysis/include/Luau/Frontend.h | 10 +- Analysis/include/Luau/TxnLog.h | 280 ++- Analysis/include/Luau/TypeInfer.h | 40 +- Analysis/include/Luau/TypePack.h | 8 + Analysis/include/Luau/TypeVar.h | 1 + Analysis/include/Luau/TypedAllocator.h | 23 +- Analysis/include/Luau/Unifier.h | 48 +- Analysis/src/Autocomplete.cpp | 33 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 7 +- Analysis/src/Frontend.cpp | 28 +- Analysis/src/Module.cpp | 2 +- Analysis/src/Quantify.cpp | 14 +- Analysis/src/ToString.cpp | 48 +- Analysis/src/TxnLog.cpp | 295 ++- Analysis/src/TypeInfer.cpp | 863 ++++--- Analysis/src/TypePack.cpp | 51 +- Analysis/src/TypeVar.cpp | 19 +- Analysis/src/TypedAllocator.cpp | 1 + Analysis/src/Unifier.cpp | 2342 +++++++++++++------ Ast/include/Luau/Common.h | 8 +- CLI/Analyze.cpp | 2 +- CLI/Repl.cpp | 46 +- CLI/Web.cpp | 24 +- Compiler/include/Luau/Bytecode.h | 1 + Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 69 +- Compiler/src/Compiler.cpp | 7 +- Sources.cmake | 1 + VM/include/luaconf.h | 4 - VM/src/lapi.cpp | 34 +- VM/src/laux.cpp | 38 +- VM/src/lbitlib.cpp | 8 - VM/src/ldebug.cpp | 17 +- VM/src/ldo.cpp | 29 +- VM/src/lgc.cpp | 2 - VM/src/lgcdebug.cpp | 7 +- VM/src/lnumprint.cpp | 375 +++ VM/src/lnumutils.h | 7 +- VM/src/lobject.h | 1 + VM/src/ltable.cpp | 6 +- VM/src/lvmload.cpp | 28 +- VM/src/lvmutils.cpp | 7 +- fuzz/number.cpp | 35 + tests/AstQuery.test.cpp | 2 - tests/Autocomplete.test.cpp | 43 +- tests/Conformance.test.cpp | 31 +- tests/Fixture.cpp | 5 - tests/Fixture.h | 9 - tests/Frontend.test.cpp | 7 +- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.builtins.test.cpp | 37 +- tests/TypeInfer.generics.test.cpp | 2 - tests/TypeInfer.test.cpp | 2 - tests/TypeInfer.tryUnify.test.cpp | 54 +- tests/TypeInfer.typePacks.cpp | 14 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tests/TypeVar.test.cpp | 4 +- tests/conformance/debug.lua | 9 + tests/conformance/strconv.lua | 51 + tests/main.cpp | 23 +- tools/numprint.py | 82 + 62 files changed, 3901 insertions(+), 1375 deletions(-) create mode 100644 VM/src/lnumprint.cpp create mode 100644 fuzz/number.cpp create mode 100644 tests/conformance/strconv.lua create mode 100644 tools/numprint.py diff --git a/.gitignore b/.gitignore index fa11b45b..5688dff5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ -^build/ -^coverage/ -^fuzz/luau.pb.* -^crash-* -^default.prof* -^fuzz-* -^luau$ -/.vs +/build/ +/build[.-]*/ +/coverage/ +/.vs/ +/.vscode/ +/fuzz/luau.pb.* +/crash-* +/default.prof* +/fuzz-* +/luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 07a0296a..1f64db30 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -68,7 +68,7 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, one in the regular mode, ond once in strict mode + // When true, we run typechecking twice, once in the regular mode, and once in strict mode // in order to get more precise type information (e.g. for autocomplete). bool typecheckTwice = false; }; @@ -109,18 +109,18 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); - CheckResult check(const ModuleName& name); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); + CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); /** Lint some code that has no associated DataModel object * * Since this source fragment has no name, we cannot cache its AST. Instead, * we return it to the caller to use as they wish. */ - std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); + std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); CheckResult check(const SourceModule& module); // OLD. TODO KILL - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 29988a3b..dc45bebf 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -1,7 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include +#include + #include "Luau/TypeVar.h" +#include "Luau/TypePack.h" LUAU_FASTFLAG(LuauShareTxnSeen); @@ -9,27 +13,28 @@ namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be -struct TxnLog +// Remove with LuauUseCommitTxnLog +struct DEPRECATED_TxnLog { - TxnLog() + DEPRECATED_TxnLog() : originalSeenSize(0) , ownedSeen() , sharedSeen(&ownedSeen) { } - explicit TxnLog(std::vector>* sharedSeen) + explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) : originalSeenSize(sharedSeen->size()) , ownedSeen() , sharedSeen(sharedSeen) { } - TxnLog(const TxnLog&) = delete; - TxnLog& operator=(const TxnLog&) = delete; + DEPRECATED_TxnLog(const DEPRECATED_TxnLog&) = delete; + DEPRECATED_TxnLog& operator=(const DEPRECATED_TxnLog&) = delete; - TxnLog(TxnLog&&) = default; - TxnLog& operator=(TxnLog&&) = default; + DEPRECATED_TxnLog(DEPRECATED_TxnLog&&) = default; + DEPRECATED_TxnLog& operator=(DEPRECATED_TxnLog&&) = default; void operator()(TypeId a); void operator()(TypePackId a); @@ -37,7 +42,7 @@ struct TxnLog void rollback(); - void concat(TxnLog rhs); + void concat(DEPRECATED_TxnLog rhs); bool haveSeen(TypeId lhs, TypeId rhs); void pushSeen(TypeId lhs, TypeId rhs); @@ -54,4 +59,263 @@ public: std::vector>* sharedSeen; // shared with all the descendent logs }; +// Pending state for a TypeVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingType +{ + // The pending TypeVar state. + TypeVar pending; + + explicit PendingType(TypeVar state) + : pending(std::move(state)) + { + } +}; + +// Pending state for a TypePackVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingTypePack +{ + // The pending TypePackVar state. + TypePackVar pending; + + explicit PendingTypePack(TypePackVar state) + : pending(std::move(state)) + { + } +}; + +template +T* getMutable(PendingType* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +template +T* getMutable(PendingTypePack* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +// Log of what TypeIds we are rebinding, to be committed later. +struct TxnLog +{ + TxnLog() + : ownedSeen() + , sharedSeen(&ownedSeen) + { + } + + explicit TxnLog(TxnLog* parent) + : parent(parent) + { + if (parent) + { + sharedSeen = parent->sharedSeen; + } + else + { + sharedSeen = &ownedSeen; + } + } + + explicit TxnLog(std::vector>* sharedSeen) + : sharedSeen(sharedSeen) + { + } + + TxnLog(TxnLog* parent, std::vector>* sharedSeen) + : parent(parent) + , sharedSeen(sharedSeen) + { + } + + TxnLog(const TxnLog&) = delete; + TxnLog& operator=(const TxnLog&) = delete; + + TxnLog(TxnLog&&) = default; + TxnLog& operator=(TxnLog&&) = default; + + // Gets an empty TxnLog pointer. This is useful for constructs that + // take a TxnLog, like TypePackIterator - use the empty log if you + // don't have a TxnLog to give it. + static const TxnLog* empty(); + + // Joins another TxnLog onto this one. You should use std::move to avoid + // copying the rhs TxnLog. + // + // If both logs talk about the same type, pack, or table, the rhs takes + // priority. + void concat(TxnLog rhs); + + // Commits the TxnLog, rebinding all type pointers to their pending states. + // Clears the TxnLog afterwards. + void commit(); + + // Clears the TxnLog without committing any pending changes. + void clear(); + + // Computes an inverse of this TxnLog at the current time. + // This method should be called before commit is called in order to give an + // accurate result. Committing the inverse of a TxnLog will undo the changes + // made by commit, assuming the inverse log is accurate. + TxnLog inverse(); + + bool haveSeen(TypeId lhs, TypeId rhs) const; + void pushSeen(TypeId lhs, TypeId rhs); + void popSeen(TypeId lhs, TypeId rhs); + + // Queues a type for modification. The original type will not change until commit + // is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* queue(TypeId ty); + + // Queues a type pack for modification. The original type pack will not change + // until commit is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* queue(TypePackId tp); + + // Returns the pending state of a type, or nullptr if there isn't any. It is important + // to note that this pending state is not transitive: the pending state may reference + // non-pending types freely, so you may need to call pending multiple times to view the + // entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* pending(TypeId ty) const; + + // Returns the pending state of a type pack, or nullptr if there isn't any. It is + // important to note that this pending state is not transitive: the pending state may + // reference non-pending types freely, so you may need to call pending multiple times + // to view the entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* pending(TypePackId tp) const; + + // Queues a replacement of a type with another type. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* replace(TypeId ty, TypeVar replacement); + + // Queues a replacement of a type pack with another type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* replace(TypePackId tp, TypePackVar replacement); + + // Queues a replacement of a table type with another table type that is bound + // to a specific value. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* bindTable(TypeId ty, std::optional newBoundTo); + + // Queues a replacement of a type with a level with a duplicate of that type + // with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeLevel(TypeId ty, TypeLevel newLevel); + + // Queues a replacement of a type pack with a level with a duplicate of that + // type pack with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* changeLevel(TypePackId tp, TypeLevel newLevel); + + // Queues a replacement of a table type with another table type with a new + // indexer. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeIndexer(TypeId ty, std::optional indexer); + + // Returns the type level of the pending state of the type, or the level of that + // type, if no pending state exists. If the type doesn't have a notion of a level, + // returns nullopt. If the pending state doesn't have a notion of a level, but the + // original state does, returns nullopt. + std::optional getLevel(TypeId ty) const; + + // Follows a type, accounting for pending type states. The returned type may have + // pending state; you should use `pending` or `get` to find out. + TypeId follow(TypeId ty); + + // Follows a type pack, accounting for pending type states. The returned type pack + // may have pending state; you should use `pending` or `get` to find out. + TypePackId follow(TypePackId tp) const; + + // Replaces a given type's state with a new variant. Returns the new pending state + // of that type. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingType* replace(TypeId ty, T replacement) + { + return replace(ty, TypeVar(replacement)); + } + + // Replaces a given type pack's state with a new variant. Returns the new + // pending state of that type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingTypePack* replace(TypePackId tp, T replacement) + { + return replace(tp, TypePackVar(replacement)); + } + + // Returns T if a given type or type pack is this variant, respecting the + // log's pending state. + // + // Do not retain this pointer; it has the potential to be invalidated when + // commit or clear is called. + template + T* getMutable(TID ty) const + { + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::getMutable(pendingTy); + + return Luau::getMutable(ty); + } + + // Returns whether a given type or type pack is a given state, respecting the + // log's pending state. + // + // This method will not assert if called on a BoundTypeVar or BoundTypePack. + template + bool is(TID ty) const + { + // We do not use getMutable here because this method can be called on + // BoundTypeVars, which triggers an assertion. + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::get_if(&pendingTy->pending.ty) != nullptr; + + return Luau::get_if(&ty->ty) != nullptr; + } + +private: + // unique_ptr is used to give us stable pointers across insertions into the + // map. Otherwise, it would be really easy to accidentally invalidate the + // pointers returned from queue/pending. + // + // We can't use a DenseHashMap here because we need a non-const iterator + // over the map when we concatenate. + std::unordered_map> typeVarChanges; + std::unordered_map> typePackChanges; + + TxnLog* parent = nullptr; + + // Owned version of sharedSeen. This should not be accessed directly in + // TxnLogs; use sharedSeen instead. This field exists because in the tree + // of TxnLogs, the root must own its seen set. In all descendant TxnLogs, + // this is an empty vector. + std::vector> ownedSeen; + +public: + // Used to avoid infinite recursion when types are cyclic. + // Shared with all the descendent TxnLogs. + std::vector>* sharedSeen; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 862f50d7..312283b0 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -198,32 +198,32 @@ struct TypeChecker */ TypeId anyIfNonstrict(TypeId ty) const; - /** Attempt to unify the types left and right. Treat any failures as type errors - * in the final typecheck report. + /** Attempt to unify the types. + * Treat any failures as type errors in the final typecheck report. */ - bool unify(TypeId left, TypeId right, const Location& location); - bool unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify(TypeId subTy, TypeId superTy, const Location& location); + bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); - /** Attempt to unify the types left and right. - * If this fails, and the right type can be instantiated, do so and try unification again. + /** Attempt to unify the types. + * If this fails, and the subTy type can be instantiated, do so and try unification again. */ - bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location); - void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state); + bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location); + void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state); - /** Attempt to unify left with right. + /** Attempt to unify. * If there are errors, undo everything and return the errors. * If there are no errors, commit and return an empty error vector. */ - ErrorVec tryUnify(TypeId left, TypeId right, const Location& location); - ErrorVec tryUnify(TypePackId left, TypePackId right, const Location& location); + template + ErrorVec tryUnify_(Id subTy, Id superTy, const Location& location); + ErrorVec tryUnify(TypeId subTy, TypeId superTy, const Location& location); + ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const Location& location); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy, const Location& location); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, const Location& location); - - // Variant that takes a preexisting 'seen' set. We need this in certain cases to avoid infinitely recursing - // into cyclic types. - ErrorVec canUnify(const std::vector>& seen, TypeId left, TypeId right, const Location& location); + template + ErrorVec canUnify_(Id subTy, Id superTy, const Location& location); + ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -237,12 +237,6 @@ struct TypeChecker std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); - template - ErrorVec tryUnify_(Id left, Id right, const Location& location); - - template - ErrorVec canUnify_(Id left, Id right, const Location& location); - public: /* * Convert monotype into a a polytype, by replacing any metavariables in descendant scopes diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index e72808da..ca588ccb 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -18,6 +18,8 @@ struct VariadicTypePack; struct TypePackVar; +struct TxnLog; + using TypePackId = const TypePackVar*; using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; @@ -84,6 +86,7 @@ struct TypePackIterator TypePackIterator() = default; explicit TypePackIterator(TypePackId tp); + TypePackIterator(TypePackId tp, const TxnLog* log); TypePackIterator& operator++(); TypePackIterator operator++(int); @@ -104,9 +107,13 @@ private: TypePackId currentTypePack = nullptr; const TypePack* tp = nullptr; size_t currentIndex = 0; + + // Only used if LuauUseCommittingTxnLog is true. + const TxnLog* log; }; TypePackIterator begin(TypePackId tp); +TypePackIterator begin(TypePackId tp, TxnLog* log); TypePackIterator end(TypePackId tp); using SeenSet = std::set>; @@ -114,6 +121,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); +TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp); bool finite(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index f6829ec3..d6e17714 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -453,6 +453,7 @@ bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real TypeId follow(TypeId t); +TypeId follow(TypeId t, std::function mapper); std::vector flattenIntersection(TypeId ty); diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index 0ded1489..64227e7c 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauTypedAllocatorZeroStart) + namespace Luau { @@ -20,7 +22,10 @@ class TypedAllocator public: TypedAllocator() { - appendBlock(); + if (FFlag::LuauTypedAllocatorZeroStart) + currentBlockSize = kBlockSize; + else + appendBlock(); } ~TypedAllocator() @@ -59,12 +64,18 @@ public: bool empty() const { - return stuff.size() == 1 && currentBlockSize == 0; + if (FFlag::LuauTypedAllocatorZeroStart) + return stuff.empty(); + else + return stuff.size() == 1 && currentBlockSize == 0; } size_t size() const { - return kBlockSize * (stuff.size() - 1) + currentBlockSize; + if (FFlag::LuauTypedAllocatorZeroStart) + return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; + else + return kBlockSize * (stuff.size() - 1) + currentBlockSize; } void clear() @@ -72,7 +83,11 @@ public: if (frozen) unfreeze(); free(); - appendBlock(); + + if (FFlag::LuauTypedAllocatorZeroStart) + currentBlockSize = kBlockSize; + else + appendBlock(); } void freeze() diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7681b966..a3be739a 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -25,6 +25,7 @@ struct Unifier Mode mode; ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. + DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; ErrorVec errors; Location location; @@ -33,44 +34,45 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog = nullptr); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState); + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + ErrorVec canUnify(TypeId subTy, TypeId superTy); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); - /** Attempt to unify left with right. + /** Attempt to unify. * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); private: - void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); - void tryUnifyPrimitives(TypeId superTy, TypeId subTy); - void tryUnifySingletons(TypeId superTy, TypeId subTy); - void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void tryUnifyFreeTable(TypeId free, TypeId other); - void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); - void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); - void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); - void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyPrimitives(TypeId subTy, TypeId superTy); + void tryUnifySingletons(TypeId subTy, TypeId superTy); + void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyFreeTable(TypeId subTy, TypeId superTy); + void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection); + void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); - void cacheResult(TypeId superTy, TypeId subTy); + void cacheResult(TypeId subTy, TypeId superTy); public: - void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); private: - void tryUnify_(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); - void tryUnifyVariadics(TypePackId superTy, TypePackId subTy, bool reversed, int subOffset = 0); + void tryUnify_(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); + void tryUnifyVariadics(TypePackId subTy, TypePackId superTy, bool reversed, int subOffset = 0); - void tryUnifyWithAny(TypeId any, TypeId ty); - void tryUnifyWithAny(TypePackId any, TypePackId ty); + void tryUnifyWithAny(TypeId subTy, TypeId anyTy); + void tryUnifyWithAny(TypePackId subTy, TypePackId anyTp); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 4b583792..67ebd075 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,10 +12,12 @@ #include #include +LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); +LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -236,28 +238,31 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { ty = follow(ty); - auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { + auto canUnify = [&typeArena, &module](TypeId subTy, TypeId superTy) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - if (FFlag::LuauAutocompleteAvoidMutation) + if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) { SeenTypes seenTypes; SeenTypePacks seenTypePacks; CloneState cloneState; - expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, cloneState); - actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, cloneState); + superTy = clone(superTy, *typeArena, seenTypes, seenTypePacks, cloneState); + subTy = clone(subTy, *typeArena, seenTypes, seenTypePacks, cloneState); - auto errors = unifier.canUnify(expectedType, actualType); + auto errors = unifier.canUnify(subTy, superTy); return errors.empty(); } else { - unifier.tryUnify(expectedType, actualType); + unifier.tryUnify(subTy, superTy); bool ok = unifier.errors.empty(); - unifier.log.rollback(); + + if (!FFlag::LuauUseCommittingTxnLog) + unifier.DEPRECATED_log.rollback(); + return ok; } }; @@ -293,22 +298,22 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty() && canUnify(expectedType, retHead.front())) + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(expectedType, vtp->ty)) + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) return TypeCorrectKind::CorrectFunctionResult; } } - return canUnify(expectedType, ty) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } else { - if (canUnify(expectedType, ty)) + if (canUnify(ty, expectedType)) return TypeCorrectKind::Correct; // We also want to suggest functions that return compatible result @@ -320,13 +325,13 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto [retHead, retTail] = flatten(ftv->retType); if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; } return TypeCorrectKind::None; @@ -1319,7 +1324,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is()) + if (!nodes.back()->is() && (!FFlag::LuauCompleteBrokenStringParams || !nodes.back()->is())) { return std::nullopt; } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index d0afa742..24982506 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -138,12 +138,7 @@ declare function gcinfo(): number -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - -- a userdata object is "roughly" the same as a sealed empty table - -- except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - -- another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - -- setmetatable. - -- FIXME: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - declare function newproxy(mt: boolean?): {} + declare function newproxy(mt: boolean?): any declare coroutine: { create: ((A...) -> R...) -> thread, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index e332f07d..fe4b6529 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -351,7 +351,7 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) { } -CheckResult Frontend::check(const ModuleName& name) +CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -372,6 +372,8 @@ CheckResult Frontend::check(const ModuleName& name) std::vector buildQueue; bool cycleDetected = parseGraph(buildQueue, checkResult, name); + FrontendOptions frontendOptions = optionOverride.value_or(options); + // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; @@ -411,31 +413,11 @@ CheckResult Frontend::check(const ModuleName& name) // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (options.typecheckTwice) + if (frontendOptions.typecheckTwice) { ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; } - else if (options.retainFullTypeGraphs && options.typecheckTwice && mode != Mode::Strict) - { - ModulePtr strictModule = typeChecker.check(sourceModule, Mode::Strict, environmentScope); - module->astTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astExpectedTypes.clear(); - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - - for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - } stats.timeCheck += getTimestamp() - timestamp; stats.filesStrict += mode == Mode::Strict; @@ -444,7 +426,7 @@ CheckResult Frontend::check(const ModuleName& name) if (module == nullptr) throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); - if (!options.retainFullTypeGraphs) + if (!frontendOptions.retainFullTypeGraphs) { // copyErrors needs to allocate into interfaceTypes as it copies // types out of internalTypes, so we unfreeze it here. diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index e1e53c97..cff85897 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) namespace Luau { diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index c773e208..04ebffc1 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,8 +4,6 @@ #include "Luau/VisitTypeVar.h" -LUAU_FASTFLAGVARIABLE(LuauQuantifyVisitOnce, false) - namespace Luau { @@ -81,16 +79,8 @@ struct Quantifier void quantify(ModulePtr module, TypeId ty, TypeLevel level) { Quantifier q{std::move(module), level}; - - if (FFlag::LuauQuantifyVisitOnce) - { - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, q, seen); - } - else - { - visitTypeVar(ty, q); - } + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, q, seen); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index a6be5348..889dd6dc 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) /* * Prefix generic typenames with gen- @@ -766,24 +765,12 @@ struct TypePackStringifier else state.emit(", "); - if (FFlag::LuauFunctionArgumentNameSize) + if (elemIndex < elemNames.size() && elemNames[elemIndex]) { - if (elemIndex < elemNames.size() && elemNames[elemIndex]) - { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); - } + state.emit(elemNames[elemIndex]->name); + state.emit(": "); } - else - { - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - if (!elemNames.empty() && elemNames[elemIndex]) - { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); - } - } elemIndex++; stringify(typeId); @@ -1151,38 +1138,19 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV s += ", "; first = false; - if (FFlag::LuauFunctionArgumentNameSize) + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) { - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - ++argNameIter; - } - else - { - s += "_: "; - } + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + ++argNameIter; } else { - // argNames is guaranteed to be equal to argTypes iff argNames is not empty. - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (!ftv.argNames.empty()) - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += "_: "; } s += toString_(*argPackIter); ++argPackIter; - - if (!FFlag::LuauFunctionArgumentNameSize) - { - if (!ftv.argNames.empty()) - { - LUAU_ASSERT(argNameIter != ftv.argNames.end()); - ++argNameIter; - } - } } if (argPackIter.tail()) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index f6a61581..a46ac0c3 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -4,27 +4,34 @@ #include "Luau/TypePack.h" #include +#include + +LUAU_FASTFLAGVARIABLE(LuauUseCommittingTxnLog, false) namespace Luau { -void TxnLog::operator()(TypeId a) +void DEPRECATED_TxnLog::operator()(TypeId a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typeVarChanges.emplace_back(a, *a); } -void TxnLog::operator()(TypePackId a) +void DEPRECATED_TxnLog::operator()(TypePackId a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typePackChanges.emplace_back(a, *a); } -void TxnLog::operator()(TableTypeVar* a) +void DEPRECATED_TxnLog::operator()(TableTypeVar* a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); tableChanges.emplace_back(a, a->boundTo); } -void TxnLog::rollback() +void DEPRECATED_TxnLog::rollback() { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) std::swap(*asMutable(it->first), it->second); @@ -38,8 +45,9 @@ void TxnLog::rollback() sharedSeen->resize(originalSeenSize); } -void TxnLog::concat(TxnLog rhs) +void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); rhs.typeVarChanges.clear(); @@ -50,23 +58,298 @@ void TxnLog::concat(TxnLog rhs) rhs.tableChanges.clear(); } -bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) +bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } +void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + sharedSeen->push_back(sortedPair); +} + +void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); +} + +static const TxnLog emptyLog; + +const TxnLog* TxnLog::empty() +{ + return &emptyLog; +} + +void TxnLog::concat(TxnLog rhs) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (auto& [ty, rep] : rhs.typeVarChanges) + typeVarChanges[ty] = std::move(rep); + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + +void TxnLog::commit() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (auto& [ty, rep] : typeVarChanges) + *asMutable(ty) = rep.get()->pending; + + for (auto& [tp, rep] : typePackChanges) + *asMutable(tp) = rep.get()->pending; + + clear(); +} + +void TxnLog::clear() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + typeVarChanges.clear(); + typePackChanges.clear(); +} + +TxnLog TxnLog::inverse() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + TxnLog inversed(sharedSeen); + + for (auto& [ty, _rep] : typeVarChanges) + inversed.typeVarChanges[ty] = std::make_unique(*ty); + + for (auto& [tp, _rep] : typePackChanges) + inversed.typePackChanges[tp] = std::make_unique(*tp); + + return inversed; +} + +bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } + + if (parent) + { + return parent->haveSeen(lhs, rhs); + } + + return false; +} + void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } +PendingType* TxnLog::queue(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(!ty->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typeVarChanges[ty]; + if (!pending) + pending = std::make_unique(*ty); + + return pending.get(); +} + +PendingTypePack* TxnLog::queue(TypePackId tp) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(!tp->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typePackChanges[tp]; + if (!pending) + pending = std::make_unique(*tp); + + return pending.get(); +} + +PendingType* TxnLog::pending(TypeId ty) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) + return it->second.get(); + } + + return nullptr; +} + +PendingTypePack* TxnLog::pending(TypePackId tp) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) + return it->second.get(); + } + + return nullptr; +} + +PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + PendingType* newTy = queue(ty); + newTy->pending = replacement; + return newTy; +} + +PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + PendingTypePack* newTp = queue(tp); + newTp->pending = replacement; + return newTp; +} + +PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + ttv->boundTo = newBoundTo; + + return newTy; +} + +PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + + PendingType* newTy = queue(ty); + if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); + ttv->level = newLevel; + } + else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + + return newTy; +} + +PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(tp)); + + PendingTypePack* newTp = queue(tp); + if (FreeTypePack* ftp = Luau::getMutable(newTp)) + { + ftp->level = newLevel; + } + + return newTp; +} + +PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + ttv->indexer = indexer; + } + + return newTy; +} + +std::optional TxnLog::getLevel(TypeId ty) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + if (FreeTypeVar* ftv = getMutable(ty)) + return ftv->level; + else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) + return ttv->level; + else if (FunctionTypeVar* ftv = getMutable(ty)) + return ftv->level; + + return std::nullopt; +} + +TypeId TxnLog::follow(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + return Luau::follow(ty, [this](TypeId ty) { + PendingType* state = this->pending(ty); + + if (state == nullptr) + return ty; + + // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants + // that normally apply. This is safe because follow will only call get<> + // on the returned pointer. + return const_cast(&state->pending); + }); +} + +TypePackId TxnLog::follow(TypePackId tp) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + return Luau::follow(tp, [this](TypePackId tp) { + PendingTypePack* state = this->pending(tp); + + if (state == nullptr) + return tp; + + // Ugly: Fabricate a TypePackId that doesn't adhere to most of the + // invariants that normally apply. This is safe because follow will + // only call get<> on the returned pointer. + return const_cast(&state->pending); + }); +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e29b6ec6..1689a5c3 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,15 +27,16 @@ LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) +LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) -LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) @@ -450,7 +451,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) ++subLevel; TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); - unify(leftType, funTy, fun->location); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) { @@ -556,21 +557,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) } } -ErrorVec TypeChecker::canUnify(TypeId left, TypeId right, const Location& location) -{ - return canUnify_(left, right, location); -} - -ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location& location) -{ - return canUnify_(left, right, location); -} - template -ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) +ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const Location& location) { Unifier state = mkUnifier(location); - return state.canUnify(superTy, subTy); + return state.canUnify(subTy, superTy); +} + +ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const Location& location) +{ + return canUnify_(subTy, superTy, location); +} + +ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Location& location) +{ + return canUnify_(subTy, superTy, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) @@ -619,7 +620,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) { - ErrorVec errors = tryUnify(scope->returnType, retPack, return_.location); + ErrorVec errors = tryUnify(retPack, scope->returnType, return_.location); if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); @@ -627,31 +628,41 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) return; } - unify(scope->returnType, retPack, return_.location, CountMismatch::Context::Return); -} - -ErrorVec TypeChecker::tryUnify(TypeId left, TypeId right, const Location& location) -{ - return tryUnify_(left, right, location); -} - -ErrorVec TypeChecker::tryUnify(TypePackId left, TypePackId right, const Location& location) -{ - return tryUnify_(left, right, location); + unify(retPack, scope->returnType, return_.location, CountMismatch::Context::Return); } template -ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) +ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) { Unifier state = mkUnifier(location); - state.tryUnify(left, right); - if (!state.errors.empty()) - state.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + freeze(currentModule->internalTypes); + + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + unfreeze(currentModule->internalTypes); + + if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + + if (state.errors.empty() && FFlag::LuauUseCommittingTxnLog) + state.log.commit(); return state.errors; } +ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const Location& location) +{ + return tryUnify_(subTy, superTy, location); +} + +ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Location& location) +{ + return tryUnify_(subTy, superTy, location); +} + void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; @@ -743,9 +754,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) - unify(left, anyType, loc); + unify(anyType, left, loc); else - unify(left, right, loc); + unify(right, left, loc); } } } @@ -760,7 +771,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); - unify(left, result, assign.location); + unify(result, left, assign.location); } void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) @@ -817,9 +828,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) Unifier state = mkUnifier(local.location); state.ctx = CountMismatch::Result; - state.tryUnify(variablePack, valuePack); + state.tryUnify(valuePack, variablePack); reportErrors(state.errors); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. // We also want to do this for 'local T = setmetatable(...)'. if (local.vars.size == 1 && local.values.size == 1) @@ -889,7 +903,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) TypeId loopVarType = numberType; if (expr.var->annotation) - unify(resolveType(scope, *expr.var->annotation), loopVarType, expr.location); + unify(loopVarType, resolveType(scope, *expr.var->annotation), expr.location); loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; @@ -899,11 +913,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) if (!expr.to) ice("Bad AstStatFor has no to expr"); - unify(loopVarType, checkExpr(loopScope, *expr.from).type, expr.from->location); - unify(loopVarType, checkExpr(loopScope, *expr.to).type, expr.to->location); + unify(checkExpr(loopScope, *expr.from).type, loopVarType, expr.from->location); + unify(checkExpr(loopScope, *expr.to).type, loopVarType, expr.to->location); if (expr.step) - unify(loopVarType, checkExpr(loopScope, *expr.step).type, expr.step->location); + unify(checkExpr(loopScope, *expr.step).type, loopVarType, expr.step->location); check(loopScope, *expr.body); } @@ -956,12 +970,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) if (get(callRetPack)) { iterTy = freshType(scope); - unify(addTypePack({{iterTy}, freshTypePack(scope)}), callRetPack, forin.location); + unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), forin.location); } else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorRecoveryType(scope), forin.location); + unify(errorRecoveryType(scope), var, forin.location); return check(loopScope, *forin.body); } @@ -982,7 +996,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) - unify(var, varTy, forin.location); + unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); @@ -1010,6 +1024,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) Unifier state = mkUnifier(firstValue->location); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + reportErrors(state.errors); } @@ -1024,10 +1041,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; TypePackId retPack = checkExprPack(scope, exprCall).type; - unify(varPack, retPack, forin.location); + unify(retPack, varPack, forin.location); } else - unify(varPack, iterFunc->retType, forin.location); + unify(iterFunc->retType, varPack, forin.location); check(loopScope, *forin.body); } @@ -1112,7 +1129,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - unify(leftType, ty, function.location); + unify(ty, leftType, function.location); if (FFlag::LuauUpdateFunctionNameBinding) { @@ -1242,7 +1259,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else if (auto mtv = getMutable(follow(ty))) mtv->syntheticName = name; - unify(bindingsMap[name].type, ty, typealias.location); + unify(ty, bindingsMap[name].type, typealias.location); } } @@ -1526,7 +1543,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); - unify(retPack, pack, expr.location); + unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } if (get(retPack)) @@ -1598,7 +1615,7 @@ std::optional TypeChecker::getIndexTypeFromType( return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(indexer->indexType, stringType, location); + tryUnify(stringType, indexer->indexType, location); return indexer->indexResultType; } else if (tableType->state == TableState::Free) @@ -1824,7 +1841,7 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) - unify(indexer->indexResultType, valueType, value->location); + unify(valueType, indexer->indexResultType, value->location); else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -1842,13 +1859,13 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(expectedProp.type, exprType, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type, k->location); if (errors.empty()) exprType = expectedProp.type; } else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) { - ErrorVec errors = tryUnify(expectedTable->indexer->indexResultType, exprType, k->location); + ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) exprType = expectedTable->indexer->indexResultType; } @@ -1863,8 +1880,8 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - unify(indexer->indexType, keyType, k->location); - unify(indexer->indexResultType, valueType, value->location); + unify(keyType, indexer->indexType, k->location); + unify(valueType, indexer->indexResultType, value->location); } else if (isNonstrictMode()) { @@ -1992,7 +2009,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) @@ -2006,7 +2026,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn return {errorRecoveryType(scope)}; } - reportErrors(tryUnify(numberType, operandType, expr.location)); + reportErrors(tryUnify(operandType, numberType, expr.location)); return {numberType}; } case AstExprUnary::Len: @@ -2072,7 +2092,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b { if (unifyFreeTypes && (get(a) || get(b))) { - if (unify(a, b, location)) + if (unify(b, a, location)) return a; return errorRecoveryType(anyType); @@ -2175,7 +2195,13 @@ TypeId TypeChecker::checkRelationalOperation( */ Unifier state = mkUnifier(expr.location); if (!isEquality) - state.tryUnify(lhsType, rhsType); + { + state.tryUnify(rhsType, lhsType); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + } + bool needsMetamethod = !isEquality; @@ -2216,13 +2242,16 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(ftv->retType, addTypePack({booleanType})); + state.tryUnify(addTypePack({booleanType}), ftv->retType); if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); return errorRecoveryType(booleanType); } + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); } } @@ -2230,7 +2259,10 @@ TypeId TypeChecker::checkRelationalOperation( TypeId actualFunctionType = addType(FunctionTypeVar(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( - instantiate(scope, *metamethod, expr.location), instantiate(scope, actualFunctionType, expr.location), /*isFunctionCall*/ true); + instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return booleanType; @@ -2323,7 +2355,7 @@ TypeId TypeChecker::checkBinaryOperation( } if (get(rhsType)) - unify(lhsType, rhsType, expr.location); + unify(rhsType, lhsType, expr.location); if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { @@ -2334,7 +2366,7 @@ TypeId TypeChecker::checkBinaryOperation( TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); bool hasErrors = !state.errors.empty(); @@ -2345,11 +2377,28 @@ TypeId TypeChecker::checkBinaryOperation( // so we loosen the argument types to see if that helps. TypePackId fallbackArguments = freshTypePack(scope); TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); - state.log.rollback(); state.errors.clear(); - state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); - if (!state.errors.empty()) - state.log.rollback(); + + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.clear(); + } + else + { + state.DEPRECATED_log.rollback(); + } + + state.tryUnify(actualFunctionType, fallbackFunctionType, /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + state.log.commit(); + else if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + } + + if (FFlag::LuauUseCommittingTxnLog && !hasErrors) + { + state.log.commit(); } TypeId retType = first(retTypePack).value_or(nilType); @@ -2377,8 +2426,8 @@ TypeId TypeChecker::checkBinaryOperation( switch (expr.op) { case AstExprBinary::Concat: - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), lhsType, expr.left->location)); - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.left->location)); + reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.right->location)); return stringType; case AstExprBinary::Add: case AstExprBinary::Sub: @@ -2386,8 +2435,8 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Div: case AstExprBinary::Mod: case AstExprBinary::Pow: - reportErrors(tryUnify(numberType, lhsType, expr.left->location)); - reportErrors(tryUnify(numberType, rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, numberType, expr.left->location)); + reportErrors(tryUnify(rhsType, numberType, expr.right->location)); return numberType; default: // These should have been handled with checkRelationalOperation @@ -2466,10 +2515,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy if (FFlag::LuauBidirectionalAsExpr) { // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (canUnify(result.type, annotationType, expr.location).empty()) + if (canUnify(annotationType, result.type, expr.location).empty()) return {annotationType, std::move(result.predicates)}; - if (canUnify(annotationType, result.type, expr.location).empty()) + if (canUnify(result.type, annotationType, expr.location).empty()) return {annotationType, std::move(result.predicates)}; reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); @@ -2477,7 +2526,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy } else { - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + ErrorVec errorVec = canUnify(annotationType, result.type, expr.location); reportErrors(errorVec); if (!errorVec.empty()) annotationType = errorRecoveryType(annotationType); @@ -2512,7 +2561,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); - unify(trueType.type, falseType.type, expr.location); + unify(falseType.type, trueType.type, expr.location); // TODO: normalize(UnionTypeVar{{trueType, falseType}}) // For now both trueType and falseType must be the same type. @@ -2607,14 +2656,18 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope else if (auto indexer = lhsTable->indexer) { Unifier state = mkUnifier(expr.location); - state.tryUnify(indexer->indexType, stringType); + state.tryUnify(stringType, indexer->indexType); TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + reportError(expr.location, UnknownProperty{lhs, name}); retType = errorRecoveryType(retType); } + else if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); return std::pair(retType, nullptr); } @@ -2713,7 +2766,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (exprTable->indexer) { const TableIndexer& indexer = *exprTable->indexer; - unify(indexer.indexType, indexType, expr.index->location); + unify(indexType, indexer.indexType, expr.index->location); return std::pair(indexer.indexResultType, nullptr); } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) @@ -3106,204 +3159,402 @@ void TypeChecker::checkArgumentList( * A function requires parameters. * To call a function, you supply arguments. */ - TypePackIterator argIter = begin(argPack); - TypePackIterator paramIter = begin(paramPack); + TypePackIterator argIter = begin(argPack, &state.log); + TypePackIterator paramIter = begin(paramPack, &state.log); TypePackIterator endIter = end(argPack); // Important subtlety: All end TypePackIterators are equivalent size_t paramIndex = 0; size_t minParams = getMinParameterCount(paramPack); - while (true) + if (FFlag::LuauUseCommittingTxnLog) { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - - if (argIter == endIter && paramIter == endIter) + while (true) { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. - - if (argTail) + if (argIter == endIter && paramIter == endIter) { - if (get(*argTail)) + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) { - if (paramTail) - state.tryUnify(*argTail, *paramTail); + if (state.log.getMutable(state.log.follow(*argTail))) + { + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + state.log.replace(*argTail, TypePackVar(TypePack{{}})); + } + } + else if (paramTail) + { + // argTail is definitely empty + if (state.log.getMutable(state.log.follow(*paramTail))) + state.log.replace(*paramTail, TypePackVar(TypePack{{}})); + } + + return; + } + else if (argIter == endIter) + { + // Not enough arguments. + + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) + { + TypePackId tail = *argIter.tail(); + if (state.log.getMutable(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(errorRecoveryType(anyType), *paramIter); + ++paramIter; + } + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(vtp->ty, *paramIter); + ++paramIter; + } + + return; + } + else if (state.log.getMutable(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } + } + + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = state.log.follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (state.log.getMutable(t)) + { + } // ok + else if (isNonstrictMode() && state.log.getMutable(t)) + { + } // ok else { - state.log(*argTail); - *asMutable(*argTail) = TypePack{{}}; + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; + } + } + else if (paramIter == endIter) + { + // too many parameters passed + if (!paramIter.tail()) + { + while (argIter != endIter) + { + // The use of unify here is deliberate. We don't want this unification + // to be undoable. + unify(errorRecoveryType(scope), *argIter, state.location); + ++argIter; + } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + TypePackId tail = state.log.follow(*paramIter.tail()); + + if (state.log.getMutable(tail)) + { + // Function is variadic. Ok. + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) + { + Location location = state.location; + + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; + } + + return; + } + else if (state.log.getMutable(tail)) + { + // Create a type pack out of the remaining argument types + // and unify it with the tail. + std::vector rest; + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) + { + rest.push_back(*argIter); + ++argIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } + else if (state.log.getMutable(tail)) + { + state.log.replace(tail, TypePackVar(TypePack{{}})); + return; + } + else if (state.log.getMutable(tail)) + { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + } + else + { + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; + ++paramIter; + } + + ++paramIndex; + } + } + else + { + while (true) + { + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + + if (argIter == endIter && paramIter == endIter) + { + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) + { + if (get(*argTail)) + { + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + { + state.DEPRECATED_log(*argTail); + *asMutable(*argTail) = TypePack{{}}; + } } } - } - else if (paramTail) - { - // argTail is definitely empty - if (get(*paramTail)) + else if (paramTail) { - state.log(*paramTail); - *asMutable(*paramTail) = TypePack{{}}; + // argTail is definitely empty + if (get(*paramTail)) + { + state.DEPRECATED_log(*paramTail); + *asMutable(*paramTail) = TypePack{{}}; + } + } + + return; + } + else if (argIter == endIter) + { + // Not enough arguments. + + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) + { + TypePackId tail = *argIter.tail(); + if (get(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(*paramIter, errorRecoveryType(anyType)); + ++paramIter; + } + return; + } + else if (auto vtp = get(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(*paramIter, vtp->ty); + ++paramIter; + } + + return; + } + else if (get(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } + } + + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (get(t)) + { + } // ok + else if (isNonstrictMode() && get(t)) + { + } // ok + else + { + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; } } - - return; - } - else if (argIter == endIter) - { - // Not enough arguments. - - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) + else if (paramIter == endIter) { - TypePackId tail = *argIter.tail(); + // too many parameters passed + if (!paramIter.tail()) + { + while (argIter != endIter) + { + unify(*argIter, errorRecoveryType(scope), state.location); + ++argIter; + } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + TypePackId tail = *paramIter.tail(); + if (get(tail)) { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) - { - state.tryUnify(*paramIter, errorRecoveryType(anyType)); - ++paramIter; - } + // Function is variadic. Ok. return; } else if (auto vtp = get(tail)) { - while (paramIter != endIter) + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) { - state.tryUnify(*paramIter, vtp->ty); - ++paramIter; + Location location = state.location; + + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; } return; } else if (get(tail)) { + // Create a type pack out of the remaining argument types + // and unify it with the tail. std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) { - rest.push_back(*paramIter); - ++paramIter; + rest.push_back(*argIter); + ++argIter; } - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); state.tryUnify(tail, varPack); return; } - } + else if (get(tail)) + { + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.replace(tail, TypePackVar(TypePack{{}})); + } + else + { + state.DEPRECATED_log(tail); + *asMutable(tail) = TypePack{}; + } - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = follow(*paramIter); - if (isOptional(t)) - { - } // ok - else if (get(t)) - { - } // ok - else if (isNonstrictMode() && get(t)) - { - } // ok - else - { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } + else if (get(tail)) + { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + } + else + { + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; ++paramIter; } + + ++paramIndex; } - else if (paramIter == endIter) - { - // too many parameters passed - if (!paramIter.tail()) - { - while (argIter != endIter) - { - unify(*argIter, errorRecoveryType(scope), state.location); - ++argIter; - } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = *paramIter.tail(); - - if (get(tail)) - { - // Function is variadic. Ok. - return; - } - else if (auto vtp = get(tail)) - { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) - { - Location location = state.location; - - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; - - unify(vtp->ty, *argIter, location); - ++argIter; - ++argIndex; - } - - return; - } - else if (get(tail)) - { - // Create a type pack out of the remaining argument types - // and unify it with the tail. - std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) - { - rest.push_back(*argIter); - ++argIter; - } - - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(tail, varPack); - return; - } - else if (get(tail)) - { - state.log(tail); - *asMutable(tail) = TypePack{}; - - return; - } - else if (get(tail)) - { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - } - else - { - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); - ++argIter; - ++paramIter; - } - - ++paramIndex; } } @@ -3475,7 +3726,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - unify(argPack, anyTypePack, expr.location); + unify(anyTypePack, argPack, expr.location); return {{anyTypePack}}; } @@ -3490,7 +3741,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - unify(r, fn, expr.location); + unify(fn, r, expr.location); return {{retPack}}; } @@ -3533,7 +3784,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!ftv) { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + unify(errorRecoveryTypePack(scope), retPack, expr.func->location); return {{errorRecoveryTypePack(retPack)}}; } @@ -3552,7 +3803,9 @@ std::optional> TypeChecker::checkCallOverload(const Scope checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); if (!state.errors.empty()) { - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + return {}; } @@ -3580,10 +3833,15 @@ std::optional> TypeChecker::checkCallOverload(const Scope overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); - state.log.rollback(); + + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); } else { + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND @@ -3640,6 +3898,9 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { + if (FFlag::LuauUseCommittingTxnLog) + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); // This is a little bit suspect: If this overload would work with a . replaced by a : // we eagerly assume that that's what you actually meant and we commit to it. @@ -3648,8 +3909,8 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); + else if (!FFlag::LuauUseCommittingTxnLog) + editedState.DEPRECATED_log.rollback(); } else if (ftv->hasSelf) { @@ -3671,6 +3932,9 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { + if (FFlag::LuauUseCommittingTxnLog) + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionRequiresSelf{}}); // This is a little bit suspect: If this overload would work with a : replaced by a . // we eagerly assume that that's what you actually meant and we commit to it. @@ -3679,8 +3943,8 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); + else if (!FFlag::LuauUseCommittingTxnLog) + editedState.DEPRECATED_log.rollback(); } } } @@ -3740,6 +4004,9 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } + if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + state.log.commit(); + if (i > 0) s += "; "; @@ -3748,7 +4015,8 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast s += toString(overload); - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); } if (overloadsThatMatchArgCount.size() == 0) @@ -3781,6 +4049,8 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L Unifier state = mkUnifier(location); + std::vector inverseLogs; + for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; @@ -3791,18 +4061,15 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (FFlag::LuauTailArgumentTypeInfo) + if (std::optional firstTy = first(typePack)) { - if (std::optional firstTy = first(typePack)) - { - if (!currentModule->astTypes.find(expr)) - currentModule->astTypes[expr] = follow(*firstTy); - } - - if (expectedType) - currentModule->astExpectedTypes[expr] = *expectedType; + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); } + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + tp->tail = typePack; } else @@ -3816,13 +4083,31 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L actualType = instantiate(scope, actualType, expr->location); if (expectedType) - state.tryUnify(*expectedType, actualType); + { + state.tryUnify(actualType, *expectedType); + + // Ugly: In future iterations of the loop, we might need the state of the unification we + // just performed. There's not a great way to pass that into checkExpr. Instead, we store + // the inverse of the current log, and commit it. When we're done, we'll commit all the + // inverses. This isn't optimal, and a better solution is welcome here. + if (FFlag::LuauUseCommittingTxnLog) + { + inverseLogs.push_back(state.log.inverse()); + state.log.commit(); + } + } tp->head.push_back(actualType); } } - state.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog) + { + for (TxnLog& log : inverseLogs) + log.commit(); + } + else + state.DEPRECATED_log.rollback(); return {pack, predicates}; } @@ -3884,7 +4169,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module TypePackId modulePack = module->getModuleScope()->returnType; - if (FFlag::LuauModuleRequireErrorPack && get(modulePack)) + if (get(modulePack)) return errorRecoveryType(scope); std::optional moduleType = first(modulePack); @@ -3917,72 +4202,94 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const return ty; } -bool TypeChecker::unify(TypeId left, TypeId right, const Location& location) +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) { Unifier state = mkUnifier(location); - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx) +bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx) { Unifier state = mkUnifier(location); state.ctx = ctx; - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) +bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location) { Unifier state = mkUnifier(location); - unifyWithInstantiationIfNeeded(scope, left, right, state); + unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) +void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state) { - if (!maybeGeneric(right)) + if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate - state.tryUnify(left, right, /*isFunctionCall*/ false); - else if (!maybeGeneric(left) && isGeneric(right)) + state.tryUnify(subTy, superTy, /*isFunctionCall*/ false); + else if (!maybeGeneric(superTy) && isGeneric(subTy)) { // Quick check to see if we definitely have to instantiate - TypeId instantiated = instantiate(scope, right, state.location); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + TypeId instantiated = instantiate(scope, subTy, state.location); + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } else { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. Unifier child = state.makeChildUnifier(); - child.tryUnify(left, right, /*isFunctionCall*/ false); + child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); if (!child.errors.empty()) { - TypeId instantiated = instantiate(scope, right, state.location); - if (right == instantiated) + TypeId instantiated = instantiate(scope, subTy, state.location); + if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - state.log.concat(std::move(child.log)); + if (FFlag::LuauUseCommittingTxnLog) + state.log.concat(std::move(child.log)); + else + state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); } else { - child.log.rollback(); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + if (!FFlag::LuauUseCommittingTxnLog) + child.DEPRECATED_log.rollback(); + + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } } else { - state.log.concat(std::move(child.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.concat(std::move(child.log)); + } + else + { + state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + } } } } @@ -4139,7 +4446,7 @@ TypePackId Quantification::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { if (const TableTypeVar* ttv = get(ty)) - return (ttv->state == TableState::Free); + return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); else if (get(ty)) return true; else @@ -4162,6 +4469,12 @@ TypeId Anyification::clean(TypeId ty) TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; + if (FFlag::LuauSealExports) + { + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; + } return addType(std::move(clone)); } else @@ -5194,8 +5507,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement // This by itself is not truly enough to determine that A is stronger than B or vice versa. // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. if (!optionIsSubtype && targetIsSubtype) @@ -5379,7 +5692,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa for (TypeId right : rhs) { // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) set.insert(left); } } @@ -5406,7 +5719,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp for (size_t i = 0; i < expectedLength; ++i) expectedPack->head.push_back(freshType(scope)); - unify(expectedTypePack, tp, location); + unify(tp, expectedTypePack, location); for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index d3221c73..b15548a8 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,8 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/TxnLog.h" + #include +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + namespace Luau { @@ -35,14 +39,28 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) } TypePackIterator::TypePackIterator(TypePackId typePack) + : TypePackIterator(typePack, TxnLog::empty()) +{ +} + +TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) : currentTypePack(follow(typePack)) , tp(get(currentTypePack)) , currentIndex(0) + , log(log) { while (tp && tp->head.empty()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + if (FFlag::LuauUseCommittingTxnLog) + { + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + } + else + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } } } @@ -53,8 +71,17 @@ TypePackIterator& TypePackIterator::operator++() ++currentIndex; while (tp && currentIndex >= tp->head.size()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + if (FFlag::LuauUseCommittingTxnLog) + { + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + } + else + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } + currentIndex = 0; } @@ -95,6 +122,11 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } +TypePackIterator begin(TypePackId tp, TxnLog* log) +{ + return TypePackIterator{tp, log}; +} + TypePackIterator end(TypePackId tp) { return TypePackIterator{}; @@ -160,8 +192,15 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - auto advance = [](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(ty)) + return follow(tp, [](TypePackId t) { + return t; + }); +} + +TypePackId follow(TypePackId tp, std::function mapper) +{ + auto advance = [&mapper](TypePackId ty) -> std::optional { + if (const Unifiable::Bound* btv = get>(mapper(ty))) return btv->boundTo; else return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index fb75aa02..4cab79c8 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -31,17 +31,24 @@ std::optional> magicFunctionFormat( TypeId follow(TypeId t) { - auto advance = [](TypeId ty) -> std::optional { - if (auto btv = get>(ty)) + return follow(t, [](TypeId t) { + return t; + }); +} + +TypeId follow(TypeId t, std::function mapper) +{ + auto advance = [&mapper](TypeId ty) -> std::optional { + if (auto btv = get>(mapper(ty))) return btv->boundTo; - else if (auto ttv = get(ty)) + else if (auto ttv = get(mapper(ty))) return ttv->boundTo; else return std::nullopt; }; - auto force = [](TypeId ty) { - if (auto ltv = get_if(&ty->ty)) + auto force = [&mapper](TypeId ty) { + if (auto ltv = get_if(&mapper(ty)->ty)) { TypeId res = ltv->thunk(); if (get(res)) @@ -1004,7 +1011,7 @@ std::optional> magicFunctionFormat( { Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + paramOffset], location); + typechecker.unify(params[i + paramOffset], expected[i], location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index f037351e..1f7ef8c2 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -20,6 +20,7 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE); #include LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAGVARIABLE(LuauTypedAllocatorZeroStart, false) namespace Luau { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 43ea37e7..393a84a7 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,6 +13,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); +LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) @@ -29,27 +30,39 @@ namespace Luau struct PromoteTypeLevels { + DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; TypeLevel minLevel; - explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) - : log(log) + explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel) + : DEPRECATED_log(DEPRECATED_log) + , log(log) , minLevel(minLevel) - {} + { + } - template + template void promote(TID ty, T* t) { LUAU_ASSERT(t); if (minLevel.subsumesStrict(t->level)) { - log(ty); - t->level = minLevel; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeLevel(ty, minLevel); + } + else + { + DEPRECATED_log(ty); + t->level = minLevel; + } } } template - void cycle(TID) {} + void cycle(TID) + { + } template bool operator()(TID, const T&) @@ -59,39 +72,47 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FreeTypeVar&) { - promote(ty, getMutable(ty)); + // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauUseCommittingTxnLog && !log.is(ty)) + return true; + + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypeId ty, const FunctionTypeVar&) { - promote(ty, getMutable(ty)); + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } - bool operator()(TypeId ty, const TableTypeVar&) + bool operator()(TypeId ty, const TableTypeVar& ttv) { - promote(ty, getMutable(ty)); + if (ttv.state != TableState::Free && ttv.state != TableState::Generic) + return true; + + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypePackId tp, const FreeTypePack&) { - promote(tp, getMutable(tp)); + promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); return true; } }; -void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypeId ty) { - PromoteTypeLevels ptl{log, minLevel}; + PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } -void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypePackId tp) { - PromoteTypeLevels ptl{log, minLevel}; + PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } @@ -221,10 +242,12 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog) : types(types) , mode(mode) , globalScope(std::move(globalScope)) + , log(parentLog) , location(location) , variance(variance) , sharedState(sharedState) @@ -233,11 +256,12 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState) + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(sharedSeen) + , DEPRECATED_log(sharedSeen) + , log(parentLog, sharedSeen) , location(location) , variance(variance) , sharedState(sharedState) @@ -245,14 +269,14 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector< LUAU_ASSERT(sharedState.iceHandler); } -void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { sharedState.counters.iterationCount = 0; - tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(subTy, superTy, isFunctionCall, isIntersection); } -void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -264,55 +288,112 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } - superTy = follow(superTy); - subTy = follow(subTy); + if (FFlag::LuauUseCommittingTxnLog) + { + superTy = log.follow(superTy); + subTy = log.follow(subTy); + } + else + { + superTy = follow(superTy); + subTy = follow(subTy); + } if (superTy == subTy) return; - auto l = getMutable(superTy); - auto r = getMutable(subTy); + auto superFree = getMutable(superTy); + auto subFree = getMutable(subTy); - if (l && r && l->level.subsumes(r->level)) + if (FFlag::LuauUseCommittingTxnLog) + { + superFree = log.getMutable(superTy); + subFree = log.getMutable(subTy); + } + + if (superFree && subFree && superFree->level.subsumes(subFree->level)) { occursCheck(subTy, superTy); // The occurrence check might have caused superTy no longer to be a free type - if (!get(subTy)) + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(subTy)); + else + occursFailed = bool(get(subTy)); + + if (!occursFailed) { - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + if (FFlag::LuauUseCommittingTxnLog) + { + log.replace(subTy, BoundTypeVar(superTy)); + } + else + { + DEPRECATED_log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } } return; } - else if (l && r) + else if (superFree && subFree) { - if (!FFlag::LuauErrorRecoveryType) - log(superTy); - occursCheck(superTy, subTy); - r->level = min(r->level, l->level); - - // The occurrence check might have caused superTy no longer to be a free type - if (!FFlag::LuauErrorRecoveryType) - *asMutable(superTy) = BoundTypeVar(subTy); - else if (!get(superTy)) + if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) + { + DEPRECATED_log(superTy); + subFree->level = min(subFree->level, superFree->level); + } + + occursCheck(superTy, subTy); + + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(superTy)); + else + occursFailed = bool(get(superTy)); + + if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) { - log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); + return; + } + + if (!occursFailed) + { + if (FFlag::LuauUseCommittingTxnLog) + { + if (superFree->level.subsumes(subFree->level)) + { + log.changeLevel(subTy, superFree->level); + } + + log.replace(superTy, BoundTypeVar(subTy)); + } + else + { + DEPRECATED_log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + subFree->level = min(subFree->level, superFree->level); + } } return; } - else if (l) + else if (superFree) { occursCheck(superTy, subTy); + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(superTy)); + else + occursFailed = bool(get(superTy)); - TypeLevel superLevel = l->level; + TypeLevel superLevel = superFree->level; // Unification can't change the level of a generic. - auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) + auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); + if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -320,63 +401,83 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } // The occurrence check might have caused superTy no longer to be a free type - if (!get(superTy)) + if (!occursFailed) { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(log, superLevel, subTy); - else if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!rightLevel->subsumes(l->level)) - *rightLevel = l->level; + promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + log.replace(superTy, BoundTypeVar(subTy)); } + else + { + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + else if (auto subLevel = getMutableLevel(subTy)) + { + if (!subLevel->subsumes(superFree->level)) + *subLevel = superFree->level; + } - log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); + DEPRECATED_log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } } return; } - else if (r) + else if (subFree) { - TypeLevel subLevel = r->level; + TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(subTy)); + else + occursFailed = bool(get(subTy)); // Unification can't change the level of a generic. - auto leftGeneric = get(superTy); - if (leftGeneric && !leftGeneric->level.subsumes(r->level)) + auto superGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy); + if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } - if (!get(subTy)) + if (!occursFailed) { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(log, subLevel, superTy); - - if (auto superLevel = getMutableLevel(superTy)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!superLevel->subsumes(r->level)) - { - log(superTy); - *superLevel = r->level; - } + promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + log.replace(subTy, BoundTypeVar(superTy)); } + else + { + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + else if (auto superLevel = getMutableLevel(superTy)) + { + if (!superLevel->subsumes(subFree->level)) + { + DEPRECATED_log(superTy); + *superLevel = subFree->level; + } + } - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + DEPRECATED_log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } } return; } if (get(superTy) || get(superTy)) - return tryUnifyWithAny(superTy, subTy); + return tryUnifyWithAny(subTy, superTy); if (get(subTy) || get(subTy)) - return tryUnifyWithAny(subTy, superTy); + return tryUnifyWithAny(superTy, subTy); bool cacheEnabled = !isFunctionCall && !isIntersection; auto& cache = sharedState.cachedUnify; @@ -389,12 +490,22 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. - if (log.haveSeen(superTy, subTy)) - return; + if (FFlag::LuauUseCommittingTxnLog) + { + if (log.haveSeen(superTy, subTy)) + return; - log.pushSeen(superTy, subTy); + log.pushSeen(superTy, subTy); + } + else + { + if (DEPRECATED_log.haveSeen(superTy, subTy)) + return; - if (const UnionTypeVar* uv = get(subTy)) + DEPRECATED_log.pushSeen(superTy, subTy); + } + + if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { // A | B <: T if A <: T and B <: T bool failed = false; @@ -407,7 +518,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool for (TypeId type : uv->options) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type); + innerState.tryUnify_(type, superTy); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; @@ -420,10 +531,24 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool failed = true; } - if (i != count - 1) - innerState.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog) + { + if (i == count - 1) + { + log.concat(std::move(innerState.log)); + } + } else - log.concat(std::move(innerState.log)); + { + if (i != count - 1) + { + innerState.DEPRECATED_log.rollback(); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + } ++i; } @@ -438,7 +563,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } } - else if (const UnionTypeVar* uv = get(superTy)) + else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { // T <: A | B if T <: A or T <: B bool found = false; @@ -502,12 +627,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, isFunctionCall); + innerState.tryUnify_(subTy, type, isFunctionCall); if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -522,7 +651,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool failedOption = {innerState.errors.front()}; } - innerState.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -538,7 +668,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } - else if (const IntersectionTypeVar* uv = get(superTy)) + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { std::optional unificationTooComplex; std::optional firstFailedOption; @@ -547,7 +678,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; @@ -557,7 +688,10 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool firstFailedOption = {innerState.errors.front()}; } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } if (unificationTooComplex) @@ -565,7 +699,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (firstFailedOption) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } - else if (const IntersectionTypeVar* uv = get(subTy)) + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { // A & B <: T if T <: A or T <: B bool found = false; @@ -591,12 +726,15 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type, isFunctionCall); + innerState.tryUnify_(type, superTy, isFunctionCall); if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -604,7 +742,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool unificationTooComplex = e; } - innerState.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -614,44 +753,56 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } - else if (get(superTy) && get(subTy)) - tryUnifyPrimitives(superTy, subTy); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) - tryUnifySingletons(superTy, subTy); + else if (FFlag::LuauSingletonTypes && + ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) + tryUnifySingletons(subTy, superTy); - else if (get(superTy) && get(subTy)) - tryUnifyFunctions(superTy, subTy, isFunctionCall); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyFunctions(subTy, superTy, isFunctionCall); - else if (get(superTy) && get(subTy)) + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) { - tryUnifyTables(superTy, subTy, isIntersection); + tryUnifyTables(subTy, superTy, isIntersection); if (cacheEnabled && errors.empty()) - cacheResult(superTy, subTy); + cacheResult(subTy, superTy); } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if (get(superTy)) - tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); - else if (get(subTy)) - tryUnifyWithMetatable(subTy, superTy, /*reversed*/ true); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if (get(superTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if (get(subTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ true); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); - log.popSeen(superTy, subTy); + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superTy, subTy); + else + DEPRECATED_log.popSeen(superTy, subTy); } -void Unifier::cacheResult(TypeId superTy, TypeId subTy) +void Unifier::cacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); @@ -684,7 +835,7 @@ void Unifier::cacheResult(TypeId superTy, TypeId subTy) sharedState.cachedUnify.insert({subTy, superTy}); } -struct WeirdIter +struct DEPRECATED_WeirdIter { TypePackId packId; const TypePack* pack; @@ -692,7 +843,7 @@ struct WeirdIter bool growing; TypeLevel level; - WeirdIter(TypePackId packId) + DEPRECATED_WeirdIter(TypePackId packId) : packId(packId) , pack(get(packId)) , index(0) @@ -705,7 +856,7 @@ struct WeirdIter } } - WeirdIter(const WeirdIter&) = default; + DEPRECATED_WeirdIter(const DEPRECATED_WeirdIter&) = default; const TypeId& operator*() { @@ -756,34 +907,152 @@ struct WeirdIter } }; -ErrorVec Unifier::canUnify(TypeId superTy, TypeId subTy) +struct WeirdIter +{ + TypePackId packId; + TxnLog& log; + TypePack* pack; + size_t index; + bool growing; + TypeLevel level; + + WeirdIter(TypePackId packId, TxnLog& log) + : packId(packId) + , log(log) + , pack(log.getMutable(packId)) + , index(0) + , growing(false) + { + while (pack && pack->head.empty() && pack->tail) + { + packId = *pack->tail; + pack = log.getMutable(packId); + } + } + + WeirdIter(const WeirdIter&) = default; + + TypeId& operator*() + { + LUAU_ASSERT(good()); + return pack->head[index]; + } + + bool good() const + { + return pack != nullptr && index < pack->head.size(); + } + + bool advance() + { + if (!pack) + return good(); + + if (index < pack->head.size()) + ++index; + + if (growing || index < pack->head.size()) + return good(); + + if (pack->tail) + { + packId = log.follow(*pack->tail); + pack = log.getMutable(packId); + index = 0; + } + + return good(); + } + + bool canGrow() const + { + return nullptr != log.getMutable(packId); + } + + void grow(TypePackId newTail) + { + LUAU_ASSERT(canGrow()); + LUAU_ASSERT(log.getMutable(newTail)); + + level = log.getMutable(packId)->level; + log.replace(packId, Unifiable::Bound(newTail)); + packId = newTail; + pack = log.getMutable(newTail); + index = 0; + growing = true; + } + + void pushType(TypeId ty) + { + LUAU_ASSERT(pack); + PendingTypePack* pendingPack = log.queue(packId); + if (TypePack* pending = getMutable(pendingPack)) + { + pending->head.push_back(ty); + // We've potentially just replaced the TypePack* that we need to look + // in. We need to replace pack. + pack = pending; + } + else + { + LUAU_ASSERT(!"Pending state for this pack was not a TypePack"); + } + } +}; + +ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy); - s.log.rollback(); + s.tryUnify_(subTy, superTy); + + if (!FFlag::LuauUseCommittingTxnLog) + s.DEPRECATED_log.rollback(); + return s.errors; } -ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall) +ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy, isFunctionCall); - s.log.rollback(); + s.tryUnify_(subTy, superTy, isFunctionCall); + + if (!FFlag::LuauUseCommittingTxnLog) + s.DEPRECATED_log.rollback(); + return s.errors; } -void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { sharedState.counters.iterationCount = 0; - tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(subTp, superTp, isFunctionCall); +} + +static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; } /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ -void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -795,252 +1064,458 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal return; } - superTp = follow(superTp); - subTp = follow(subTp); - - while (auto r = get(subTp)) + if (FFlag::LuauUseCommittingTxnLog) { - if (r->head.empty() && r->tail) - subTp = follow(*r->tail); - else - break; - } + superTp = log.follow(superTp); + subTp = log.follow(subTp); - while (auto l = get(superTp)) - { - if (l->head.empty() && l->tail) - superTp = follow(*l->tail); - else - break; - } - - if (superTp == subTp) - return; - - if (get(superTp)) - { - occursCheck(superTp, subTp); - - // The occurrence check might have caused superTp no longer to be a free type - if (!get(superTp)) + while (auto tp = log.getMutable(subTp)) { - log(superTp); - *asMutable(superTp) = Unifiable::Bound(subTp); - } - } - else if (get(subTp)) - { - occursCheck(subTp, superTp); - - // The occurrence check might have caused superTp no longer to be a free type - if (!get(subTp)) - { - log(subTp); - *asMutable(subTp) = Unifiable::Bound(superTp); - } - } - - else if (get(superTp)) - tryUnifyWithAny(superTp, subTp); - - else if (get(subTp)) - tryUnifyWithAny(subTp, superTp); - - else if (get(superTp)) - tryUnifyVariadics(superTp, subTp, false); - else if (get(subTp)) - tryUnifyVariadics(subTp, superTp, true); - - else if (get(superTp) && get(subTp)) - { - auto l = get(superTp); - auto r = get(subTp); - - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = flatten(superTp); - auto [subTypes, subTail] = flatten(subTp); - - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - - auto superIter = WeirdIter{superTp}; - auto subIter = WeirdIter{subTp}; - - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; - - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); - - int loopCount = 0; - - do - { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); - - ++loopCount; - - if (superIter.good() && subIter.growing) - asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); - - if (subIter.good() && superIter.growing) - asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); - - if (superIter.good() && subIter.good()) - { - tryUnify_(*superIter, *subIter); - - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - - superIter.advance(); - subIter.advance(); - continue; - } - - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) - { - const bool lFreeTail = l->tail && get(follow(*l->tail)) != nullptr; - const bool rFreeTail = r->tail && get(follow(*r->tail)) != nullptr; - if (lFreeTail && rFreeTail) - tryUnify_(*l->tail, *r->tail); - else if (lFreeTail) - tryUnify_(*l->tail, emptyTp); - else if (rFreeTail) - tryUnify_(*r->tail, emptyTp); - - break; - } - - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*superIter.pack->tail, *subIter.pack->tail); - - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - + if (tp->head.empty() && tp->tail) + subTp = log.follow(*tp->tail); else + break; + } + + while (auto tp = log.getMutable(superTp)) + { + if (tp->head.empty() && tp->tail) + superTp = log.follow(*tp->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (log.getMutable(superTp)) + { + occursCheck(superTp, subTp); + + if (!log.getMutable(superTp)) { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } - - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) - { - superIter.advance(); - continue; - } - - if (get(superIter.packId)) - { - tryUnifyVariadics(superIter.packId, subIter.packId, false, int(subIter.index)); - return; - } - - if (get(subIter.packId)) - { - tryUnifyVariadics(subIter.packId, superIter.packId, true, int(superIter.index)); - return; - } - - if (!isFunctionCall && subIter.good()) - { - // Sometimes it is ok to pass too many arguments - return; - } - - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); - - while (superIter.good()) - { - tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); - superIter.advance(); - } - - while (subIter.good()) - { - tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); - subIter.advance(); - } - - return; + log.replace(superTp, Unifiable::Bound(subTp)); } + } + else if (log.getMutable(subTp)) + { + occursCheck(subTp, superTp); - } while (!noInfiniteGrowth); + if (!log.getMutable(subTp)) + { + log.replace(subTp, Unifiable::Bound(superTp)); + } + } + else if (log.getMutable(superTp)) + tryUnifyWithAny(subTp, superTp); + else if (log.getMutable(subTp)) + tryUnifyWithAny(superTp, subTp); + else if (log.getMutable(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (log.getMutable(subTp)) + tryUnifyVariadics(superTp, subTp, true); + else if (log.getMutable(superTp) && log.getMutable(subTp)) + { + auto superTpv = log.getMutable(superTp); + auto subTpv = log.getMutable(subTp); + + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = logAwareFlatten(superTp, log); + auto [subTypes, subTail] = logAwareFlatten(subTp, log); + + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + + auto superIter = WeirdIter(superTp, log); + auto subIter = WeirdIter(subTp, log); + + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do + { + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); + + ++loopCount; + + if (superIter.good() && subIter.growing) + { + subIter.pushType(mkFreshType(subIter.level)); + } + + if (subIter.good() && superIter.growing) + { + superIter.pushType(mkFreshType(superIter.level)); + } + + if (superIter.good() && subIter.good()) + { + tryUnify_(*subIter, *superIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + + superIter.advance(); + subIter.advance(); + continue; + } + + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*subTpv->tail, *superTpv->tail); + else if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; + } + + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + { + superIter.advance(); + continue; + } + + if (log.getMutable(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (log.getMutable(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } + + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + superTp = follow(superTp); + subTp = follow(subTp); + + while (auto tp = get(subTp)) + { + if (tp->head.empty() && tp->tail) + subTp = follow(*tp->tail); + else + break; + } + + while (auto tp = get(superTp)) + { + if (tp->head.empty() && tp->tail) + superTp = follow(*tp->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (get(superTp)) + { + occursCheck(superTp, subTp); + + if (!get(superTp)) + { + DEPRECATED_log(superTp); + *asMutable(superTp) = Unifiable::Bound(subTp); + } + } + else if (get(subTp)) + { + occursCheck(subTp, superTp); + + if (!get(subTp)) + { + DEPRECATED_log(subTp); + *asMutable(subTp) = Unifiable::Bound(superTp); + } + } + + else if (get(superTp)) + tryUnifyWithAny(subTp, superTp); + + else if (get(subTp)) + tryUnifyWithAny(superTp, subTp); + + else if (get(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (get(subTp)) + tryUnifyVariadics(superTp, subTp, true); + + else if (get(superTp) && get(subTp)) + { + auto superTpv = get(superTp); + auto subTpv = get(subTp); + + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = flatten(superTp); + auto [subTypes, subTail] = flatten(subTp); + + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + + auto superIter = DEPRECATED_WeirdIter{superTp}; + auto subIter = DEPRECATED_WeirdIter{subTp}; + + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do + { + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); + + ++loopCount; + + if (superIter.good() && subIter.growing) + asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + + if (subIter.good() && superIter.growing) + asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); + + if (superIter.good() && subIter.good()) + { + tryUnify_(*subIter, *superIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + + superIter.advance(); + subIter.advance(); + continue; + } + + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*subTpv->tail, *superTpv->tail); + else if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; + } + + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) + { + superIter.advance(); + continue; + } + + if (get(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (get(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } + + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } } } -void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) +void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const PrimitiveTypeVar* rp = get(subTy); - if (!lp || !rp) + const PrimitiveTypeVar* superPrim = get(superTy); + const PrimitiveTypeVar* subPrim = get(subTy); + if (!superPrim || !subPrim) ice("passed non primitive types to unifyPrimitives"); - if (lp->type != rp->type) + if (superPrim->type != subPrim->type) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } -void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const SingletonTypeVar* ls = get(superTy); - const SingletonTypeVar* rs = get(subTy); + const PrimitiveTypeVar* superPrim = get(superTy); + const SingletonTypeVar* superSingleton = get(superTy); + const SingletonTypeVar* subSingleton = get(subTy); - if ((!lp && !ls) || !rs) + if ((!superPrim && !superSingleton) || !subSingleton) ice("passed non singleton/primitive types to unifySingletons"); - if (ls && *ls == *rs) + if (superSingleton && *superSingleton == *subSingleton) return; - if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) return; - if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } -void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) +void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) { - FunctionTypeVar* lf = getMutable(superTy); - FunctionTypeVar* rf = getMutable(subTy); - if (!lf || !rf) + FunctionTypeVar* superFunction = getMutable(superTy); + FunctionTypeVar* subFunction = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); + } + + if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); - size_t numGenerics = lf->generics.size(); - if (numGenerics != rf->generics.size()) + size_t numGenerics = superFunction->generics.size(); + if (numGenerics != subFunction->generics.size()) { - numGenerics = std::min(lf->generics.size(), rf->generics.size()); + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); if (FFlag::LuauExtendedFunctionMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); @@ -1048,10 +1523,10 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - size_t numGenericPacks = lf->genericPacks.size(); - if (numGenericPacks != rf->genericPacks.size()) + size_t numGenericPacks = superFunction->genericPacks.size(); + if (numGenericPacks != subFunction->genericPacks.size()) { - numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); if (FFlag::LuauExtendedFunctionMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); @@ -1060,7 +1535,12 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal } for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); + { + if (FFlag::LuauUseCommittingTxnLog) + log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + else + DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + } CountMismatch::Context context = ctx; @@ -1071,7 +1551,7 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (FFlag::LuauExtendedFunctionMismatchError) { innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); bool reported = !innerState.errors.empty(); @@ -1085,13 +1565,13 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + innerState.tryUnify_(subFunction->retType, superFunction->retType); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) errors.push_back(*e); - else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) errors.push_back( @@ -1104,38 +1584,70 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal else { ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + innerState.tryUnify_(subFunction->retType, superFunction->retType); checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + log.concat(std::move(innerState.log)); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } } else { ctx = CountMismatch::Arg; - tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(lf->retType, rf->retType); + tryUnify_(subFunction->retType, superFunction->retType); } - if (lf->definition && !rf->definition && !subTy->persistent) + if (FFlag::LuauUseCommittingTxnLog) { - rf->definition = lf->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; + } } - else if (!lf->definition && rf->definition && !superTy->persistent) + else { - lf->definition = rf->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + subFunction->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + superFunction->definition = subFunction->definition; + } } ctx = context; for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); + { + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superFunction->generics[i], subFunction->generics[i]); + else + DEPRECATED_log.popSeen(superFunction->generics[i], subFunction->generics[i]); + } } namespace @@ -1160,77 +1672,84 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { if (!FFlag::LuauTableSubtypingVariance2) - return DEPRECATED_tryUnifyTables(left, right, isIntersection); + return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); std::vector missingProperties; std::vector extraProperties; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer && subTable->state != TableState::Free) { - for (const auto& [propName, superProp] : lt->props) + for (const auto& [propName, superProp] : superTable->props) { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + auto subIter = subTable->props.find(propName); + if (subIter == subTable->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) missingProperties.push_back(propName); } if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && - lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && + superTable->state != TableState::Free) { - for (const auto& [propName, subProp] : rt->props) + for (const auto& [propName, subProp] : subTable->props) { - auto superIter = lt->props.find(propName); - if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + auto superIter = superTable->props.find(propName); + if (superIter == superTable->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) extraProperties.push_back(propName); } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } } - // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. - for (const auto& [name, prop] : lt->props) + for (const auto& [name, prop] : superTable->props) { - const auto& r = rt->props.find(name); - if (r != rt->props.end()) + const auto& r = subTable->props.find(name); + if (r != subTable->props.end()) { // TODO: read-only properties don't need invariance Resetter resetter{&variance}; variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, r->second.type); + innerState.tryUnify_(r->second.type, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (rt->indexer && isString(rt->indexer->indexType)) + else if (subTable->indexer && isString(subTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1238,37 +1757,55 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } else if (isOptional(prop.type) || get(follow(prop.type))) // TODO: this case is unsound, but without it our test suite fails. CLI-46031 // TODO: should isOptional(anyType) be true? { } - else if (rt->state == TableState::Free) + else if (subTable->state == TableState::Free) { - log(rt); - rt->props[name] = prop; + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* ttv = getMutable(pendingSub); + LUAU_ASSERT(ttv); + ttv->props[name] = prop; + } + else + { + DEPRECATED_log(subTy); + subTable->props[name] = prop; + } } else missingProperties.push_back(name); } - for (const auto& [name, prop] : rt->props) + for (const auto& [name, prop] : subTable->props) { - if (lt->props.count(name)) + if (superTable->props.count(name)) { // If both lt and rt contain the property, then // we're done since we already unified them above } - else if (lt->indexer && isString(lt->indexer->indexType)) + else if (superTable->indexer && isString(superTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1276,24 +1813,42 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (lt->state == TableState::Unsealed) + else if (superTable->state == TableState::Unsealed) { // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. // TODO: file a JIRA // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(left); - lt->props[name] = clone; + + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = clone; + } + else + { + DEPRECATED_log(superTy); + superTable->props[name] = clone; + } } else if (variance == Covariant) { @@ -1303,61 +1858,93 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: should isOptional(anyType) be true? { } - else if (lt->state == TableState::Free) + else if (superTable->state == TableState::Free) { - log(left); - lt->props[name] = prop; + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = prop; + } + else + { + DEPRECATED_log(superTy); + superTable->props[name] = prop; + } } else extraProperties.push_back(name); } // Unify indexers - if (lt->indexer && rt->indexer) + if (superTable->indexer && subTable->indexer) { // TODO: read-only indexers don't need invariance Resetter resetter{&variance}; variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*lt->indexer, *rt->indexer); - checkChildUnifierTypeMismatch(innerState.errors, left, right); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (lt->indexer) + else if (superTable->indexer) { - if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + if (subTable->state == TableState::Unsealed || subTable->state == TableState::Free) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(right); - rt->indexer = lt->indexer; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(subTy, superTable->indexer); + } + else + { + DEPRECATED_log(subTy); + subTable->indexer = superTable->indexer; + } } } - else if (rt->indexer && variance == Invariant) + else if (subTable->indexer && variance == Invariant) { // Symmetric if we are invariant - if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + if (superTable->state == TableState::Unsealed || superTable->state == TableState::Free) { - log(left); - lt->indexer = rt->indexer; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(superTy, subTable->indexer); + } + else + { + DEPRECATED_log(superTy); + superTable->indexer = subTable->indexer; + } } } if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } @@ -1369,18 +1956,32 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (lt->boundTo || rt->boundTo) - return tryUnify_(left, right); + if (superTable->boundTo || subTable->boundTo) + return tryUnify_(subTy, superTy); - if (lt->state == TableState::Free) + if (superTable->state == TableState::Free) { - log(lt); - lt->boundTo = right; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(superTable); + superTable->boundTo = subTy; + } } - else if (rt->state == TableState::Free) + else if (subTable->state == TableState::Free) { - log(rt); - rt->boundTo = left; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + DEPRECATED_log(subTy); + subTable->boundTo = superTy; + } } } @@ -1406,99 +2007,129 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } -void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); - if (lt->state == TableState::Sealed && rt->state == TableState::Sealed) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Unsealed) || - (lt->state == TableState::Unsealed && rt->state == TableState::Sealed)) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Generic) || - (lt->state == TableState::Generic && rt->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{left, right}}); - else if ((lt->state == TableState::Free) != (rt->state == TableState::Free)) // one table is free and the other is not + if (superTable->state == TableState::Sealed && subTable->state == TableState::Sealed) + return tryUnifySealedTables(subTy, superTy, isIntersection); + else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Unsealed) || + (superTable->state == TableState::Unsealed && subTable->state == TableState::Sealed)) + return tryUnifySealedTables(subTy, superTy, isIntersection); + else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || + (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not { - TypeId freeTypeId = rt->state == TableState::Free ? right : left; - TypeId otherTypeId = rt->state == TableState::Free ? left : right; + TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; + TypeId otherTypeId = subTable->state == TableState::Free ? superTy : subTy; - return tryUnifyFreeTable(freeTypeId, otherTypeId); + return tryUnifyFreeTable(otherTypeId, freeTypeId); } - else if (lt->state == TableState::Free && rt->state == TableState::Free) + else if (superTable->state == TableState::Free && subTable->state == TableState::Free) { - tryUnifyFreeTable(left, right); + tryUnifyFreeTable(subTy, superTy); // avoid creating a cycle when the types are already pointing at each other - if (follow(left) != follow(right)) + if (follow(superTy) != follow(subTy)) { - log(lt); - lt->boundTo = right; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(superTable); + superTable->boundTo = subTy; + } } return; } - else if (lt->state != TableState::Sealed && rt->state != TableState::Sealed) + else if (superTable->state != TableState::Sealed && subTable->state != TableState::Sealed) { // All free tables are checked in one of the branches above - LUAU_ASSERT(lt->state != TableState::Free); - LUAU_ASSERT(rt->state != TableState::Free); + LUAU_ASSERT(superTable->state != TableState::Free); + LUAU_ASSERT(subTable->state != TableState::Free); // Tables must have exactly the same props and their types must all unify // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : lt->props) + for (const auto& [name, prop] : superTable->props) { - const auto& r = rt->props.find(name); - if (r == rt->props.end()) - errors.push_back(TypeError{location, UnknownProperty{right, name}}); + const auto& r = subTable->props.find(name); + if (r == subTable->props.end()) + errors.push_back(TypeError{location, UnknownProperty{subTy, name}}); else - tryUnify_(prop.type, r->second.type); + tryUnify_(r->second.type, prop.type); } - if (lt->indexer && rt->indexer) - tryUnify(*lt->indexer, *rt->indexer); - else if (lt->indexer) + if (superTable->indexer && subTable->indexer) + tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (superTable->indexer) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (rt->state == TableState::Unsealed) - rt->indexer = lt->indexer; + if (subTable->state == TableState::Unsealed) + { + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(subTy, superTable->indexer); + } + else + { + subTable->indexer = superTable->indexer; + } + } else - errors.push_back(TypeError{location, CannotExtendTable{right, CannotExtendTable::Indexer}}); + errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); } } - else if (lt->state == TableState::Sealed) + else if (superTable->state == TableState::Sealed) { // lt is sealed and so it must be possible for rt to have precisely the same shape // Verify that this is the case, then bind rt to lt. ice("unsealed tables are not working yet", location); } - else if (rt->state == TableState::Sealed) - return tryUnifyTables(right, left, isIntersection); + else if (subTable->state == TableState::Sealed) + return tryUnifyTables(superTy, subTy, isIntersection); else ice("tryUnifyTables"); } -void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) +void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { - TableTypeVar* freeTable = getMutable(freeTypeId); - TableTypeVar* otherTable = getMutable(otherTypeId); - if (!freeTable || !otherTable) + TableTypeVar* freeTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + freeTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!freeTable || !subTable) ice("passed non-table types to tryUnifyFreeTable"); // Any properties in freeTable must unify with those in otherTable. // Then bind freeTable to otherTable. for (const auto& [freeName, freeProp] : freeTable->props) { - if (auto otherProp = findTablePropertyRespectingMeta(otherTypeId, freeName)) + if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - tryUnify_(*otherProp, freeProp.type); + tryUnify_(freeProp.type, *subProp); /* * TypeVars are commonly cyclic, so it is entirely possible @@ -1508,84 +2139,133 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (!get(freeTypeId) || !get(otherTypeId)) - return tryUnify_(freeTypeId, otherTypeId); + if (FFlag::LuauUseCommittingTxnLog) + { + if (!log.getMutable(superTy) || !log.getMutable(subTy)) + return tryUnify_(subTy, superTy); - if (freeTable->boundTo) - return tryUnify_(freeTypeId, otherTypeId); + if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) + return tryUnify_(subTy, superTy); + } + else + { + if (!get(superTy) || !get(subTy)) + return tryUnify_(subTy, superTy); + + if (freeTable->boundTo) + return tryUnify_(subTy, superTy); + } } else { // If the other table is also free, then we are learning that it has more // properties than we previously thought. Else, it is an error. - if (otherTable->state == TableState::Free) - otherTable->props.insert({freeName, freeProp}); + if (subTable->state == TableState::Free) + { + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* pendingSubTtv = getMutable(pendingSub); + LUAU_ASSERT(pendingSubTtv); + pendingSubTtv->props.insert({freeName, freeProp}); + } + else + { + subTable->props.insert({freeName, freeProp}); + } + } else - errors.push_back(TypeError{location, UnknownProperty{otherTypeId, freeName}}); + errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}}); } } - if (freeTable->indexer && otherTable->indexer) + if (freeTable->indexer && subTable->indexer) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*freeTable->indexer, *otherTable->indexer); + innerState.tryUnifyIndexer(*subTable->indexer, *freeTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, freeTypeId, otherTypeId); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } - else if (otherTable->state == TableState::Free && freeTable->indexer) - freeTable->indexer = otherTable->indexer; - - if (!freeTable->boundTo && otherTable->state != TableState::Free) + else if (subTable->state == TableState::Free && freeTable->indexer) { - log(freeTable); - freeTable->boundTo = otherTypeId; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(superTy, subTable->indexer); + } + else + { + freeTable->indexer = subTable->indexer; + } + } + + if (!freeTable->boundTo && subTable->state != TableState::Free) + { + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(freeTable); + freeTable->boundTo = subTy; + } } } -void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); Unifier innerState = makeChildUnifier(); std::vector missingPropertiesInSuper; - bool isUnnamedTable = rt->name == std::nullopt && rt->syntheticName == std::nullopt; + bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; bool errorReported = false; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer) + if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer) { - for (const auto& [propName, superProp] : lt->props) + for (const auto& [propName, superProp] : superTable->props) { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type)) + auto subIter = subTable->props.find(propName); + if (subIter == subTable->props.end() && !isOptional(superProp.type)) missingPropertiesInSuper.push_back(propName); } if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } } // Tables must have exactly the same props and their types must all unify - for (const auto& it : lt->props) + for (const auto& it : superTable->props) { - const auto& r = rt->props.find(it.first); - if (r == rt->props.end()) + const auto& r = subTable->props.find(it.first); + if (r == subTable->props.end()) { if (isOptional(it.second.type)) continue; missingPropertiesInSuper.push_back(it.first); - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -1594,7 +2274,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio size_t oldErrorSize = innerState.errors.size(); Location old = innerState.location; innerState.location = *r->second.location; - innerState.tryUnify_(it.second.type, r->second.type); + innerState.tryUnify_(r->second.type, it.second.type); innerState.location = old; if (oldErrorSize != innerState.errors.size() && !errorReported) @@ -1605,113 +2285,165 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else { - innerState.tryUnify_(it.second.type, r->second.type); + innerState.tryUnify_(r->second.type, it.second.type); } } } - if (lt->indexer || rt->indexer) + if (superTable->indexer || subTable->indexer) { - if (lt->indexer && rt->indexer) - innerState.tryUnify(*lt->indexer, *rt->indexer); - else if (rt->state == TableState::Unsealed) + if (FFlag::LuauUseCommittingTxnLog) { - if (lt->indexer && !rt->indexer) - rt->indexer = lt->indexer; - } - else if (lt->state == TableState::Unsealed) - { - if (rt->indexer && !lt->indexer) - lt->indexer = rt->indexer; - } - else if (lt->indexer) - { - innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); - // We already try to unify properties in both tables. - // Skip those and just look for the ones remaining and see if they fit into the indexer. - for (const auto& [name, type] : rt->props) + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) { - const auto& it = lt->props.find(name); - if (it == lt->props.end()) - innerState.tryUnify_(lt->indexer->indexResultType, type.type); + if (superTable->indexer && !subTable->indexer) + { + log.changeIndexer(subTy, superTable->indexer); + } } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) + { + log.changeIndexer(superTy, subTable->indexer); + } + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + for (const auto& [name, type] : subTable->props) + { + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } else - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); + { + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) + { + if (superTable->indexer && !subTable->indexer) + subTable->indexer = superTable->indexer; + } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) + superTable->indexer = subTable->indexer; + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + // We already try to unify properties in both tables. + // Skip those and just look for the ones remaining and see if they fit into the indexer. + for (const auto& [name, type] : subTable->props) + { + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (!errorReported) + log.concat(std::move(innerState.log)); + } + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); if (errorReported) return; if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } - // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. + // If the superTy is an immediate part of an intersection type, do not do extra-property check. // Otherwise, we would falsely generate an extra-property-error for 's' in this code: // local a: {n: number} & {s: string} = {n=1, s=""} // When checking against the table '{n: number}'. - if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) + if (!isIntersection && superTable->state != TableState::Unsealed && !superTable->indexer) { // Check for extra properties in the subTy std::vector extraPropertiesInSub; - for (const auto& it : rt->props) + for (const auto& [subKey, subProp] : subTable->props) { - const auto& r = lt->props.find(it.first); - if (r == lt->props.end()) + const auto& superIt = superTable->props.find(subKey); + if (superIt == superTable->props.end()) { - if (isOptional(it.second.type)) + if (isOptional(subProp.type)) continue; - extraPropertiesInSub.push_back(it.first); + extraPropertiesInSub.push_back(subKey); } } if (!extraPropertiesInSub.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } } - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); } -void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed) +void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { - const MetatableTypeVar* lhs = get(metatable); - if (!lhs) + const MetatableTypeVar* superMetatable = get(superTy); + if (!superMetatable) ice("tryUnifyMetatable invoked with non-metatable TypeVar"); - TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other}}; + TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy}}; - if (const MetatableTypeVar* rhs = get(other)) + if (const MetatableTypeVar* subMetatable = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(lhs->table, rhs->table); - innerState.tryUnify_(lhs->metatable, rhs->metatable); + innerState.tryUnify_(subMetatable->table, superMetatable->table); + innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); if (auto e = hasUnificationTooComplex(innerState.errors)) errors.push_back(*e); else if (!innerState.errors.empty()) errors.push_back( - TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); + TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } - else if (TableTypeVar* rhs = getMutable(other)) + else if (TableTypeVar* subTable = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : getMutable(subTy)) { - switch (rhs->state) + switch (subTable->state) { case TableState::Free: { - tryUnify_(lhs->table, other); - rhs->boundTo = metatable; + tryUnify_(subTy, superMetatable->table); + + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + subTable->boundTo = superTy; + } break; } @@ -1722,7 +2454,8 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse errors.push_back(mismatchError); } } - else if (get(other) || get(other)) + else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) + : (get(subTy) || get(subTy))) { } else @@ -1732,7 +2465,7 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse } // Class unification is almost, but not quite symmetrical. We use the 'reversed' boolean to indicate which scenario we are evaluating. -void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) +void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { if (reversed) std::swap(superTy, subTy); @@ -1763,7 +2496,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) } ice("Illegal variance setting!"); } - else if (TableTypeVar* table = getMutable(subTy)) + else if (TableTypeVar* subTable = getMutable(subTy)) { /** * A free table is something whose shape we do not exactly know yet. @@ -1775,12 +2508,12 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) * * Tables that are not free are known to be actual tables. */ - if (table->state != TableState::Free) + if (subTable->state != TableState::Free) return fail(); bool ok = true; - for (const auto& [propName, prop] : table->props) + for (const auto& [propName, prop] : subTable->props) { const Property* classProp = lookupClassProp(superClass, propName); if (!classProp) @@ -1791,23 +2524,37 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) else { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, classProp->type); + innerState.tryUnify_(classProp->type, prop.type); checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) + if (FFlag::LuauUseCommittingTxnLog) { - log.concat(std::move(innerState.log)); + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + } } else { - ok = false; - innerState.log.rollback(); + if (innerState.errors.empty()) + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + else + { + ok = false; + innerState.DEPRECATED_log.rollback(); + } } } } - if (table->indexer) + if (subTable->indexer) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; @@ -1817,17 +2564,24 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) if (!ok) return; - log(table); - table->boundTo = superTy; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + DEPRECATED_log(subTable); + subTable->boundTo = superTy; + } } else return fail(); } -void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer) +void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { - tryUnify_(superIndexer.indexType, subIndexer.indexType); - tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); + tryUnify_(subIndexer.indexType, superIndexer.indexType); + tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); } static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -1840,54 +2594,85 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (get(a)) + if (FFlag::LuauUseCommittingTxnLog) { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; + if (state.log.getMutable(a)) + { + state.log.replace(a, Unifiable::Bound{anyTypePack}); + } + else if (auto tp = state.log.getMutable(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } } - else if (auto tp = get(a)) + else { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; + if (get(a)) + { + state.DEPRECATED_log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } } } } -void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) +void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { - const VariadicTypePack* lv = get(superTp); - if (!lv) + const VariadicTypePack* superVariadic = get(superTp); + + if (FFlag::LuauUseCommittingTxnLog) + { + superVariadic = log.getMutable(superTp); + } + + if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); - if (const VariadicTypePack* rv = get(subTp)) - tryUnify_(reversed ? rv->ty : lv->ty, reversed ? lv->ty : rv->ty); + if (const VariadicTypePack* subVariadic = get(subTp)) + tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); else if (get(subTp)) { - TypePackIterator rIter = begin(subTp); - TypePackIterator rEnd = end(subTp); + TypePackIterator subIter = begin(subTp, &log); + TypePackIterator subEnd = end(subTp); - std::advance(rIter, subOffset); + std::advance(subIter, subOffset); - while (rIter != rEnd) + while (subIter != subEnd) { - tryUnify_(reversed ? *rIter : lv->ty, reversed ? lv->ty : *rIter); - ++rIter; + tryUnify_(reversed ? superVariadic->ty : *subIter, reversed ? *subIter : superVariadic->ty); + ++subIter; } - if (std::optional maybeTail = rIter.tail()) + if (std::optional maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); if (get(tail)) { - log(tail); - *asMutable(tail) = BoundTypePack{superTp}; + if (FFlag::LuauUseCommittingTxnLog) + { + log.replace(tail, BoundTypePack(superTp)); + } + else + { + DEPRECATED_log(tail); + *asMutable(tail) = BoundTypePack{superTp}; + } } else if (const VariadicTypePack* vtp = get(tail)) { - tryUnify_(lv->ty, vtp->ty); + tryUnify_(vtp->ty, superVariadic->ty); } else if (get(tail)) { @@ -1914,65 +2699,113 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { while (!queue.empty()) { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.find(ty)) - continue; - seen.insert(ty); + if (FFlag::LuauUseCommittingTxnLog) + { + TypeId ty = state.log.follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); - if (get(ty)) - { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); - - if (table->indexer) + if (state.log.getMutable(ty)) { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); + state.log.replace(ty, BoundTypeVar{anyType}); } + else if (auto fun = state.log.getMutable(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = state.log.getMutable(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = state.log.getMutable(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (state.log.getMutable(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = state.log.getMutable(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = state.log.getMutable(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); else { - } // Primitives, any, errors, and generics are left untouched. + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + + if (get(ty)) + { + state.DEPRECATED_log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } } } -void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) +void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(any) || get(any)); + LUAU_ASSERT(get(anyTy) || get(anyTy)); // These types are not visited in general loop below - if (get(ty) || get(ty) || get(ty)) + if (get(subTy) || get(subTy) || get(subTy)) return; const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); - const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + const TypePackId anyTP = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::vector queue = {ty}; + std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); @@ -1980,9 +2813,9 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } -void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) +void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { - LUAU_ASSERT(get(any)); + LUAU_ASSERT(get(anyTp)); const TypeId anyTy = getSingletonTypes().errorRecoveryType(); @@ -1991,9 +2824,9 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, subTy, anyTp); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, anyTp); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -2012,54 +2845,105 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - needle = follow(needle); - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (get(needle)) - return; - - if (!get(needle)) - ice("Expected needle to be free"); - - if (needle == haystack) - { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); - return; - } - auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); }; - if (get(haystack)) - return; - else if (auto a = get(haystack)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypeId ty : a->argTypes) - check(ty); + needle = log.follow(needle); + haystack = log.follow(haystack); - for (TypeId ty : a->retType) + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (log.getMutable(needle)) + return; + + if (!log.getMutable(needle)) + ice("Expected needle to be free"); + + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryType()); + + return; + } + + if (log.getMutable(haystack)) + return; + else if (auto a = log.getMutable(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it) + check(*it); + + for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it) + check(*it); + } + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) check(ty); } } - else if (auto a = get(haystack)) + else { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = get(haystack)) - { - for (TypeId ty : a->parts) - check(ty); + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; + + if (!get(needle)) + ice("Expected needle to be free"); + + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + DEPRECATED_log(needle); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); + return; + } + + if (get(haystack)) + return; + else if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypeId ty : a->argTypes) + check(ty); + + for (TypeId ty : a->retType) + check(ty); + } + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } } @@ -2072,59 +2956,115 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { - needle = follow(needle); - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (get(needle)) - return; - - if (!get(needle)) - ice("Expected needle pack to be free"); - - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - - while (!get(haystack)) + if (FFlag::LuauUseCommittingTxnLog) { - if (needle == haystack) - { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); - return; - } + needle = log.follow(needle); + haystack = log.follow(haystack); - if (auto a = get(haystack)) + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (log.getMutable(needle)) + return; + + if (!get(needle)) + ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + + while (!log.getMutable(haystack)) { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); + + return; + } + + if (auto a = get(haystack)) { for (const auto& ty : a->head) { - if (auto f = get(follow(ty))) + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + if (auto f = log.getMutable(log.follow(ty))) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } } } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } + } + break; + } + } + else + { + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; + + if (!get(needle)) + ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + + while (!get(haystack)) + { + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + DEPRECATED_log(needle); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); } - if (a->tail) + if (auto a = get(haystack)) { - haystack = follow(*a->tail); - continue; + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (const auto& ty : a->head) + { + if (auto f = get(follow(ty))) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } + } + } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } } + break; } - break; } } Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState}; + if (FFlag::LuauUseCommittingTxnLog) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, &log}; + else + return Unifier{types, mode, globalScope, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Common.h b/Ast/include/Luau/Common.h index 63cd3df4..fbb03a9e 100644 --- a/Ast/include/Luau/Common.h +++ b/Ast/include/Luau/Common.h @@ -29,7 +29,7 @@ namespace Luau { -using AssertHandler = int (*)(const char* expression, const char* file, int line); +using AssertHandler = int (*)(const char* expression, const char* file, int line, const char* function); inline AssertHandler& assertHandler() { @@ -37,10 +37,10 @@ inline AssertHandler& assertHandler() return handler; } -inline int assertCallHandler(const char* expression, const char* file, int line) +inline int assertCallHandler(const char* expression, const char* file, int line, const char* function) { if (AssertHandler handler = assertHandler()) - return handler(expression, file, line); + return handler(expression, file, line, function); return 1; } @@ -48,7 +48,7 @@ inline int assertCallHandler(const char* expression, const char* file, int line) } // namespace Luau #if !defined(NDEBUG) || defined(LUAU_ENABLE_ASSERT) -#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__) && (LUAU_DEBUGBREAK(), 0)))) +#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__, __FUNCTION__) && (LUAU_DEBUGBREAK(), 0)))) #define LUAU_ASSERTENABLED #else #define LUAU_ASSERT(expr) (void)sizeof(!!(expr)) diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index aecb619a..54a9a26f 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -107,7 +107,7 @@ static void displayHelp(const char* argv0) printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 35c02f2c..26d4333a 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -235,11 +235,14 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, while (lua_next(L, -2) != 0) { - // table, key, value - std::string_view key = lua_tostring(L, -2); + if (lua_type(L, -2) == LUA_TSTRING) + { + // table, key, value + std::string_view key = lua_tostring(L, -2); - if (!key.empty() && Luau::startsWith(key, prefix)) - completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + if (!key.empty() && Luau::startsWith(key, prefix)) + completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + } lua_pop(L, 1); } @@ -253,7 +256,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_rawget(L, -2); lua_remove(L, -2); - if (lua_isnil(L, -1)) + if (!lua_istable(L, -1)) break; lookup.remove_prefix(dot + 1); @@ -266,7 +269,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) { size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.')) + while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == '_')) start--; // look the value up in current global table first @@ -278,6 +281,34 @@ static void completeRepl(lua_State* L, const char* editBuffer, std::vector globalState(luaL_newstate(), lua_close); @@ -292,6 +323,7 @@ static void runRepl() }); std::string buffer; + LinenoiseScopedHistory scopedHistory; for (;;) { @@ -457,7 +489,7 @@ static void displayHelp(const char* argv0) printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; diff --git a/CLI/Web.cpp b/CLI/Web.cpp index cf5c831e..416a79f2 100644 --- a/CLI/Web.cpp +++ b/CLI/Web.cpp @@ -53,30 +53,38 @@ static std::string runCode(lua_State* L, const std::string& source) lua_insert(T, 1); lua_pcall(T, n, 0, 0); } + + lua_pop(L, 1); // pop T + return std::string(); } else { std::string error; + lua_Debug ar; + if (lua_getinfo(L, 0, "sln", &ar)) + { + error += ar.short_src; + error += ':'; + error += std::to_string(ar.currentline); + error += ": "; + } + if (status == LUA_YIELD) { - error = "thread yielded unexpectedly"; + error += "thread yielded unexpectedly"; } else if (const char* str = lua_tostring(T, -1)) { - error = str; + error += str; } error += "\nstack backtrace:\n"; error += lua_debugtrace(T); - error = "Error:" + error; - - fprintf(stdout, "%s", error.c_str()); + lua_pop(L, 1); // pop T + return error; } - - lua_pop(L, 1); - return std::string(); } extern "C" const char* executeScript(const char* source) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 71631d10..d9694d7d 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -377,6 +377,7 @@ enum LuauBytecodeTag { // Bytecode version LBC_VERSION = 1, + LBC_VERSION_FUTURE = 2, // TODO: This will be removed in favor of LBC_VERSION with LuauBytecodeV2Force // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index d4ebad6b..287bf4ee 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -74,6 +74,7 @@ public: void expandJumps(); void setDebugFunctionName(StringRef name); + void setDebugFunctionLineDefined(int line); void setDebugLine(int line); void pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushDebugUpval(StringRef name); @@ -162,6 +163,7 @@ private: bool isvararg = false; unsigned int debugname = 0; + int debuglinedefined = 0; std::string dump; std::string dumpname; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 3280c8a4..2d31c409 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Write, false) + namespace Luau { @@ -81,6 +83,52 @@ static int getOpLength(LuauOpcode op) } } +inline bool isJumpD(LuauOpcode op) +{ + switch (op) + { + case LOP_JUMP: + case LOP_JUMPIF: + case LOP_JUMPIFNOT: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_FORNPREP: + case LOP_FORNLOOP: + case LOP_FORGLOOP: + case LOP_FORGPREP_INEXT: + case LOP_FORGLOOP_INEXT: + case LOP_FORGPREP_NEXT: + case LOP_FORGLOOP_NEXT: + case LOP_JUMPBACK: + case LOP_JUMPIFEQK: + case LOP_JUMPIFNOTEQK: + return true; + + default: + return false; + } +} + +inline bool isSkipC(LuauOpcode op) +{ + switch (op) + { + case LOP_LOADB: + case LOP_FASTCALL: + case LOP_FASTCALL1: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + return true; + + default: + return false; + } +} + bool BytecodeBuilder::StringRef::operator==(const StringRef& other) const { return (data && other.data) ? (length == other.length && memcmp(data, other.data, length) == 0) : (data == other.data); @@ -365,13 +413,7 @@ bool BytecodeBuilder::patchJumpD(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_JUMP || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIF || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLT || - LUAU_INSN_OP(jumpInsn) == LOP_FORNPREP || LUAU_INSN_OP(jumpInsn) == LOP_FORNLOOP || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP || - LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_INEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_INEXT || - LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_NEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_NEXT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPBACK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQK); + LUAU_ASSERT(isJumpD(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_D(jumpInsn) == 0); LUAU_ASSERT(targetLabel <= insns.size()); @@ -403,8 +445,7 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL1 || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2 || - LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2K); + LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0); int offset = int(targetLabel) - int(jumpLabel) - 1; @@ -428,6 +469,11 @@ void BytecodeBuilder::setDebugFunctionName(StringRef name) functions[currentFunction].dumpname = std::string(name.data, name.length); } +void BytecodeBuilder::setDebugFunctionLineDefined(int line) +{ + functions[currentFunction].debuglinedefined = line; +} + void BytecodeBuilder::setDebugLine(int line) { debugLine = line; @@ -464,7 +510,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(LBC_VERSION); + bytecode = char(FFlag::LuauBytecodeV2Write ? LBC_VERSION_FUTURE : LBC_VERSION); writeStringTable(bytecode); @@ -565,6 +611,9 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const writeVarInt(ss, child); // debug info + if (FFlag::LuauBytecodeV2Write) + writeVarInt(ss, func.debuglinedefined); + writeVarInt(ss, func.debugname); bool hasLines = true; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8f74ffed..6ae49027 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) -LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) namespace Luau { @@ -179,6 +178,8 @@ struct Compiler if (options.optimizationLevel >= 1 && options.debugLevel >= 2) gatherConstUpvals(func); + bytecode.setDebugFunctionLineDefined(func->location.begin.line + 1); + if (options.debugLevel >= 1 && func->debugname.value) bytecode.setDebugFunctionName(sref(func->debugname)); @@ -3626,9 +3627,9 @@ struct Compiler return LBF_BIT32_RROTATE; if (builtin.method == "rshift") return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin) + if (builtin.method == "countlz") return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin) + if (builtin.method == "countrz") return LBF_BIT32_COUNTRZ; } diff --git a/Sources.cmake b/Sources.cmake index a7153eb3..5dd486aa 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -125,6 +125,7 @@ target_sources(Luau.VM PRIVATE VM/src/linit.cpp VM/src/lmathlib.cpp VM/src/lmem.cpp + VM/src/lnumprint.cpp VM/src/lobject.cpp VM/src/loslib.cpp VM/src/lperf.cpp diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index a01a1481..7e0832e7 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -138,10 +138,6 @@ /* }================================================================== */ -/* Default number printing format and the string length limit */ -#define LUA_NUMBER_FMT "%.14g" -#define LUAI_MAXNUMBER2STR 32 /* 16 digits, sign, point, and \0 */ - /* @@ LUAI_USER_ALIGNMENT_T is a type that requires maximum alignment. ** CHANGE it if your system requires alignments larger than double. (For diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index a65b0325..c98b9590 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauActivateBeforeExec) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -939,21 +937,7 @@ void lua_call(lua_State* L, int nargs, int nresults) checkresults(L, nargs, nresults); func = L->top - (nargs + 1); - if (FFlag::LuauActivateBeforeExec) - { - luaD_call(L, func, nresults); - } - else - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); - - luaD_call(L, func, nresults); - - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + luaD_call(L, func, nresults); adjustresults(L, nresults); return; @@ -994,21 +978,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - if (FFlag::LuauActivateBeforeExec) - { - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - } - else - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); - - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); adjustresults(L, nresults); return status; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 7ed2a62e..71975a52 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -7,9 +7,12 @@ #include "lstring.h" #include "lapi.h" #include "lgc.h" +#include "lnumutils.h" #include +LUAU_FASTFLAG(LuauSchubfach) + /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -477,7 +480,17 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) switch (lua_type(L, idx)) { case LUA_TNUMBER: - lua_pushstring(L, lua_tostring(L, idx)); + if (FFlag::LuauSchubfach) + { + double n = lua_tonumber(L, idx); + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, n); + lua_pushlstring(L, s, e - s); + } + else + { + lua_pushstring(L, lua_tostring(L, idx)); + } break; case LUA_TSTRING: lua_pushvalue(L, idx); @@ -491,11 +504,30 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); + + if (FFlag::LuauSchubfach) + { + char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; + char* e = s; + for (int i = 0; i < LUA_VECTOR_SIZE; ++i) + { + if (i != 0) + { + *e++ = ','; + *e++ = ' '; + } + e = luai_num2str(e, v[i]); + } + lua_pushlstring(L, s, e - s); + } + else + { #if LUA_VECTOR_SIZE == 4 - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); #else - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); #endif + } break; } default: diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 8b511edf..093400f2 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -5,8 +5,6 @@ #include "lcommon.h" #include "lnumutils.h" -LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) - #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -182,9 +180,6 @@ static int b_replace(lua_State* L) static int b_countlz(lua_State* L) { - if (!FFlag::LuauBit32Count) - luaL_error(L, "bit32.countlz isn't enabled"); - b_uint v = luaL_checkunsigned(L, 1); b_uint r = NBITS; @@ -201,9 +196,6 @@ static int b_countlz(lua_State* L) static int b_countrz(lua_State* L) { - if (!FFlag::LuauBit32Count) - luaL_error(L, "bit32.countrz isn't enabled"); - b_uint v = luaL_checkunsigned(L, 1); b_uint r = NBITS; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 9fe1885f..2b5382bb 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,9 @@ #include #include +LUAU_FASTFLAG(LuauBytecodeV2Read) +LUAU_FASTFLAG(LuauBytecodeV2Force) + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -89,6 +92,16 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } +static int getlinedefined(Proto* p) +{ + if (FFlag::LuauBytecodeV2Force) + return p->linedefined; + else if (FFlag::LuauBytecodeV2Read && p->linedefined >= 0) + return p->linedefined; + else + return luaG_getline(p, 0); +} + static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { int status = 1; @@ -108,7 +121,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, { ar->source = getstr(f->l.p->source); ar->what = "Lua"; - ar->linedefined = luaG_getline(f->l.p, 0); + ar->linedefined = getlinedefined(f->l.p); } luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; @@ -121,7 +134,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, } else { - ar->currentline = f->isC ? -1 : luaG_getline(f->l.p, 0); + ar->currentline = f->isC ? -1 : getlinedefined(f->l.p); } break; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 62bbdb7c..eb47971a 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,7 +19,6 @@ LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) -LUAU_FASTFLAGVARIABLE(LuauActivateBeforeExec, false) /* ** {====================================================== @@ -228,21 +227,14 @@ void luaD_call(lua_State* L, StkId func, int nResults) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ - if (FFlag::LuauActivateBeforeExec) - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - luau_execute(L); /* call it */ + luau_execute(L); /* call it */ - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } - else - { - luau_execute(L); /* call it */ - } + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); } L->nCcalls--; luaC_checkGC(L); @@ -549,12 +541,9 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } - if (FFlag::LuauActivateBeforeExec) - { - // since the call failed with an error, we might have to reset the 'active' thread state - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + // since the call failed with an error, we might have to reset the 'active' thread state + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); if (FFlag::LuauCcallRestoreFix) { diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 7393fc74..76ef7a06 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,8 +12,6 @@ #include -LUAU_FASTFLAG(LuauArrayBoundary) - #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index f6f7a878..c66de9c1 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauArrayBoundary) - static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -38,10 +36,7 @@ static void validatetable(global_State* g, Table* h) { int sizenode = 1 << h->lsizenode; - if (FFlag::LuauArrayBoundary) - LUAU_ASSERT(h->lastfree <= sizenode); - else - LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + LUAU_ASSERT(h->lastfree <= sizenode); if (h->metatable) validateobjref(g, obj2gco(h), obj2gco(h->metatable)); diff --git a/VM/src/lnumprint.cpp b/VM/src/lnumprint.cpp new file mode 100644 index 00000000..2fd0f1bb --- /dev/null +++ b/VM/src/lnumprint.cpp @@ -0,0 +1,375 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "luaconf.h" +#include "lnumutils.h" + +#include "lcommon.h" + +#include +#include // TODO: Remove with LuauSchubfach + +#ifdef _MSC_VER +#include +#endif + +// This work is based on: +// Raffaello Giulietti. The Schubfach way to render doubles. 2021 +// https://drive.google.com/file/d/1IEeATSVnEE6TkrHlCYNY2GjaraBjOT4f/edit + +// The code uses the notation from the paper for local variables where appropriate, and refers to paper sections/figures/results. + +LUAU_FASTFLAGVARIABLE(LuauSchubfach, false) + +// 9.8.2. Precomputed table for 128-bit overestimates of powers of 10 (see figure 3 for table bounds) +// To avoid storing 616 128-bit numbers directly we use a technique inspired by Dragonbox implementation and store 16 consecutive +// powers using a 128-bit baseline and a bitvector with 1-bit scale and 3-bit offset for the delta between each entry and base*5^k +static const int kPow10TableMin = -292; +static const int kPow10TableMax = 324; + +// clang-format off +static const uint64_t kPow5Table[16] = { + 0x8000000000000000, 0xa000000000000000, 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, 0xc350000000000000, + 0xf424000000000000, 0x9896800000000000, 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, 0xba43b74000000000, + 0xe8d4a51000000000, 0x9184e72a00000000, 0xb5e620f480000000, 0xe35fa931a0000000, +}; +static const uint64_t kPow10Table[(kPow10TableMax - kPow10TableMin + 1 + 15) / 16][3] = { + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b, 0x333443443333443b}, {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4, 0xbbb3ab3cb3ba3cbc}, + {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa, 0x4ba4bc4bb4bb4bcc}, {0xaecc49914078536d, 0x58fae9f773886e19, 0x3ba3bc33b43b43bb}, + {0xc21094364dfb5636, 0x985915fc12f542e5, 0x33b43b43a33b33cb}, {0xd77485cb25823ac7, 0x7d633293366b828c, 0x34b44c444343443c}, + {0xef340a98172aace4, 0x86fb897116c87c35, 0x333343333343334b}, {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074, 0xccaccbbcbcbb4bbc}, + {0x936b9fcebb25c995, 0xcab10dd900beec35, 0x3ab3ab3ab3bb3bbb}, {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb, 0x4cc3dc4db4db4dbb}, + {0xb5b5ada8aaff80b8, 0x0d819992132456bb, 0x33b33a34c33b34ab}, {0xc9bcff6034c13052, 0xfc89b393dd02f0b6, 0x33c33b44b43c34bc}, + {0xdff9772470297ebd, 0x59787e2b93bc56f8, 0x43b444444443434c}, {0xf8a95fcf88747d94, 0x75a44c6397ce912b, 0x443334343443343b}, + {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900, 0xbbabab3aa3ab4ccc}, {0x993fe2c6d07b7fab, 0xe546a8038efe402a, 0x4cb4bc4db4db4bcc}, + {0xaa242499697392d2, 0xdde50bd1d5d0b9ea, 0x3ba3ba3bb33b33bc}, {0xbce5086492111aea, 0x88f4bb1ca6bcf585, 0x44b44c44c44c43cb}, + {0xd1b71758e219652b, 0xd3c36113404ea4a9, 0x44c44c44c444443b}, {0xe8d4a51000000000, 0x0000000000000000, 0x444444444444444c}, + {0x813f3978f8940984, 0x4000000000000000, 0xcccccccccccccccc}, {0x8f7e32ce7bea5c6f, 0xe4820023a2000000, 0xbba3bc4cc4cc4ccc}, + {0x9f4f2726179a2245, 0x01d762422c946591, 0x4aa3bb3aa3ba3bab}, {0xb0de65388cc8ada8, 0x3b25a55f43294bcc, 0x3ca33b33b44b43bc}, + {0xc45d1df942711d9a, 0x3ba5d0bd324f8395, 0x44c44c34c44b44cb}, {0xda01ee641a708de9, 0xe80e6f4820cc9496, 0x33b33b343333333c}, + {0xf209787bb47d6b84, 0xc0678c5dbd23a49b, 0x443444444443443b}, {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3, 0xdbccbcccb4cb3bbb}, + {0x952ab45cfa97a0b2, 0xdd945a747bf26184, 0x3bc4bb4ab3ca3cbc}, {0xa59bc234db398c25, 0x43fab9837e699096, 0x3bb3ac3ab3bb33ac}, + {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30, 0x33b43b43b34c34dc}, {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5, 0x34c44c43c44b44cb}, + {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e, 0x333333333333333c}, {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2, 0x433344443333344c}, + {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f, 0xdcbdcc3cc4cc4bcb}, {0x9b10a4e5e9913128, 0xca7cf2b4191c8327, 0x3ab3cb3bc3bb4bbb}, + {0xac2820d9623bf429, 0x546345fa9fbdcd45, 0x3bb3cc43c43c43cb}, {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4, 0x44b34a43b44c44bc}, + {0xd433179d9c8cb841, 0x5fa60692a46151ec, 0x43a33a33a333333c}, +}; +// clang-format on + +static const char kDigitTable[] = "0001020304050607080910111213141516171819202122232425262728293031323334353637383940414243444546474849" + "5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899"; + +// x*y => 128-bit product (lo+hi) +inline uint64_t mul128(uint64_t x, uint64_t y, uint64_t* hi) +{ +#if defined(_MSC_VER) && defined(_M_X64) + return _umul128(x, y, hi); +#elif defined(__SIZEOF_INT128__) + unsigned __int128 r = x; + r *= y; + *hi = uint64_t(r >> 64); + return uint64_t(r); +#else + uint32_t x0 = uint32_t(x), x1 = uint32_t(x >> 32); + uint32_t y0 = uint32_t(y), y1 = uint32_t(y >> 32); + uint64_t p11 = uint64_t(x1) * y1, p01 = uint64_t(x0) * y1; + uint64_t p10 = uint64_t(x1) * y0, p00 = uint64_t(x0) * y0; + uint64_t mid = p10 + (p00 >> 32) + uint32_t(p01); + uint64_t r0 = (mid << 32) | uint32_t(p00); + uint64_t r1 = p11 + (mid >> 32) + (p01 >> 32); + *hi = r1; + return r0; +#endif +} + +// (x*y)>>64 => 128-bit product (lo+hi) +inline uint64_t mul192hi(uint64_t xhi, uint64_t xlo, uint64_t y, uint64_t* hi) +{ + uint64_t z2; + uint64_t z1 = mul128(xhi, y, &z2); + + uint64_t z1c; + uint64_t z0 = mul128(xlo, y, &z1c); + (void)z0; + + z1 += z1c; + z2 += (z1 < z1c); + + *hi = z2; + return z1; +} + +// 9.3. Rounding to odd (+ figure 8 + result 23) +inline uint64_t roundodd(uint64_t ghi, uint64_t glo, uint64_t cp) +{ + uint64_t xhi; + uint64_t xlo = mul128(glo, cp, &xhi); + (void)xlo; + + uint64_t yhi; + uint64_t ylo = mul128(ghi, cp, &yhi); + + uint64_t z = ylo + xhi; + return (yhi + (z < xhi)) | (z > 1); +} + +struct Decimal +{ + uint64_t s; + int k; +}; + +static Decimal schubfach(int exponent, uint64_t fraction) +{ + // Extract c & q such that c*2^q == |v| + uint64_t c = fraction; + int q = exponent - 1023 - 51; + + if (exponent != 0) // normal numbers have implicit leading 1 + { + c |= (1ull << 52); + q--; + } + + // 8.3. Fast path for integers + if (unsigned(-q) < 53 && (c & ((1ull << (-q)) - 1)) == 0) + return {c >> (-q), 0}; + + // 5. Rounding interval + int irr = (c == (1ull << 52) && q != -1074); // Qmin + int out = int(c & 1); + + // 9.8.1. Boundaries for c + uint64_t cbl = 4 * c - 2 + irr; + uint64_t cb = 4 * c; + uint64_t cbr = 4 * c + 2; + + // 9.1. Computing k and h + const int Q = 20; + const int C = 315652; // floor(2^Q * log10(2)) + const int A = -131008; // floor(2^Q * log10(3/4)) + const int C2 = 3483294; // floor(2^Q * log2(10)) + int k = (q * C + (irr ? A : 0)) >> Q; + int h = q + ((-k * C2) >> Q) + 1; // see (9) in 9.9 + + // 9.8.2. Overestimates of powers of 10 + // Recover 10^-k fraction using compact tables generated by tools/numutils.py + // The 128-bit fraction is encoded as 128-bit baseline * power-of-5 * scale + offset + LUAU_ASSERT(-k >= kPow10TableMin && -k <= kPow10TableMax); + int gtoff = -k - kPow10TableMin; + const uint64_t* gt = kPow10Table[gtoff >> 4]; + + uint64_t ghi; + uint64_t glo = mul192hi(gt[0], gt[1], kPow5Table[gtoff & 15], &ghi); + + // Apply 1-bit scale + 3-bit offset; note, offset is intentionally applied without carry, numutils.py validates that this is sufficient + int gterr = (gt[2] >> ((gtoff & 15) * 4)) & 15; + int gtscale = gterr >> 3; + + ghi <<= gtscale; + ghi += (glo >> 63) & gtscale; + glo <<= gtscale; + glo -= (gterr & 7) - 4; + + // 9.9. Boundaries for v + uint64_t vbl = roundodd(ghi, glo, cbl << h); + uint64_t vb = roundodd(ghi, glo, cb << h); + uint64_t vbr = roundodd(ghi, glo, cbr << h); + + // Main algorithm; see figure 7 + figure 9 + uint64_t s = vb / 4; + + if (s >= 10) + { + uint64_t sp = s / 10; + + bool upin = vbl + out <= 40 * sp; + bool wpin = vbr >= 40 * sp + 40 + out; + + if (upin != wpin) + return {sp + wpin, k + 1}; + } + + // Figure 7 contains the algorithm to select between u (s) and w (s+1) + // rup computes the last 4 conditions in that algorithm + // rup is only used when uin == win, but since these branches predict poorly we use branchless selects + bool uin = vbl + out <= 4 * s; + bool win = 4 * s + 4 + out <= vbr; + bool rup = vb >= 4 * s + 2 + 1 - (s & 1); + + return {s + (uin != win ? win : rup), k}; +} + +static char* printspecial(char* buf, int sign, uint64_t fraction) +{ + if (fraction == 0) + { + memcpy(buf, ("-inf") + (1 - sign), 4); + return buf + 3 + sign; + } + else + { + memcpy(buf, "nan", 4); + return buf + 3; + } +} + +static char* printunsignedrev(char* end, uint64_t num) +{ + while (num >= 10000) + { + unsigned int tail = unsigned(num % 10000); + + memcpy(end - 4, &kDigitTable[int(tail / 100) * 2], 2); + memcpy(end - 2, &kDigitTable[int(tail % 100) * 2], 2); + num /= 10000; + end -= 4; + } + + unsigned int rest = unsigned(num); + + while (rest >= 10) + { + memcpy(end - 2, &kDigitTable[int(rest % 100) * 2], 2); + rest /= 100; + end -= 2; + } + + if (rest) + { + end[-1] = '0' + int(rest); + end -= 1; + } + + return end; +} + +static char* printexp(char* buf, int num) +{ + *buf++ = 'e'; + *buf++ = num < 0 ? '-' : '+'; + + int v = num < 0 ? -num : num; + + if (v >= 100) + { + *buf++ = '0' + (v / 100); + v %= 100; + } + + memcpy(buf, &kDigitTable[v * 2], 2); + return buf + 2; +} + +inline char* trimzero(char* end) +{ + while (end[-1] == '0') + end--; + + return end; +} + +// We use fixed-length memcpy/memset since they lower to fast SIMD+scalar writes; the target buffers should have padding space +#define fastmemcpy(dst, src, size, sizefast) check_exp((size) <= sizefast, memcpy(dst, src, sizefast)) +#define fastmemset(dst, val, size, sizefast) check_exp((size) <= sizefast, memset(dst, val, sizefast)) + +char* luai_num2str(char* buf, double n) +{ + if (!FFlag::LuauSchubfach) + { + snprintf(buf, LUAI_MAXNUM2STR, LUA_NUMBER_FMT, n); + return buf + strlen(buf); + } + + // IEEE-754 + union + { + double v; + uint64_t bits; + } v = {n}; + int sign = int(v.bits >> 63); + int exponent = int(v.bits >> 52) & 2047; + uint64_t fraction = v.bits & ((1ull << 52) - 1); + + // specials + if (LUAU_UNLIKELY(exponent == 0x7ff)) + return printspecial(buf, sign, fraction); + + // sign bit + *buf = '-'; + buf += sign; + + // zero + if (exponent == 0 && fraction == 0) + { + buf[0] = '0'; + return buf + 1; + } + + // convert binary to decimal using Schubfach + Decimal d = schubfach(exponent, fraction); + LUAU_ASSERT(d.s < uint64_t(1e17)); + + // print the decimal to a temporary buffer; we'll need to insert the decimal point and figure out the format + char decbuf[40]; + char* decend = decbuf + 20; // significand needs at most 17 digits; the rest of the buffer may be copied using fixed length memcpy + char* dec = printunsignedrev(decend, d.s); + + int declen = int(decend - dec); + LUAU_ASSERT(declen <= 17); + + int dot = declen + d.k; + + // the limits are somewhat arbitrary but changing them may require changing fastmemset/fastmemcpy sizes below + if (dot >= -5 && dot <= 21) + { + // fixed point format + if (dot <= 0) + { + buf[0] = '0'; + buf[1] = '.'; + + fastmemset(buf + 2, '0', -dot, 5); + fastmemcpy(buf + 2 + (-dot), dec, declen, 17); + + return trimzero(buf + 2 + (-dot) + declen); + } + else if (dot == declen) + { + // no dot + fastmemcpy(buf, dec, dot, 17); + + return buf + dot; + } + else if (dot < declen) + { + // dot in the middle + fastmemcpy(buf, dec, dot, 16); + + buf[dot] = '.'; + + fastmemcpy(buf + dot + 1, dec + dot, declen - dot, 16); + + return trimzero(buf + declen + 1); + } + else + { + // no dot, zero padding + fastmemcpy(buf, dec, declen, 17); + fastmemset(buf + declen, '0', dot - declen, 8); + + return buf + dot; + } + } + else + { + // scientific format + buf[0] = dec[0]; + buf[1] = '.'; + fastmemcpy(buf + 2, dec + 1, declen - 1, 16); + + char* exp = trimzero(buf + declen + 1); + + return printexp(exp, dot - 1); + } +} diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 67f832dc..fba07bc3 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -3,7 +3,6 @@ #pragma once #include -#include #define luai_numadd(a, b) ((a) + (b)) #define luai_numsub(a, b) ((a) - (b)) @@ -56,5 +55,9 @@ LUAU_FASTMATH_END #define luai_num2unsigned(i, n) ((i) = (unsigned)(long long)(n)) #endif -#define luai_num2str(s, n) snprintf((s), sizeof(s), LUA_NUMBER_FMT, (n)) +#define LUA_NUMBER_FMT "%.14g" /* TODO: Remove with LuauSchubfach */ +#define LUAI_MAXNUM2STR 48 + +LUAI_FUNC char* luai_num2str(char* buf, double n); + #define luai_str2num(s, p) strtod((s), (p)) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index fd0a15b7..b642cf78 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -289,6 +289,7 @@ typedef struct Proto int sizek; int sizelineinfo; int linegaplog2; + int linedefined; uint8_t nups; /* number of upvalues */ diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 0b55fcea..83b59f3f 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -24,8 +24,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -222,7 +220,7 @@ int luaH_next(lua_State* L, Table* t, StkId key) #define maybesetaboundary(t, boundary) \ { \ - if (FFlag::LuauArrayBoundary && t->aboundary <= 0) \ + if (t->aboundary <= 0) \ t->aboundary = -int(boundary); \ } @@ -705,7 +703,7 @@ int luaH_getn(Table* t) { int boundary = getaboundary(t); - if (FFlag::LuauArrayBoundary && boundary > 0) + if (boundary > 0) { if (!ttisnil(&t->array[t->sizearray - 1]) && t->node == dummynode) return t->sizearray; /* fast-path: the end of the array in `t' already refers to a boundary */ diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index add3588d..cdb276c0 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,6 +13,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) + // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -146,15 +149,19 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint8_t version = read(data, size, offset); // 0 means the rest of the bytecode is the error message - if (version == 0 || version != LBC_VERSION) + if (version == 0) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); + return 1; + } - if (version == 0) - lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); - else - lua_pushfstring(L, "%s: bytecode version mismatch", chunkid); + if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : FFlag::LuauBytecodeV2Read ? (version != LBC_VERSION && version != LBC_VERSION_FUTURE) : (version != LBC_VERSION)) + { + char chunkid[LUA_IDSIZE]; + luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); return 1; } @@ -285,6 +292,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } + if (FFlag::LuauBytecodeV2Force || (FFlag::LuauBytecodeV2Read && version == LBC_VERSION_FUTURE)) + p->linedefined = readVarInt(data, size, offset); + else + p->linedefined = -1; + p->debugname = readString(strings, data, size, offset); uint8_t lineinfo = read(data, size, offset); @@ -307,11 +319,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->lineinfo[j] = lastoffset; } - int lastLine = 0; + int lastline = 0; for (int j = 0; j < intervals; ++j) { - lastLine += read(data, size, offset); - p->abslineinfo[j] = lastLine; + lastline += read(data, size, offset); + p->abslineinfo[j] = lastline; } } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 5d802277..31dd59c8 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -34,10 +34,11 @@ int luaV_tostring(lua_State* L, StkId obj) return 0; else { - char s[LUAI_MAXNUMBER2STR]; + char s[LUAI_MAXNUM2STR]; double n = nvalue(obj); - luai_num2str(s, n); - setsvalue2s(L, obj, luaS_new(L, s)); + char* e = luai_num2str(s, n); + LUAU_ASSERT(e < s + sizeof(s)); + setsvalue2s(L, obj, luaS_newlstr(L, s, e - s)); return 1; } } diff --git a/fuzz/number.cpp b/fuzz/number.cpp new file mode 100644 index 00000000..70447409 --- /dev/null +++ b/fuzz/number.cpp @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" + +#include +#include +#include +#include + +LUAU_FASTFLAG(LuauSchubfach); + +#define LUAI_MAXNUM2STR 48 + +char* luai_num2str(char* buf, double n); + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + if (Size < 8) + return 0; + + FFlag::LuauSchubfach.value = true; + + double num; + memcpy(&num, Data, 8); + + char buf[LUAI_MAXNUM2STR]; + char* end = luai_num2str(buf, num); + LUAU_ASSERT(end < buf + sizeof(buf)); + + *end = 0; + + double rec = strtod(buf, nullptr); + + LUAU_ASSERT(rec == num || (rec != rec && num != num)); + return 0; +} diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 2090b014..41b553b5 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -83,8 +83,6 @@ TEST_SUITE_BEGIN("AstQuery"); TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") { - ScopedFastFlag luauTailArgumentTypeInfo{"LuauTailArgumentTypeInfo", true}; - check(R"( local function foo() return 2 end local function bar(a: number) return -a end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8ca09c0e..210db7ee 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -1911,11 +1912,14 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") +// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_correct_function_no_parenthesis") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); + ACFixture fix; - check(R"( + fix.check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end @@ -1923,7 +1927,7 @@ local function bar2(a: string) return a .. 'x' end return target(b@1 )"); - auto ac = autocomplete('1'); + auto ac = fix.autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1976,11 +1980,14 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") +// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_correct_keywords") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); + ACFixture fix; - check(R"( + fix.check(R"( local function a(x: boolean) end local function b(x: number?) end local function c(x: (number) -> string) end @@ -1997,26 +2004,26 @@ local dc = d(f@4) local ec = e(f@5) )"); - auto ac = autocomplete('1'); + auto ac = fix.autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('2'); + ac = fix.autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('3'); + ac = fix.autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('4'); + ac = fix.autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('5'); + ac = fix.autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -2507,21 +2514,23 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") +TEST_CASE("autocomplete_documentation_symbols") { - loadDefinition(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + fix.loadDefinition(R"( declare y: { x: number, } )"); - fileResolver.source["Module/A"] = R"( + fix.fileResolver.source["Module/A"] = R"( local a = y. )"; - frontend.check("Module/A"); + fix.frontend.check("Module/A"); - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete(fix.frontend, "Module/A", Position{1, 21}, nullCallback); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b055a38e..663b329e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -13,8 +13,11 @@ #include "ScopedFlags.h" #include +#include #include +extern bool verbose; + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; @@ -146,15 +149,21 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n luaL_openlibs(L); // Register a few global functions for conformance tests - static const luaL_Reg funcs[] = { + std::vector funcs = { {"collectgarbage", lua_collectgarbage}, {"loadstring", lua_loadstring}, - {"print", lua_silence}, // Disable print() by default; comment this out to enable debug prints in tests - {nullptr, nullptr}, }; + if (!verbose) + { + funcs.push_back({"print", lua_silence}); + } + + // "null" terminate the list of functions to register + funcs.push_back({nullptr, nullptr}); + lua_pushvalue(L, LUA_GLOBALSINDEX); - luaL_register(L, nullptr, funcs); + luaL_register(L, nullptr, funcs.data()); lua_pop(L, 1); // In some configurations we have a larger C stack consumption which trips some conformance tests @@ -312,8 +321,6 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { - ScopedFastFlag sff("LuauBit32Count", true); - runConformance("bitwise.lua"); } @@ -491,6 +498,9 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { + ScopedFastFlag sffr("LuauBytecodeV2Read", true); + ScopedFastFlag sffw("LuauBytecodeV2Write", true); + runConformance("debug.lua"); } @@ -738,8 +748,6 @@ TEST_CASE("ApiFunctionCalls") // lua_equal with a sleeping thread wake up { - ScopedFastFlag luauActivateBeforeExec("LuauActivateBeforeExec", true); - lua_State* L2 = lua_newthread(L); lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); @@ -913,4 +921,11 @@ TEST_CASE("Coverage") nullptr, nullptr, &copts); } +TEST_CASE("StringConversion") +{ + ScopedFastFlag sff{"LuauSchubfach", true}; + + runConformance("strconv.lua"); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 36d6f561..ca4281a0 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -103,11 +103,6 @@ Fixture::~Fixture() Luau::resetPrintLine(); } -UnfrozenFixture::UnfrozenFixture() - : Fixture(false) -{ -} - AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& parseOptions) { sourceModule.reset(new SourceModule); diff --git a/tests/Fixture.h b/tests/Fixture.h index de2b7381..e01632ea 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -152,15 +152,6 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; -// Disables arena freezing for a given test case. -// Do not use this in new tests. If you are running into access violations, you -// are violating Luau's memory model - the fix is not to use UnfrozenFixture. -// Related: CLI-45692 -struct UnfrozenFixture : Fixture -{ - UnfrozenFixture(); -}; - ModuleName fromString(std::string_view name); template diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 51fcd3d6..405f26e0 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -914,6 +914,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { + ScopedFastFlag sffs("LuauSealExports", true); + frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -927,7 +929,7 @@ return a; --!nonstrict local a = require(script.Parent.A) local b = {} -function a:b() end -- this should error, but doesn't +function a:b() end -- this should error, since A doesn't define a:b() return b )"; @@ -942,8 +944,7 @@ a:b() -- this should error, since A doesn't define a:b() LUAU_REQUIRE_NO_ERRORS(resultA); CheckResult resultB = frontend.check("Module/B"); - // TODO (CLI-45592): this should error, since we shouldn't be adding properties to objects from other modules - LUAU_REQUIRE_NO_ERRORS(resultB); + LUAU_REQUIRE_ERRORS(resultB); CheckResult resultC = frontend.check("Module/C"); LUAU_REQUIRE_ERRORS(resultC); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 71ff4e1b..275782b3 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -620,7 +620,7 @@ struct AssertionCatcher { tripped = 0; oldhook = Luau::assertHandler(); - Luau::assertHandler() = [](const char* expr, const char* file, int line) -> int { + Luau::assertHandler() = [](const char* expr, const char* file, int line, const char* function) -> int { ++tripped; return 0; }; diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 506279b9..5e08654a 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -11,6 +11,8 @@ LUAU_FASTFLAG(LuauFixTonumberReturnType) using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -444,19 +446,28 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("thread_is_a_type") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local co = coroutine.create(function() end) )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.threadType, *requireType("co")); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); + CHECK_EQ(*fix.typeChecker.threadType, *fix.requireType("co")); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("coroutine_resume_anything_goes") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local function nifty(x, y) print(x, y) local z = coroutine.yield(1, 2) @@ -469,12 +480,17 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") local answer = coroutine.resume(co, 3) )"); - LUAU_REQUIRE_NO_ERRORS(result); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("coroutine_wrap_anything_goes") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( --!nonstrict local function nifty(x, y) print(x, y) @@ -488,7 +504,8 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") local answer = f(3) )"); - LUAU_REQUIRE_NO_ERRORS(result); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); } TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b62044fa..114679e3 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -629,8 +629,6 @@ return exports TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") { - ScopedFastFlag luauFunctionArgumentNameSize{"LuauFunctionArgumentNameSize", true}; - CheckResult result = check(R"( local function f(a: T, ...: U...) end diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 76324556..f70f3b1c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4803,8 +4803,6 @@ local bar = foo.nutrition + 100 TEST_CASE_FIXTURE(Fixture, "require_failed_module") { - ScopedFastFlag luauModuleRequireErrorPack{"LuauModuleRequireErrorPack", true}; - fileResolver.source["game/A"] = R"( return unfortunately() )"; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index f55b46a4..1e790eba 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -12,6 +12,8 @@ LUAU_FASTFLAG(LuauQuantifyInPlace2); using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + struct TryUnifyFixture : Fixture { TypeArena arena; @@ -28,7 +30,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TypeVar numberOne{TypeVariant{PrimitiveTypeVar{PrimitiveTypeVar::Number}}}; TypeVar numberTwo = numberOne; - state.tryUnify(&numberOne, &numberTwo); + state.tryUnify(&numberTwo, &numberOne); CHECK(state.errors.empty()); } @@ -41,9 +43,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TypeVar functionTwo{TypeVariant{ FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + CHECK_EQ(functionOne, functionTwo); } @@ -61,7 +66,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TypeVar functionTwoSaved = functionTwo; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -80,10 +85,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK(state.errors.empty()); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -101,11 +109,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK_EQ(1, state.errors.size()); - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -170,7 +179,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; - state.tryUnify(&variadicPack, &testPack); + state.tryUnify(&testPack, &variadicPack); CHECK(!state.errors.empty()); } @@ -180,7 +189,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; - state.tryUnify(&a, &b); + state.tryUnify(&b, &a); CHECK(state.errors.empty()); } @@ -214,32 +223,41 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +TEST_CASE("undo_new_prop_on_unsealed_table") { ScopedFastFlag flags[] = { {"LuauTableSubtypingVariance2", true}, + // This test makes no sense with a committing TxnLog. + {"LuauUseCommittingTxnLog", false}, }; // I am not sure how to make this happen in Luau code. - TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); - TypeId sealedTable = arena.addType(TableTypeVar{ - {{"prop", Property{getSingletonTypes().numberType}}}, - std::nullopt, - TypeLevel{}, - TableState::Sealed - }); + TryUnifyFixture fix; + + TypeId unsealedTable = fix.arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = + fix.arena.addType(TableTypeVar{{{"prop", Property{getSingletonTypes().numberType}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); const TableTypeVar* ttv = get(unsealedTable); REQUIRE(ttv); - state.tryUnify(unsealedTable, sealedTable); + fix.state.tryUnify(sealedTable, unsealedTable); // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. CHECK(!ttv->props.empty()); - state.log.rollback(); + fix.state.DEPRECATED_log.rollback(); CHECK(ttv->props.empty()); } +TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") +{ + TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + + ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); + CHECK(unifyErrors.size() == 0); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3f4420cd..5d37b032 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -10,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -263,10 +265,13 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } -// CLI-45791 -TEST_CASE_FIXTURE(UnfrozenFixture, "type_pack_hidden_free_tail_infinite_growth") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_pack_hidden_free_tail_infinite_growth") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( --!nonstrict if _ then _[function(l0)end],l0 = _ @@ -278,7 +283,8 @@ elseif _ then end )"); - LUAU_REQUIRE_ERRORS(result); + // Switch back to LUAU_REQUIRE_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() > 0); } TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 2357869e..b54ba996 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -282,16 +283,19 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") +TEST_CASE("optional_union_follow") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local y: number? = 2 local x = y local function f(a: number, b: typeof(x), c: typeof(x)) return -a end return f() )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE_EQ(result.errors.size(), 1); + // LUAU_REQUIRE_ERROR_COUNT(1, result); auto acm = get(result.errors[0]); REQUIRE(acm); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 13db923e..2e0d149e 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -185,6 +185,8 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { + ScopedFastFlag sff{"LuauSealExports", true}; + TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; @@ -261,7 +263,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); - CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); + CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); } TEST_CASE("tagging_tables") diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 9cf3c742..8c96ab33 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,13 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") +-- info linedefined & line +function testlinedefined() + local line = debug.info(1, "l") + local linedefined = debug.info(testlinedefined, "l") + assert(linedefined + 1 == line) +end + +testlinedefined() + return 'OK' diff --git a/tests/conformance/strconv.lua b/tests/conformance/strconv.lua new file mode 100644 index 00000000..85ad0295 --- /dev/null +++ b/tests/conformance/strconv.lua @@ -0,0 +1,51 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing string-number conversion") + +-- zero +assert(tostring(0) == "0") +assert(tostring(0/-1) == "-0") + +-- specials +assert(tostring(1/0) == "inf") +assert(tostring(-1/0) == "-inf") +assert(tostring(0/0) == "nan") + +-- integers +assert(tostring(1) == "1") +assert(tostring(42) == "42") +assert(tostring(-4294967296) == "-4294967296") +assert(tostring(9007199254740991) == "9007199254740991") + +-- decimals +assert(tostring(0.5) == "0.5") +assert(tostring(0.1) == "0.1") +assert(tostring(-0.17) == "-0.17") +assert(tostring(math.pi) == "3.141592653589793") + +-- fuzzing corpus +assert(tostring(5.4536123983019448e-311) == "5.453612398302e-311") +assert(tostring(5.4834368411298348e-311) == "5.48343684113e-311") +assert(tostring(4.4154895841930002e-305) == "4.415489584193e-305") +assert(tostring(1125968630513728) == "1125968630513728") +assert(tostring(3.3951932655938423e-313) == "3.3951932656e-313") +assert(tostring(1.625) == "1.625") +assert(tostring(4.9406564584124654e-324) == "5.e-324") +assert(tostring(2.0049288280105384) == "2.0049288280105384") +assert(tostring(3.0517578125e-05) == "0.000030517578125") +assert(tostring(1.383544921875) == "1.383544921875") +assert(tostring(3.0053350932691001) == "3.0053350932691") +assert(tostring(0.0001373291015625) == "0.0001373291015625") +assert(tostring(-1.9490628022799998e+289) == "-1.94906280228e+289") +assert(tostring(-0.00610404721867928) == "-0.00610404721867928") +assert(tostring(0.00014495849609375) == "0.00014495849609375") +assert(tostring(0.453125) == "0.453125") +assert(tostring(-4.2375343999999997e+73) == "-4.2375344e+73") +assert(tostring(1.3202313930270133e-192) == "1.3202313930270133e-192") +assert(tostring(3.6984408976312836e+19) == "36984408976312840000") +assert(tostring(2.0563000527063302) == "2.05630005270633") +assert(tostring(4.8970527433648997e-260) == "4.8970527433649e-260") +assert(tostring(1.62890625) == "1.62890625") +assert(tostring(1.1295093211933533e+65) == "1.1295093211933533e+65") + +return "OK" diff --git a/tests/main.cpp b/tests/main.cpp index ed17070c..cd24e100 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -2,6 +2,8 @@ #include "Luau/Common.h" #define DOCTEST_CONFIG_IMPLEMENT +// Our calls to parseOption/parseFlag don't provide a prefix so set the prefix to the empty string. +#define DOCTEST_CONFIG_OPTIONS_PREFIX "" #include "doctest.h" #ifdef _WIN32 @@ -18,6 +20,10 @@ #include +// Indicates if verbose output is enabled. +// Currently, this enables output from lua's 'print', but other verbose output could be enabled eventually. +bool verbose = false; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -46,7 +52,7 @@ static bool debuggerPresent() #endif } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) LUAU_DEBUGBREAK(); @@ -235,6 +241,11 @@ int main(int argc, char** argv) return 0; } + if (doctest::parseFlag(argc, argv, "--verbose")) + { + verbose = true; + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -261,7 +272,15 @@ int main(int argc, char** argv) } } - return context.run(); + int result = context.run(); + if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) + { + printf("Additional command line options:\n"); + printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); + printf(" --fflags= Sets specified fast flags\n"); + printf(" --list-fflags List all fast flags\n"); + } + return result; } diff --git a/tools/numprint.py b/tools/numprint.py new file mode 100644 index 00000000..47ad36d9 --- /dev/null +++ b/tools/numprint.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to generate power tables for Schubfach algorithm (see lnumprint.cpp) + +import math +import sys + +(_, pow10min, pow10max, compact) = sys.argv +pow10min = int(pow10min) +pow10max = int(pow10max) +compact = compact == "True" + +# extract high 128 bits of the value +def high128(n, roundup): + L = math.ceil(math.log2(n)) + + r = 0 + for i in range(L - 128, L): + if i >= 0 and (n & (1 << i)) != 0: + r |= (1 << (i - L + 128)) + + return r + (1 if roundup else 0) + +def pow10approx(n): + if n == 0: + return 1 << 127 + elif n > 0: + return high128(10**n, 5**n >= 2**128) + else: + # 10^-n is a binary fraction that can't be represented in floating point + # we need to extract top 128 bits of the fraction starting from the first 1 + # to get there, we need to divide 2^k by 10^n for a sufficiently large k and repeat the extraction process + p = 10**-n + k = 2**128 * 16**-n # this guarantees that the fraction has more than 128 extra bits + return high128(k // p, True) + +def pow5_64(n): + assert(n >= 0) + if n == 0: + return 1 << 63 + else: + return high128(5**n, False) >> 64 + +if not compact: + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1): + h = hex(pow10approx(p))[2:] + assert(len(h) == 32) + print(" {0x%s, 0x%s}," % (h[0:16].upper(), h[16:32].upper())) + print("}") +else: + print("// kPow5Table") + print("{") + for i in range(16): + print(" " + hex(pow5_64(i)) + ",") + print("}") + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1, 16): + base = pow10approx(p) + errw = 0 + for i in range(16): + real = pow10approx(p + i) + appr = (base * pow5_64(i)) >> 64 + scale = 1 if appr < (1 << 127) else 0 # 1-bit scale + + offset = (appr << scale) - real + assert(offset >= -4 and offset <= 3) # 3-bit offset + assert((appr << scale) >> 64 == real >> 64) # offset only affects low half + assert((appr << scale) - offset == real) # validate full reconstruction + + err = (scale << 3) | (offset + 4) + errw |= err << (i * 4) + + hbase = hex(base)[2:] + assert(len(hbase) == 32) + assert(errw < 1 << 64) + + print(" {0x%s, 0x%s, 0x%16x}," % (hbase[0:16], hbase[16:32], errw)) + print("}") From d189bd9b1a3c6ecc882c793c7382a1c886635900 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 17:37:50 -0800 Subject: [PATCH 013/102] Enable V2Read flag early --- VM/src/lvmload.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index cdb276c0..7839c68c 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, true) LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens From 80d5c0000ee34767f8fec7ec24c02a438174a667 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 14 Jan 2022 08:06:31 -0800 Subject: [PATCH 014/102] Sync to upstream/release/510 --- Analysis/include/Luau/TypeInfer.h | 12 +- Analysis/include/Luau/TypeVar.h | 20 +- Analysis/src/Autocomplete.cpp | 79 +-- Analysis/src/Error.cpp | 12 +- Analysis/src/IostreamHelpers.cpp | 8 +- Analysis/src/JsonEncoder.cpp | 38 ++ Analysis/src/Module.cpp | 26 +- Analysis/src/ToString.cpp | 107 +++- Analysis/src/Transpiler.cpp | 65 +- Analysis/src/TypeAttach.cpp | 12 +- Analysis/src/TypeInfer.cpp | 310 ++++++++-- Analysis/src/TypeVar.cpp | 4 + Analysis/src/Unifier.cpp | 17 +- Ast/include/Luau/Ast.h | 49 +- Ast/include/Luau/Parser.h | 6 +- Ast/src/Ast.cpp | 32 +- Ast/src/Parser.cpp | 128 ++-- CLI/Analyze.cpp | 18 +- CLI/Repl.cpp | 137 ++-- Compiler/src/Builtins.cpp | 197 ++++++ Compiler/src/Builtins.h | 41 ++ Compiler/src/BytecodeBuilder.cpp | 6 +- Compiler/src/Compiler.cpp | 894 ++------------------------- Compiler/src/ConstantFolding.cpp | 394 ++++++++++++ Compiler/src/ConstantFolding.h | 48 ++ Compiler/src/TableShape.cpp | 129 ++++ Compiler/src/TableShape.h | 21 + Compiler/src/ValueTracking.cpp | 103 +++ Compiler/src/ValueTracking.h | 42 ++ Sources.cmake | 8 + VM/src/ldo.cpp | 12 +- VM/src/lfunc.cpp | 8 +- VM/src/lstrlib.cpp | 20 +- bench/tests/sunspider/3d-cube.lua | 32 +- tests/Autocomplete.test.cpp | 61 +- tests/Compiler.test.cpp | 93 +-- tests/Conformance.test.cpp | 8 - tests/Fixture.cpp | 5 +- tests/Fixture.h | 2 +- tests/Parser.test.cpp | 53 +- tests/ToString.test.cpp | 22 +- tests/Transpiler.test.cpp | 14 +- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 28 +- tests/TypeInfer.singletons.test.cpp | 26 +- tests/TypeInfer.tables.test.cpp | 101 ++- tests/TypeInfer.test.cpp | 150 ++++- tests/TypeInfer.typePacks.cpp | 324 ++++++++++ 48 files changed, 2652 insertions(+), 1272 deletions(-) create mode 100644 Compiler/src/Builtins.cpp create mode 100644 Compiler/src/Builtins.h create mode 100644 Compiler/src/ConstantFolding.cpp create mode 100644 Compiler/src/ConstantFolding.h create mode 100644 Compiler/src/TableShape.cpp create mode 100644 Compiler/src/TableShape.h create mode 100644 Compiler/src/ValueTracking.cpp create mode 100644 Compiler/src/ValueTracking.h diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 312283b0..aa090014 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -97,6 +97,12 @@ struct ApplyTypeFunction : Substitution TypePackId clean(TypePackId tp) override; }; +struct GenericTypeDefinitions +{ + std::vector genericTypes; + std::vector genericPacks; +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -146,7 +152,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -336,8 +342,8 @@ private: const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d6e17714..fd2c2afa 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -181,6 +181,18 @@ const T* get(const SingletonTypeVar* stv) return nullptr; } +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; +}; + struct FunctionArgument { Name name; @@ -358,8 +370,8 @@ struct ClassTypeVar struct TypeFun { // These should all be generic - std::vector typeParams; - std::vector typePackParams; + std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -369,13 +381,13 @@ struct TypeFun TypeId type; TypeFun() = default; - TypeFun(std::vector typeParams, TypeId type) + TypeFun(std::vector typeParams, TypeId type) : typeParams(std::move(typeParams)) , type(type) { } - TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) : typeParams(std::move(typeParams)) , typePackParams(std::move(typePackParams)) , type(type) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 67ebd075..7a801f97 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,11 +13,10 @@ #include LUAU_FASTFLAG(LuauUseCommittingTxnLog) -LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); +LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -291,51 +290,23 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ expectedType = follow(*it); } - if (FFlag::LuauAutocompletePreferToCallFunctions) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - } - } - - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - } - else - { - if (canUnify(ty, expectedType)) - return TypeCorrectKind::Correct; - - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); - - if (!ftv) - return TypeCorrectKind::None; - auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty()) - return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; } - - return TypeCorrectKind::None; } + + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -435,13 +406,28 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) + if (FFlag::LuauMissingFollowACMetatables) { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } + } + else + { + if (get(indexIt->second.type) || get(indexIt->second.type)) + autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); + else if (auto indexFunction = get(indexIt->second.type)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } } } } @@ -1224,7 +1210,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (auto it = module.astTypes.find(node->asExpr())) autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } - else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) + else if (autocompleteIfElseExpression(node, ancestry, position, result)) return; else if (node->is()) return; @@ -1261,8 +1247,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - if (FFlag::LuauIfElseExpressionAnalysisSupport) - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index ce832c6b..88069f1f 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -190,24 +190,24 @@ struct ErrorConverter { name += "<"; bool first = true; - for (TypeId t : e.typeFun.typeParams) + for (auto param : e.typeFun.typeParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.ty); } - for (TypePackId t : e.typeFun.typePackParams) + for (auto param : e.typeFun.typePackParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.tp); } name += ">"; @@ -544,13 +544,13 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { - if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) + if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty) return false; } for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp) return false; } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 5bc76ade..19c2ddab 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -96,24 +96,24 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "<"; bool first = true; - for (TypeId t : error.typeFun.typeParams) + for (auto param : error.typeFun.typeParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.ty); } - for (TypePackId t : error.typeFun.typePackParams) + for (auto param : error.typeFun.typePackParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.tp); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 23491a5a..8dd597e1 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,6 +5,8 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace Luau { @@ -337,6 +339,42 @@ struct AstJsonEncoder : public AstVisitor writeRaw("}"); } + void write(const AstGenericType& genericType) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericType.name); + if (genericType.defaultValue) + write("type", genericType.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericType.name); + } + } + + void write(const AstGenericTypePack& genericTypePack) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("type", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericTypePack.name); + } + } + void write(AstExprTable::Item::Kind kind) { switch (kind) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index cff85897..9f352f4b 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypeAliasDefaults) namespace Luau { @@ -447,11 +448,28 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; - for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 889dd6dc..4b898d3a 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) +LUAU_FASTFLAG(LuauTypeAliasDefaults) /* * Prefix generic typenames with gen- @@ -209,6 +210,14 @@ struct StringifierState result.name += s; } + + void emit(const char* s) + { + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + return; + + result.name += s; + } }; struct TypeVarStringifier @@ -280,13 +289,28 @@ struct TypeVarStringifier else first = false; - if (!singleTp) - state.emit("("); + if (FFlag::LuauTypeAliasDefaults) + { + bool wrap = !singleTp && get(follow(tp)); - stringify(tp); + if (wrap) + state.emit("("); - if (!singleTp) - state.emit(")"); + stringify(tp); + + if (wrap) + state.emit(")"); + } + else + { + if (!singleTp) + state.emit("("); + + stringify(tp); + + if (!singleTp) + state.emit(")"); + } } if (types.size() || typePacks.size()) @@ -1086,7 +1110,7 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) { std::string s = prefix; @@ -1175,6 +1199,77 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + if (!FFlag::LuauTypeAliasDefaults) + return toStringNamedFunction_DEPRECATED(prefix, ftv, opts); + + ToStringResult result; + StringifierState state(opts, result, opts.nameMap); + TypeVarStringifier tvs{state}; + + state.emit(prefix); + + if (!opts.hideNamedFunctionTypeParameters) + tvs.stringify(ftv.generics, ftv.genericPacks); + + state.emit("("); + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + state.emit(", "); + first = false; + + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); + ++argNameIter; + } + else + { + state.emit("_: "); + } + + tvs.stringify(*argPackIter); + ++argPackIter; + } + + if (argPackIter.tail()) + { + if (!first) + state.emit(", "); + + state.emit("...: "); + + if (auto vtp = get(*argPackIter.tail())) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } + + state.emit("): "); + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + + if (wrap) + state.emit("("); + + tvs.stringify(ftv.retType); + + if (wrap) + state.emit(")"); + + return result.name; +} + std::string dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 8e13ea5b..f5908683 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace { bool isIdentifierStartChar(char c) @@ -793,14 +795,47 @@ struct Printer for (auto o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } + } + else + { + writer.identifier(o.name.value); + } } for (auto o : a->genericPacks) { comma(); - writer.identifier(o.value); - writer.symbol("..."); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + else + { + writer.identifier(o.name.value); + writer.symbol("..."); + } } writer.symbol(">"); @@ -846,12 +881,20 @@ struct Printer for (const auto& o : func.generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); @@ -979,12 +1022,20 @@ struct Printer for (const auto& o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 9e61c792..2ec02093 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -212,24 +212,24 @@ public: if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName("")); - AstArray generics; + AstArray generics; generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { if (auto gtv = get(*it)) - generics.data[numGenerics++] = AstName(gtv->name.c_str()); + generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } - AstArray genericPacks; + AstArray genericPacks; genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { if (auto gtv = get(*it)) - genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); + genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1689a5c3..bedcc022 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -15,8 +15,8 @@ #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" -#include #include +#include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) @@ -24,25 +24,30 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau @@ -279,6 +284,14 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); } + if (FFlag::LuauPerModuleUnificationCache) + { + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.skipCacheForType.clear(); + } + return std::move(currentModule); } @@ -1213,18 +1226,18 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); - for (TypeId ty : binding->typeParams) + for (auto param : binding->typeParams) { - auto generic = get(ty); + auto generic = get(param.ty); LUAU_ASSERT(generic); - aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; } - for (TypePackId tp : binding->typePackParams) + for (auto param : binding->typePackParams) { - auto generic = get(tp); + auto generic = get(param.tp); LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; + aliasScope->privateTypePackBindings[generic->name] = param.tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1233,9 +1246,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the table is already named and we want to rename the type function, we have to bind new alias to a copy if (ttv->name) { + bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), + binding->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), + binding->typePackParams.begin(), binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // Copy can be skipped if this is an identical alias - if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - ttv->instantiatedTypePackParams != binding->typePackParams) + if (ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1243,8 +1264,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; - clone.instantiatedTypeParams = binding->typeParams; - clone.instantiatedTypePackParams = binding->typePackParams; + + for (auto param : binding->typeParams) + clone.instantiatedTypeParams.push_back(param.ty); + + for (auto param : binding->typePackParams) + clone.instantiatedTypePackParams.push_back(param.tp); ty = addType(std::move(clone)); } @@ -1252,8 +1277,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else { ttv->name = name; - ttv->instantiatedTypeParams = binding->typeParams; - ttv->instantiatedTypePackParams = binding->typePackParams; + + ttv->instantiatedTypeParams.clear(); + for (auto param : binding->typeParams) + ttv->instantiatedTypeParams.push_back(param.ty); + + ttv->instantiatedTypePackParams.clear(); + for (auto param : binding->typePackParams) + ttv->instantiatedTypePackParams.push_back(param.tp); } } else if (auto mtv = getMutable(follow(ty))) @@ -1367,9 +1398,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); + TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); FunctionTypeVar* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -1394,7 +1437,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& ExprResult result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr); + result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt); else if (expr.is()) result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) @@ -1438,21 +1481,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - { - if (FFlag::LuauIfElseExpressionAnalysisSupport) - { - result = checkExpr(scope, *a); - } - else - { - // Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse - // because we would generate an ICE. We also can't use the default value - // of result, because it will lead to a compiler crash. - // Note: LuauIfElseExpressionBaseSupport can be used to disable parser support - // for if-else expressions which will mean this node type is never created. - result = {anyType}; - } - } + result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt); else ice("Unhandled AstExpr?"); @@ -1895,7 +1924,7 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; + TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); @@ -2549,23 +2578,34 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { ExprResult result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); reportErrors(resolve(result.predicates, trueScope, true)); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr); + ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); + ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - unify(falseType.type, trueType.type, expr.location); + if (FFlag::LuauIfElseBranchTypeUnion) + { + if (falseType.type == trueType.type) + return {trueType.type}; - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. - return {trueType.type}; + std::vector types = reduceUnion({trueType.type, falseType.type}); + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; + } + else + { + unify(falseType.type, trueType.type, expr.location); + + // TODO: normalize(UnionTypeVar{{trueType, falseType}}) + // For now both trueType and falseType must be the same type. + return {trueType.type}; + } } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) @@ -3032,7 +3072,20 @@ std::pair TypeChecker::checkFunctionSignature( defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); - TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId funTy = + addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); FunctionTypeVar* ftv = getMutable(funTy); @@ -4848,11 +4901,38 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - if (!lit->hasParameterList && !tf->typePackParams.empty()) + bool hasDefaultTypes = false; + bool hasDefaultPacks = false; + bool parameterCountErrorReported = false; + + if (FFlag::LuauTypeAliasDefaults) { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); + hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!lit->hasParameterList) + { + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } + } + } + else + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } } std::vector typeParams; @@ -4892,14 +4972,89 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); + if (FFlag::LuauTypeAliasDefaults) + { + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); + + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); + + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) + { + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.typeArguments.clear(); + applyTypeFunction.typePackArguments.clear(); + applyTypeFunction.currentModule = currentModule; + applyTypeFunction.level = scope->level; + applyTypeFunction.encounteredForwardedType = false; + + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; + + if (typesProvided < typesRequired) + { + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); + + if (!defaultTy) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); + } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); + } + } + + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); + + if (!defaultTp) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); + } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); + } + } + } + } + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) typePackParams.push_back(addTypePack({})); if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + if (!parameterCountErrorReported) + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { @@ -4913,11 +5068,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + if (FFlag::LuauRecursiveTypeParameterRestriction) { + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // If the generic parameters and the type arguments are the same, we are about to // perform an identity substitution, which we can just short-circuit. - return tf->type; + if (sameTys && sameTps) + return tf->type; } return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); @@ -4948,7 +5112,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); - TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); FunctionTypeVar* ftv = getMutable(fnType); @@ -5137,11 +5313,11 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; applyTypeFunction.typePackArguments.clear(); for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -5213,17 +5389,23 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; - std::vector generics; - for (const AstName& generic : genericNames) + std::vector generics; + + for (const AstGenericType& generic : genericNames) { - Name n = generic.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && generic.defaultValue) + defaultValue = resolveType(scope, *generic.defaultValue); + + Name n = generic.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5246,14 +5428,20 @@ std::pair, std::vector> TypeChecker::createGener g = addType(Unifiable::Generic{level, n}); } - generics.push_back(g); + generics.push_back({g, defaultValue}); scope->privateTypeBindings[n] = TypeFun{{}, g}; } - std::vector genericPacks; - for (const AstName& genericPack : genericPackNames) + std::vector genericPacks; + + for (const AstGenericTypePack& genericPack : genericPackNames) { - Name n = genericPack.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue) + defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + + Name n = genericPack.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5276,7 +5464,7 @@ std::pair, std::vector> TypeChecker::createGener g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } - genericPacks.push_back(g); + genericPacks.push_back({g, defaultValue}); scope->privateTypePackBindings[n] = g; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4cab79c8..ac2b2541 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(DebugLuauFreezeArena) @@ -453,6 +454,9 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { + if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs)) + return true; + return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 393a84a7..6873c657 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) @@ -1170,9 +1171,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); @@ -1370,9 +1377,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 5b4bfa03..573850a5 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -334,6 +334,20 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName +struct AstGenericType +{ + AstName name; + Location location; + AstType* defaultValue = nullptr; +}; + +struct AstGenericTypePack +{ + AstName name; + Location location; + AstTypePack* defaultValue = nullptr; +}; + extern int gAstRttiIndex; template @@ -569,15 +583,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, + AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, std::optional argLocation = std::nullopt); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; bool hasReturnAnnotation; @@ -942,14 +956,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -972,14 +986,15 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; @@ -1077,13 +1092,13 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, - const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 87ebc48b..40ecdcdd 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -219,7 +219,7 @@ private: AstTableIndexer* parseTableIndexerAnnotation(); AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); @@ -281,7 +281,7 @@ private: Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(); + std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); @@ -418,6 +418,8 @@ private: std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; std::string scratchData; }; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index e709894d..9b5bc0c7 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -158,9 +158,10 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, + std::optional argLocation) : AstExpr(ClassIndex(), location) , generics(generics) , genericPacks(genericPacks) @@ -641,8 +642,8 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -655,7 +656,21 @@ AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) + { + for (const AstGenericType& el : generics) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + + for (const AstGenericTypePack& el : genericPacks) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + type->visit(visitor); + } } AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) @@ -671,8 +686,9 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -778,7 +794,7 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) , generics(generics) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 72f61649..77787cb1 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,10 +10,10 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) +LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) namespace Luau { @@ -394,23 +394,13 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauIfStatementRecursionGuard) - { - unsigned int recursionCounterOld = recursionCounter; - incrementRecursionCounter("elseif"); - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; - } - else - { - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - } + unsigned int recursionCounterOld = recursionCounter; + incrementRecursionCounter("elseif"); + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + recursionCounter = recursionCounterOld; } else { @@ -772,7 +762,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults); expectAndConsume('=', "type alias"); @@ -788,8 +778,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; @@ -849,7 +839,7 @@ AstStat* Parser::parseDeclaration(const Location& start) nextLexeme(); Name globalName = parseName("global function name"); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); @@ -991,7 +981,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1228,8 +1218,8 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; + AstArray generics{nullptr, 0}; + AstArray genericPacks{nullptr, 0}; AstArray types = copy(result); AstArray> names = copy(resultNames); @@ -1363,7 +1353,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme begin = lexer.current(); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme parameterStart = lexer.current(); @@ -1401,7 +1391,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, +AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) { @@ -1448,7 +1438,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1461,7 +1451,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } else @@ -1498,7 +1488,7 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() TempVector parts(scratchAnnotation); - auto [type, typePack] = parseSimpleTypeAnnotation(true); + auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); if (typePack) { @@ -1521,7 +1511,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); recursionCounter = oldRecursionCount; @@ -2121,7 +2111,7 @@ AstExpr* Parser::parseSimpleExpr() { return parseTableConstructor(); } - else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) + else if (lexer.current().type == Lexeme::ReservedIf) { return parseIfElseExpr(); } @@ -2341,10 +2331,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeList() +std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) { - TempVector names{scratchName}; - TempVector namePacks{scratchPackName}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { @@ -2352,21 +2342,73 @@ std::pair, AstArray> Parser::parseGenericTypeList() nextLexeme(); bool seenPack = false; + bool seenDefault = false; + while (true) { + Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack)) { seenPack = true; - nextLexeme(); - namePacks.push_back(name); + + if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3) + report(lexer.current().location, "Generic types come before generic type packs"); + else + nextLexeme(); + + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + Lexeme packBegin = lexer.current(); + + if (shouldParseTypePackAnnotation(lexer)) + { + auto typePack = parseTypePackAnnotation(); + + namePacks.push_back({name, nameLocation, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type pack after type pack name"); + + namePacks.push_back({name, nameLocation, nullptr}); + } } else { - if (seenPack) + if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack) report(lexer.current().location, "Generic types come before generic type packs"); - names.push_back(name); + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + AstType* defaultType = parseTypeAnnotation(); + + names.push_back({name, nameLocation, defaultType}); + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type after type name"); + + names.push_back({name, nameLocation, nullptr}); + } } if (lexer.current().type == ',') @@ -2378,8 +2420,8 @@ std::pair, AstArray> Parser::parseGenericTypeList() expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 54a9a26f..e0dc3e0f 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,8 @@ #include "FileUtils.h" +LUAU_FASTFLAG(DebugLuauTimeTracing) + enum class ReportFormat { Default, @@ -105,6 +107,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -213,8 +216,18 @@ int main(int argc, char** argv) format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; + else if (strcmp(argv[i], "--timetrace") == 0) + FFlag::DebugLuauTimeTracing.value = true; } +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; @@ -240,7 +253,10 @@ int main(int argc, char** argv) fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); } - return (format == ReportFormat::Luacheck) ? 0 : failed; + if (format == ReportFormat::Luacheck) + return 0; + else + return failed ? 1 : 0; } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 26d4333a..36747f48 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -19,6 +19,16 @@ #include #endif +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CliMode +{ + Unknown, + Repl, + Compile, + RunSourceFiles +}; + enum class CompileFormat { Text, @@ -485,8 +495,10 @@ static void displayHelp(const char* argv0) printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -503,71 +515,112 @@ int main(int argc, char** argv) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = true; - if (argc == 1) - { - runRepl(); - return 0; - } - - if (argc >= 2 && strcmp(argv[1], "--help") == 0) - { - displayHelp(argv[0]); - return 0; - } - + CliMode mode = CliMode::Unknown; + CompileFormat compileFormat{}; + int profile = 0; + bool coverage = false; + // Set the mode if the user has explicitly specified one. + int argStart = 1; if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - CompileFormat format = CompileFormat::Text; + argStart++; + mode = CliMode::Compile; + if (strcmp(argv[1], "--compile") == 0) + { + compileFormat = CompileFormat::Text; + } + else if (strcmp(argv[1], "--compile=binary") == 0) + { + compileFormat = CompileFormat::Binary; + } + else if (strcmp(argv[1], "--compile=text") == 0) + { + compileFormat = CompileFormat::Text; + } + else + { + fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n"); + return -1; + } + } - if (strcmp(argv[1], "--compile=binary") == 0) - format = CompileFormat::Binary; + for (int i = argStart; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strcmp(argv[i], "--profile") == 0) + { + profile = 10000; // default to 10 KHz + } + else if (strncmp(argv[i], "--profile=", 10) == 0) + { + profile = atoi(argv[i] + 10); + } + else if (strcmp(argv[i], "--coverage") == 0) + { + coverage = true; + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; +#if !defined(LUAU_ENABLE_TIME_TRACE) + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; +#endif + } + else if (argv[i][0] == '-') + { + fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } + + const std::vector files = getSourceFiles(argc, argv); + if (mode == CliMode::Unknown) + { + mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; + } + + switch (mode) + { + case CliMode::Compile: + { #ifdef _WIN32 - if (format == CompileFormat::Binary) + if (compileFormat == CompileFormat::Binary) _setmode(_fileno(stdout), _O_BINARY); #endif - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), format); + failed += !compileFile(path.c_str(), compileFormat); - return failed; + return failed ? 1 : 0; } - + case CliMode::Repl: + { + runRepl(); + return 0; + } + case CliMode::RunSourceFiles: { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); setupState(L); - int profile = 0; - bool coverage = false; - - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] != '-') - continue; - - if (strcmp(argv[i], "--profile") == 0) - profile = 10000; // default to 10 KHz - else if (strncmp(argv[i], "--profile=", 10) == 0) - profile = atoi(argv[i] + 10); - else if (strcmp(argv[i], "--coverage") == 0) - coverage = true; - } - if (profile) profilerStart(L, profile); if (coverage) coverageInit(L); - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) @@ -582,7 +635,11 @@ int main(int argc, char** argv) if (coverage) coverageDump("coverage.out"); - return failed; + return failed ? 1 : 0; + } + case CliMode::Unknown: + default: + LUAU_ASSERT(!"Unhandled cli mode."); } } diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp new file mode 100644 index 00000000..e344eb91 --- /dev/null +++ b/Compiler/src/Builtins.cpp @@ -0,0 +1,197 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Builtins.h" + +#include "Luau/Bytecode.h" +#include "Luau/Compiler.h" + +namespace Luau +{ +namespace Compile +{ + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables) +{ + if (AstExprLocal* expr = node->as()) + { + const Variable* v = variables.find(expr->local); + + return v && !v->written && v->init ? getBuiltin(v->init, globals, variables) : Builtin(); + } + else if (AstExprIndexName* expr = node->as()) + { + if (AstExprGlobal* object = expr->expr->as()) + { + return getGlobalState(globals, object->name) == Global::Default ? Builtin{object->name, expr->index} : Builtin(); + } + else + { + return Builtin(); + } + } + else if (AstExprGlobal* expr = node->as()) + { + return getGlobalState(globals, expr->name) == Global::Default ? Builtin{AstName(), expr->name} : Builtin(); + } + else + { + return Builtin(); + } +} + +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) +{ + if (builtin.empty()) + return -1; + + if (builtin.isGlobal("assert")) + return LBF_ASSERT; + + if (builtin.isGlobal("type")) + return LBF_TYPE; + + if (builtin.isGlobal("typeof")) + return LBF_TYPEOF; + + if (builtin.isGlobal("rawset")) + return LBF_RAWSET; + if (builtin.isGlobal("rawget")) + return LBF_RAWGET; + if (builtin.isGlobal("rawequal")) + return LBF_RAWEQUAL; + + if (builtin.isGlobal("unpack")) + return LBF_TABLE_UNPACK; + + if (builtin.object == "math") + { + if (builtin.method == "abs") + return LBF_MATH_ABS; + if (builtin.method == "acos") + return LBF_MATH_ACOS; + if (builtin.method == "asin") + return LBF_MATH_ASIN; + if (builtin.method == "atan2") + return LBF_MATH_ATAN2; + if (builtin.method == "atan") + return LBF_MATH_ATAN; + if (builtin.method == "ceil") + return LBF_MATH_CEIL; + if (builtin.method == "cosh") + return LBF_MATH_COSH; + if (builtin.method == "cos") + return LBF_MATH_COS; + if (builtin.method == "deg") + return LBF_MATH_DEG; + if (builtin.method == "exp") + return LBF_MATH_EXP; + if (builtin.method == "floor") + return LBF_MATH_FLOOR; + if (builtin.method == "fmod") + return LBF_MATH_FMOD; + if (builtin.method == "frexp") + return LBF_MATH_FREXP; + if (builtin.method == "ldexp") + return LBF_MATH_LDEXP; + if (builtin.method == "log10") + return LBF_MATH_LOG10; + if (builtin.method == "log") + return LBF_MATH_LOG; + if (builtin.method == "max") + return LBF_MATH_MAX; + if (builtin.method == "min") + return LBF_MATH_MIN; + if (builtin.method == "modf") + return LBF_MATH_MODF; + if (builtin.method == "pow") + return LBF_MATH_POW; + if (builtin.method == "rad") + return LBF_MATH_RAD; + if (builtin.method == "sinh") + return LBF_MATH_SINH; + if (builtin.method == "sin") + return LBF_MATH_SIN; + if (builtin.method == "sqrt") + return LBF_MATH_SQRT; + if (builtin.method == "tanh") + return LBF_MATH_TANH; + if (builtin.method == "tan") + return LBF_MATH_TAN; + if (builtin.method == "clamp") + return LBF_MATH_CLAMP; + if (builtin.method == "sign") + return LBF_MATH_SIGN; + if (builtin.method == "round") + return LBF_MATH_ROUND; + } + + if (builtin.object == "bit32") + { + if (builtin.method == "arshift") + return LBF_BIT32_ARSHIFT; + if (builtin.method == "band") + return LBF_BIT32_BAND; + if (builtin.method == "bnot") + return LBF_BIT32_BNOT; + if (builtin.method == "bor") + return LBF_BIT32_BOR; + if (builtin.method == "bxor") + return LBF_BIT32_BXOR; + if (builtin.method == "btest") + return LBF_BIT32_BTEST; + if (builtin.method == "extract") + return LBF_BIT32_EXTRACT; + if (builtin.method == "lrotate") + return LBF_BIT32_LROTATE; + if (builtin.method == "lshift") + return LBF_BIT32_LSHIFT; + if (builtin.method == "replace") + return LBF_BIT32_REPLACE; + if (builtin.method == "rrotate") + return LBF_BIT32_RROTATE; + if (builtin.method == "rshift") + return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz") + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz") + return LBF_BIT32_COUNTRZ; + } + + if (builtin.object == "string") + { + if (builtin.method == "byte") + return LBF_STRING_BYTE; + if (builtin.method == "char") + return LBF_STRING_CHAR; + if (builtin.method == "len") + return LBF_STRING_LEN; + if (builtin.method == "sub") + return LBF_STRING_SUB; + } + + if (builtin.object == "table") + { + if (builtin.method == "insert") + return LBF_TABLE_INSERT; + if (builtin.method == "unpack") + return LBF_TABLE_UNPACK; + } + + if (options.vectorCtor) + { + if (options.vectorLib) + { + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) + return LBF_VECTOR; + } + else + { + if (builtin.isGlobal(options.vectorCtor)) + return LBF_VECTOR; + } + } + + return -1; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h new file mode 100644 index 00000000..60df53a1 --- /dev/null +++ b/Compiler/src/Builtins.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +struct CompileOptions; +} + +namespace Luau +{ +namespace Compile +{ + +struct Builtin +{ + AstName object; + AstName method; + + bool empty() const + { + return object == AstName() && method == AstName(); + } + + bool isGlobal(const char* name) const + { + return object == AstName() && method == name; + } + + bool isMethod(const char* table, const char* name) const + { + return object == table && method == name; + } +}; + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables); +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 2d31c409..e6d02454 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -714,7 +714,7 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const // third pass: write resulting data int logspan = log2(span); - writeByte(ss, logspan); + writeByte(ss, uint8_t(logspan)); uint8_t lastOffset = 0; @@ -723,8 +723,8 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const int delta = lines[i] - baseline[i >> logspan]; LUAU_ASSERT(delta >= 0 && delta <= 255); - writeByte(ss, delta - lastOffset); - lastOffset = delta; + writeByte(ss, uint8_t(delta) - lastOffset); + lastOffset = uint8_t(delta); } int lastLine = 0; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6ae49027..9758c4a9 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -6,15 +6,20 @@ #include "Luau/Common.h" #include "Luau/TimeTrace.h" +#include "Builtins.h" +#include "ConstantFolding.h" +#include "TableShape.h" +#include "ValueTracking.h" + #include #include #include -LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) - namespace Luau { +using namespace Luau::Compile; + static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; @@ -62,7 +67,6 @@ static BytecodeBuilder::StringRef sref(AstArray data) struct Compiler { - struct Constant; struct RegScope; Compiler(BytecodeBuilder& bytecode, const CompileOptions& options) @@ -71,8 +75,9 @@ struct Compiler , functions(nullptr) , locals(nullptr) , globals(AstName()) + , variables(nullptr) , constants(nullptr) - , predictedTableSize(nullptr) + , tableShapes(nullptr) { } @@ -96,8 +101,10 @@ struct Compiler local->location, "Out of upvalue registers when trying to allocate %s: exceeded limit %d", local->name.value, kMaxUpvalueCount); // mark local as captured so that closeLocals emits LOP_CLOSEUPVALS accordingly - Local& l = locals[local]; - l.captured = true; + Variable* v = variables.find(local); + + if (v && v->written) + locals[local].captured = true; upvals.push_back(local); @@ -273,8 +280,8 @@ struct Compiler if (options.optimizationLevel >= 1) { - Builtin builtin = getBuiltin(expr->func); - bfid = getBuiltinFunctionId(builtin); + Builtin builtin = getBuiltin(expr->func, globals, variables); + bfid = getBuiltinFunctionId(builtin, options); } if (expr->self) @@ -364,12 +371,12 @@ struct Compiler else { args[i] = uint8_t(regs + 1 + i); - compileExprTempTop(expr->args.data[i], args[i]); + compileExprTempTop(expr->args.data[i], uint8_t(args[i])); } } fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(opc, uint8_t(bfid), args[0], 0); + bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); if (opc != LOP_FASTCALL1) bytecode.emitAux(args[1]); @@ -385,7 +392,7 @@ struct Compiler } if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), args[i], 0); + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); } } else @@ -424,8 +431,10 @@ struct Compiler for (AstLocal* uv : f->upvals) { - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); + Variable* ul = variables.find(uv); + + if (!ul) + return false; if (ul->written) return false; @@ -437,10 +446,11 @@ struct Compiler // this will only deoptimize (outside of fenv changes) if top level code is executed twice with different results. if (uv->functionDepth != 0 || uv->loopDepth != 0) { - if (!ul->func) + AstExprFunction* uf = ul->init ? ul->init->as() : nullptr; + if (!uf) return false; - if (ul->func != func && !shouldShareClosure(ul->func)) + if (uf != func && !shouldShareClosure(uf)) return false; } } @@ -471,7 +481,7 @@ struct Compiler if (cid >= 0 && cid < 32768) { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + bytecode.emitAD(LOP_DUPCLOSURE, target, int16_t(cid)); shared = true; } } @@ -483,17 +493,15 @@ struct Compiler { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); - - bool immutable = !ul->written; + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; if (uv->functionDepth == expr->functionDepth - 1) { // get local variable uint8_t reg = getLocal(uv); - bytecode.emitABC(LOP_CAPTURE, immutable ? LCT_VAL : LCT_REF, reg, 0); + bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), reg, 0); } else { @@ -635,7 +643,7 @@ struct Compiler if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) { - bytecode.emitAD(opc, rr, 0); + bytecode.emitAD(opc, uint8_t(rr), 0); bytecode.emitAux(rl); } else @@ -687,7 +695,7 @@ struct Compiler break; case Constant::Type_String: - cid = bytecode.addConstantString(sref(c->valueString)); + cid = bytecode.addConstantString(sref(c->getString())); break; default: @@ -1066,10 +1074,10 @@ struct Compiler // Optimization: if the table is empty, we can compute it directly into the target if (expr->items.size == 0) { - auto [hashSize, arraySize] = predictedTableSize[expr]; + TableShape shape = tableShapes[expr]; - bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(hashSize), 0); - bytecode.emitAux(arraySize); + bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(shape.hashSize), 0); + bytecode.emitAux(shape.arraySize); return; } @@ -1144,7 +1152,7 @@ struct Compiler } else { - bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitABC(LOP_NEWTABLE, reg, uint8_t(encodedHashSize), 0); bytecode.emitAux(0); } } @@ -1157,7 +1165,7 @@ struct Compiler bool trailingVarargs = last && last->kind == AstExprTable::Item::List && last->value->is(); LUAU_ASSERT(!trailingVarargs || arraySize > 0); - bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitABC(LOP_NEWTABLE, reg, uint8_t(encodedHashSize), 0); bytecode.emitAux(arraySize - trailingVarargs + indexSize); } @@ -1252,16 +1260,12 @@ struct Compiler bool canImport(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || !global->written); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) != Global::Written; } bool canImportChain(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) == Global::Default; } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -1341,7 +1345,7 @@ struct Compiler { uint8_t rt = compileExprAuto(expr->expr, rs); - BytecodeBuilder::StringRef iname = sref(cv->valueString); + BytecodeBuilder::StringRef iname = sref(cv->getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); @@ -1427,7 +1431,7 @@ struct Compiler case Constant::Type_String: { - int32_t cid = bytecode.addConstantString(sref(cv->valueString)); + int32_t cid = bytecode.addConstantString(sref(cv->getString())); if (cid < 0) CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); @@ -1546,7 +1550,7 @@ struct Compiler { compileExpr(expr->expr, target, targetTemp); } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) + else if (AstExprIfElse* expr = node->as()) { compileExprIfElse(expr, target, targetTemp); } @@ -1711,7 +1715,7 @@ struct Compiler { LValue result = {LValue::Kind_IndexName}; result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->valueString); + result.name = sref(cv->getString()); result.location = node->location; return result; @@ -1796,9 +1800,8 @@ struct Compiler return false; Local* l = locals.find(le->local); - LUAU_ASSERT(l); - return l->allocated; + return l && l->allocated; } bool isStatBreak(AstStat* node) @@ -2040,9 +2043,9 @@ struct Compiler for (AstLocal* local : stat->vars) { - Local* l = locals.find(local); + Variable* v = variables.find(local); - if (!l || l->constant.type == Constant::Type_Unknown) + if (!v || !v->constant) return false; } @@ -2082,9 +2085,7 @@ struct Compiler // through) uint8_t varreg = regs + 2; - Local* il = locals.find(stat->var); - - if (il && il->written) + if (Variable* il = variables.find(stat->var); il && il->written) varreg = allocReg(stat, 1); compileExprTemp(stat->from, uint8_t(regs + 2)); @@ -2164,7 +2165,7 @@ struct Compiler { if (stat->values.size == 1 && stat->values.data[0]->is()) { - Builtin builtin = getBuiltin(stat->values.data[0]->as()->func); + Builtin builtin = getBuiltin(stat->values.data[0]->as()->func, globals, variables); if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { @@ -2179,7 +2180,7 @@ struct Compiler } else if (stat->values.size == 2) { - Builtin builtin = getBuiltin(stat->values.data[0]); + Builtin builtin = getBuiltin(stat->values.data[0], globals, variables); if (builtin.isGlobal("next")) // for .. in next,t { @@ -2594,7 +2595,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) return true; } @@ -2613,7 +2614,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) { captured = true; captureReg = std::min(captureReg, l->reg); @@ -2728,519 +2729,6 @@ struct Compiler return !node->is() && !node->is(); } - struct AssignmentVisitor : AstVisitor - { - struct Hasher - { - size_t operator()(const std::pair& p) const - { - return std::hash()(p.first) ^ std::hash()(p.second); - } - }; - - DenseHashMap localToTable; - DenseHashSet, Hasher> fields; - - AssignmentVisitor(Compiler* self) - : localToTable(nullptr) - , fields(std::pair()) - , self(self) - { - } - - void assignField(AstExpr* expr, AstName index) - { - if (AstExprLocal* lv = expr->as()) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - std::pair field = {*table, index}; - - if (!fields.contains(field)) - { - fields.insert(field); - self->predictedTableSize[*table].first += 1; - } - } - } - } - - void assignField(AstExpr* expr, AstExpr* index) - { - AstExprLocal* lv = expr->as(); - AstExprConstantNumber* number = index->as(); - - if (lv && number) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - unsigned int& arraySize = self->predictedTableSize[*table].second; - - if (number->value == double(arraySize + 1)) - arraySize += 1; - } - } - } - - void assign(AstExpr* var) - { - if (AstExprLocal* lv = var->as()) - { - self->locals[lv->local].written = true; - } - else if (AstExprGlobal* gv = var->as()) - { - self->globals[gv->name].written = true; - } - else if (AstExprIndexName* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else if (AstExprIndexExpr* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else - { - // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 - var->visit(this); - } - } - - AstExprTable* getTableHint(AstExpr* expr) - { - // unadorned table literal - if (AstExprTable* table = expr->as()) - return table; - - // setmetatable(table literal, ...) - if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) - if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") - if (AstExprTable* table = call->args.data[0]->as()) - return table; - - return nullptr; - } - - bool visit(AstStatLocal* node) override - { - // track local -> table association so that we can update table size prediction in assignField - if (node->vars.size == 1 && node->values.size == 1) - if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) - localToTable[node->vars.data[0]] = table; - - return true; - } - - bool visit(AstStatAssign* node) override - { - for (size_t i = 0; i < node->vars.size; ++i) - assign(node->vars.data[i]); - - for (size_t i = 0; i < node->values.size; ++i) - node->values.data[i]->visit(this); - - return false; - } - - bool visit(AstStatCompoundAssign* node) override - { - assign(node->var); - node->value->visit(this); - - return false; - } - - bool visit(AstStatFunction* node) override - { - assign(node->name); - node->func->visit(this); - - return false; - } - - Compiler* self; - }; - - struct ConstantVisitor : AstVisitor - { - ConstantVisitor(Compiler* self) - : self(self) - { - } - - void analyzeUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) - { - switch (op) - { - case AstExprUnary::Not: - if (arg.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !arg.isTruthful(); - } - break; - - case AstExprUnary::Minus: - if (arg.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = -arg.valueNumber; - } - break; - - case AstExprUnary::Len: - if (arg.type == Constant::Type_String) - { - result.type = Constant::Type_Number; - result.valueNumber = double(arg.valueString.size); - } - break; - - default: - LUAU_ASSERT(!"Unexpected unary operation"); - } - } - - bool constantsEqual(const Constant& la, const Constant& ra) - { - LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); - - switch (la.type) - { - case Constant::Type_Nil: - return ra.type == Constant::Type_Nil; - - case Constant::Type_Boolean: - return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; - - case Constant::Type_Number: - return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; - - case Constant::Type_String: - return ra.type == Constant::Type_String && la.valueString.size == ra.valueString.size && - memcmp(la.valueString.data, ra.valueString.data, la.valueString.size) == 0; - - default: - LUAU_ASSERT(!"Unexpected constant type in comparison"); - return false; - } - } - - void analyzeBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) - { - switch (op) - { - case AstExprBinary::Add: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber + ra.valueNumber; - } - break; - - case AstExprBinary::Sub: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - ra.valueNumber; - } - break; - - case AstExprBinary::Mul: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber * ra.valueNumber; - } - break; - - case AstExprBinary::Div: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber / ra.valueNumber; - } - break; - - case AstExprBinary::Mod: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; - } - break; - - case AstExprBinary::Pow: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = pow(la.valueNumber, ra.valueNumber); - } - break; - - case AstExprBinary::Concat: - break; - - case AstExprBinary::CompareNe: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareEq: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareLt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber < ra.valueNumber; - } - break; - - case AstExprBinary::CompareLe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber <= ra.valueNumber; - } - break; - - case AstExprBinary::CompareGt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber > ra.valueNumber; - } - break; - - case AstExprBinary::CompareGe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber >= ra.valueNumber; - } - break; - - case AstExprBinary::And: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? ra : la; - } - break; - - case AstExprBinary::Or: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? la : ra; - } - break; - - default: - LUAU_ASSERT(!"Unexpected binary operation"); - } - } - - Constant analyze(AstExpr* node) - { - Constant result; - result.type = Constant::Type_Unknown; - - if (AstExprGroup* expr = node->as()) - { - result = analyze(expr->expr); - } - else if (node->is()) - { - result.type = Constant::Type_Nil; - } - else if (AstExprConstantBool* expr = node->as()) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = expr->value; - } - else if (AstExprConstantNumber* expr = node->as()) - { - result.type = Constant::Type_Number; - result.valueNumber = expr->value; - } - else if (AstExprConstantString* expr = node->as()) - { - result.type = Constant::Type_String; - result.valueString = expr->value; - } - else if (AstExprLocal* expr = node->as()) - { - const Local* l = self->locals.find(expr->local); - - if (l && l->constant.type != Constant::Type_Unknown) - { - LUAU_ASSERT(!l->written); - result = l->constant; - } - } - else if (node->is()) - { - // nope - } - else if (node->is()) - { - // nope - } - else if (AstExprCall* expr = node->as()) - { - analyze(expr->func); - - for (size_t i = 0; i < expr->args.size; ++i) - analyze(expr->args.data[i]); - } - else if (AstExprIndexName* expr = node->as()) - { - analyze(expr->expr); - } - else if (AstExprIndexExpr* expr = node->as()) - { - analyze(expr->expr); - analyze(expr->index); - } - else if (AstExprFunction* expr = node->as()) - { - // this is necessary to propagate constant information in all child functions - expr->body->visit(this); - } - else if (AstExprTable* expr = node->as()) - { - for (size_t i = 0; i < expr->items.size; ++i) - { - const AstExprTable::Item& item = expr->items.data[i]; - - if (item.key) - analyze(item.key); - - analyze(item.value); - } - } - else if (AstExprUnary* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - analyzeUnary(result, expr->op, arg); - } - else if (AstExprBinary* expr = node->as()) - { - Constant la = analyze(expr->left); - Constant ra = analyze(expr->right); - - analyzeBinary(result, expr->op, la, ra); - } - else if (AstExprTypeAssertion* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - result = arg; - } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) - { - Constant cond = analyze(expr->condition); - Constant trueExpr = analyze(expr->trueExpr); - Constant falseExpr = analyze(expr->falseExpr); - if (cond.type != Constant::Type_Unknown) - { - result = cond.isTruthful() ? trueExpr : falseExpr; - } - } - else - { - LUAU_ASSERT(!"Unknown expression type"); - } - - if (result.type != Constant::Type_Unknown) - self->constants[node] = result; - - return result; - } - - bool visit(AstExpr* node) override - { - // note: we short-circuit the visitor traversal through any expression trees by returning false - // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression - analyze(node); - - return false; - } - - bool visit(AstStatLocal* node) override - { - // for values that match 1-1 we record the initializing expression for future analysis - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Local& l = self->locals[node->vars.data[i]]; - - l.init = node->values.data[i]; - } - - // all values that align wrt indexing are simple - we just match them 1-1 - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Constant arg = analyze(node->values.data[i]); - - if (arg.type != Constant::Type_Unknown) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - l.constant = arg; - } - } - - if (node->vars.size > node->values.size) - { - // if we have trailing variables, then depending on whether the last value is capable of returning multiple values - // (aka call or varargs), we either don't know anything about these vars, or we know they're nil - AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; - bool multRet = last && (last->is() || last->is()); - - for (size_t i = node->values.size; i < node->vars.size; ++i) - { - if (!multRet) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - { - l.constant.type = Constant::Type_Nil; - } - } - } - } - else - { - // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside - // them - for (size_t i = node->vars.size; i < node->values.size; ++i) - analyze(node->values.data[i]); - } - - return false; - } - - Compiler* self; - }; - struct FenvVisitor : AstVisitor { bool& getfenvUsed; @@ -3283,14 +2771,6 @@ struct Compiler return false; } - - bool visit(AstStatLocalFunction* node) override - { - // record local->function association for some optimizations - self->locals[node->name].func = node->func; - - return true; - } }; struct UndefinedLocalVisitor : AstVisitor @@ -3397,70 +2877,12 @@ struct Compiler std::vector upvals; }; - struct Constant - { - enum Type - { - Type_Unknown, - Type_Nil, - Type_Boolean, - Type_Number, - Type_String, - }; - - Type type = Type_Unknown; - - union - { - bool valueBoolean; - double valueNumber; - AstArray valueString = {}; - }; - - bool isTruthful() const - { - LUAU_ASSERT(type != Type_Unknown); - return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); - } - }; - struct Local { uint8_t reg = 0; bool allocated = false; bool captured = false; - bool written = false; - AstExpr* init = nullptr; uint32_t debugpc = 0; - Constant constant; - AstExprFunction* func = nullptr; - }; - - struct Global - { - bool writable = false; - bool written = false; - }; - - struct Builtin - { - AstName object; - AstName method; - - bool empty() const - { - return object == AstName() && method == AstName(); - } - - bool isGlobal(const char* name) const - { - return object == AstName() && method == name; - } - - bool isMethod(const char* table, const char* name) const - { - return object == table && method == name; - } }; struct LoopJump @@ -3482,194 +2904,6 @@ struct Compiler AstExpr* untilCondition; }; - Builtin getBuiltin(AstExpr* node) - { - if (AstExprLocal* expr = node->as()) - { - Local* l = locals.find(expr->local); - - return l && !l->written && l->init ? getBuiltin(l->init) : Builtin(); - } - else if (AstExprIndexName* expr = node->as()) - { - if (AstExprGlobal* object = expr->expr->as()) - { - Global* g = globals.find(object->name); - - return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); - } - else - { - return Builtin(); - } - } - else if (AstExprGlobal* expr = node->as()) - { - Global* g = globals.find(expr->name); - - return !g || !g->written ? Builtin{AstName(), expr->name} : Builtin(); - } - else - { - return Builtin(); - } - } - - int getBuiltinFunctionId(const Builtin& builtin) - { - if (builtin.empty()) - return -1; - - if (builtin.isGlobal("assert")) - return LBF_ASSERT; - - if (builtin.isGlobal("type")) - return LBF_TYPE; - - if (builtin.isGlobal("typeof")) - return LBF_TYPEOF; - - if (builtin.isGlobal("rawset")) - return LBF_RAWSET; - if (builtin.isGlobal("rawget")) - return LBF_RAWGET; - if (builtin.isGlobal("rawequal")) - return LBF_RAWEQUAL; - - if (builtin.isGlobal("unpack")) - return LBF_TABLE_UNPACK; - - if (builtin.object == "math") - { - if (builtin.method == "abs") - return LBF_MATH_ABS; - if (builtin.method == "acos") - return LBF_MATH_ACOS; - if (builtin.method == "asin") - return LBF_MATH_ASIN; - if (builtin.method == "atan2") - return LBF_MATH_ATAN2; - if (builtin.method == "atan") - return LBF_MATH_ATAN; - if (builtin.method == "ceil") - return LBF_MATH_CEIL; - if (builtin.method == "cosh") - return LBF_MATH_COSH; - if (builtin.method == "cos") - return LBF_MATH_COS; - if (builtin.method == "deg") - return LBF_MATH_DEG; - if (builtin.method == "exp") - return LBF_MATH_EXP; - if (builtin.method == "floor") - return LBF_MATH_FLOOR; - if (builtin.method == "fmod") - return LBF_MATH_FMOD; - if (builtin.method == "frexp") - return LBF_MATH_FREXP; - if (builtin.method == "ldexp") - return LBF_MATH_LDEXP; - if (builtin.method == "log10") - return LBF_MATH_LOG10; - if (builtin.method == "log") - return LBF_MATH_LOG; - if (builtin.method == "max") - return LBF_MATH_MAX; - if (builtin.method == "min") - return LBF_MATH_MIN; - if (builtin.method == "modf") - return LBF_MATH_MODF; - if (builtin.method == "pow") - return LBF_MATH_POW; - if (builtin.method == "rad") - return LBF_MATH_RAD; - if (builtin.method == "sinh") - return LBF_MATH_SINH; - if (builtin.method == "sin") - return LBF_MATH_SIN; - if (builtin.method == "sqrt") - return LBF_MATH_SQRT; - if (builtin.method == "tanh") - return LBF_MATH_TANH; - if (builtin.method == "tan") - return LBF_MATH_TAN; - if (builtin.method == "clamp") - return LBF_MATH_CLAMP; - if (builtin.method == "sign") - return LBF_MATH_SIGN; - if (builtin.method == "round") - return LBF_MATH_ROUND; - } - - if (builtin.object == "bit32") - { - if (builtin.method == "arshift") - return LBF_BIT32_ARSHIFT; - if (builtin.method == "band") - return LBF_BIT32_BAND; - if (builtin.method == "bnot") - return LBF_BIT32_BNOT; - if (builtin.method == "bor") - return LBF_BIT32_BOR; - if (builtin.method == "bxor") - return LBF_BIT32_BXOR; - if (builtin.method == "btest") - return LBF_BIT32_BTEST; - if (builtin.method == "extract") - return LBF_BIT32_EXTRACT; - if (builtin.method == "lrotate") - return LBF_BIT32_LROTATE; - if (builtin.method == "lshift") - return LBF_BIT32_LSHIFT; - if (builtin.method == "replace") - return LBF_BIT32_REPLACE; - if (builtin.method == "rrotate") - return LBF_BIT32_RROTATE; - if (builtin.method == "rshift") - return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz") - return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz") - return LBF_BIT32_COUNTRZ; - } - - if (builtin.object == "string") - { - if (builtin.method == "byte") - return LBF_STRING_BYTE; - if (builtin.method == "char") - return LBF_STRING_CHAR; - if (builtin.method == "len") - return LBF_STRING_LEN; - if (builtin.method == "sub") - return LBF_STRING_SUB; - } - - if (builtin.object == "table") - { - if (builtin.method == "insert") - return LBF_TABLE_INSERT; - if (builtin.method == "unpack") - return LBF_TABLE_UNPACK; - } - - if (options.vectorCtor) - { - if (options.vectorLib) - { - if (builtin.isMethod(options.vectorLib, options.vectorCtor)) - return LBF_VECTOR; - } - else - { - if (builtin.isGlobal(options.vectorCtor)) - return LBF_VECTOR; - } - } - - return -1; - } - BytecodeBuilder& bytecode; CompileOptions options; @@ -3677,8 +2911,9 @@ struct Compiler DenseHashMap functions; DenseHashMap locals; DenseHashMap globals; + DenseHashMap variables; DenseHashMap constants; - DenseHashMap> predictedTableSize; + DenseHashMap tableShapes; unsigned int regTop = 0; unsigned int stackSize = 0; @@ -3699,23 +2934,18 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + assignMutable(compiler.globals, names, options.mutableGlobals); - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; + // this pass analyzes mutability of locals/globals and associates locals with their initial values + trackValues(compiler.globals, compiler.variables, root); - // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written - Compiler::AssignmentVisitor assignmentVisitor(&compiler); - root->visit(&assignmentVisitor); - - // this visitor traverses the AST to analyze constantness of expressions, filling constants[] and Local::constant/Local::init if (options.optimizationLevel >= 1) { - Compiler::ConstantVisitor constantVisitor(&compiler); - root->visit(&constantVisitor); + // this pass analyzes constantness of expressions + foldConstants(compiler.constants, compiler.variables, root); + + // this pass analyzes table assignments to estimate table shapes for initially empty tables + predictTableShapes(compiler.tableShapes, root); } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found @@ -3734,8 +2964,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName for (AstExprFunction* expr : functions) compiler.compileFunction(expr); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), /* self= */ nullptr, - AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); + AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); bytecode.setMainFunction(mainid); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp new file mode 100644 index 00000000..60a7c169 --- /dev/null +++ b/Compiler/src/ConstantFolding.cpp @@ -0,0 +1,394 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstantFolding.h" + +#include + +namespace Luau +{ +namespace Compile +{ + +static bool constantsEqual(const Constant& la, const Constant& ra) +{ + LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); + + switch (la.type) + { + case Constant::Type_Nil: + return ra.type == Constant::Type_Nil; + + case Constant::Type_Boolean: + return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; + + case Constant::Type_Number: + return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; + + case Constant::Type_String: + return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0; + + default: + LUAU_ASSERT(!"Unexpected constant type in comparison"); + return false; + } +} + +static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) +{ + switch (op) + { + case AstExprUnary::Not: + if (arg.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !arg.isTruthful(); + } + break; + + case AstExprUnary::Minus: + if (arg.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = -arg.valueNumber; + } + break; + + case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.stringLength); + } + break; + + default: + LUAU_ASSERT(!"Unexpected unary operation"); + } +} + +static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) +{ + switch (op) + { + case AstExprBinary::Add: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber + ra.valueNumber; + } + break; + + case AstExprBinary::Sub: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - ra.valueNumber; + } + break; + + case AstExprBinary::Mul: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber * ra.valueNumber; + } + break; + + case AstExprBinary::Div: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber / ra.valueNumber; + } + break; + + case AstExprBinary::Mod: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; + } + break; + + case AstExprBinary::Pow: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = pow(la.valueNumber, ra.valueNumber); + } + break; + + case AstExprBinary::Concat: + break; + + case AstExprBinary::CompareNe: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareEq: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareLt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber < ra.valueNumber; + } + break; + + case AstExprBinary::CompareLe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber <= ra.valueNumber; + } + break; + + case AstExprBinary::CompareGt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber > ra.valueNumber; + } + break; + + case AstExprBinary::CompareGe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber >= ra.valueNumber; + } + break; + + case AstExprBinary::And: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? ra : la; + } + break; + + case AstExprBinary::Or: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? la : ra; + } + break; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + } +} + +struct ConstantVisitor : AstVisitor +{ + DenseHashMap& constants; + DenseHashMap& variables; + + DenseHashMap locals; + + ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables) + : constants(constants) + , variables(variables) + , locals(nullptr) + { + } + + Constant analyze(AstExpr* node) + { + Constant result; + result.type = Constant::Type_Unknown; + + if (AstExprGroup* expr = node->as()) + { + result = analyze(expr->expr); + } + else if (node->is()) + { + result.type = Constant::Type_Nil; + } + else if (AstExprConstantBool* expr = node->as()) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = expr->value; + } + else if (AstExprConstantNumber* expr = node->as()) + { + result.type = Constant::Type_Number; + result.valueNumber = expr->value; + } + else if (AstExprConstantString* expr = node->as()) + { + result.type = Constant::Type_String; + result.valueString = expr->value.data; + result.stringLength = unsigned(expr->value.size); + } + else if (AstExprLocal* expr = node->as()) + { + const Constant* l = locals.find(expr->local); + + if (l) + result = *l; + } + else if (node->is()) + { + // nope + } + else if (node->is()) + { + // nope + } + else if (AstExprCall* expr = node->as()) + { + analyze(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + analyze(expr->args.data[i]); + } + else if (AstExprIndexName* expr = node->as()) + { + analyze(expr->expr); + } + else if (AstExprIndexExpr* expr = node->as()) + { + analyze(expr->expr); + analyze(expr->index); + } + else if (AstExprFunction* expr = node->as()) + { + // this is necessary to propagate constant information in all child functions + expr->body->visit(this); + } + else if (AstExprTable* expr = node->as()) + { + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + analyze(item.key); + + analyze(item.value); + } + } + else if (AstExprUnary* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + if (arg.type != Constant::Type_Unknown) + foldUnary(result, expr->op, arg); + } + else if (AstExprBinary* expr = node->as()) + { + Constant la = analyze(expr->left); + Constant ra = analyze(expr->right); + + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + foldBinary(result, expr->op, la, ra); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + result = arg; + } + else if (AstExprIfElse* expr = node->as()) + { + Constant cond = analyze(expr->condition); + Constant trueExpr = analyze(expr->trueExpr); + Constant falseExpr = analyze(expr->falseExpr); + + if (cond.type != Constant::Type_Unknown) + result = cond.isTruthful() ? trueExpr : falseExpr; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + } + + if (result.type != Constant::Type_Unknown) + constants[node] = result; + + return result; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression + analyze(node); + + return false; + } + + bool visit(AstStatLocal* node) override + { + // all values that align wrt indexing are simple - we just match them 1-1 + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + { + Constant arg = analyze(node->values.data[i]); + + if (arg.type != Constant::Type_Unknown) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]] = arg; + v->constant = true; + } + } + } + + if (node->vars.size > node->values.size) + { + // if we have trailing variables, then depending on whether the last value is capable of returning multiple values + // (aka call or varargs), we either don't know anything about these vars, or we know they're nil + AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; + bool multRet = last && (last->is() || last->is()); + + if (!multRet) + { + for (size_t i = node->values.size; i < node->vars.size; ++i) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]].type = Constant::Type_Nil; + v->constant = true; + } + } + } + } + else + { + // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside + // them + for (size_t i = node->vars.size; i < node->values.size; ++i) + analyze(node->values.data[i]); + } + + return false; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root) +{ + ConstantVisitor visitor{constants, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h new file mode 100644 index 00000000..c0e63539 --- /dev/null +++ b/Compiler/src/ConstantFolding.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +namespace Compile +{ + +struct Constant +{ + enum Type + { + Type_Unknown, + Type_Nil, + Type_Boolean, + Type_Number, + Type_String, + }; + + Type type = Type_Unknown; + unsigned int stringLength = 0; + + union + { + bool valueBoolean; + double valueNumber; + char* valueString = nullptr; // length stored in stringLength + }; + + bool isTruthful() const + { + LUAU_ASSERT(type != Type_Unknown); + return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); + } + + AstArray getString() const + { + LUAU_ASSERT(type == Type_String); + return {valueString, stringLength}; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp new file mode 100644 index 00000000..7d99f222 --- /dev/null +++ b/Compiler/src/TableShape.cpp @@ -0,0 +1,129 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "TableShape.h" + +namespace Luau +{ +namespace Compile +{ + +static AstExprTable* getTableHint(AstExpr* expr) +{ + // unadorned table literal + if (AstExprTable* table = expr->as()) + return table; + + // setmetatable(table literal, ...) + if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) + if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") + if (AstExprTable* table = call->args.data[0]->as()) + return table; + + return nullptr; +} + +struct ShapeVisitor : AstVisitor +{ + struct Hasher + { + size_t operator()(const std::pair& p) const + { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + DenseHashMap& shapes; + + DenseHashMap tables; + DenseHashSet, Hasher> fields; + + ShapeVisitor(DenseHashMap& shapes) + : shapes(shapes) + , tables(nullptr) + , fields(std::pair()) + { + } + + void assignField(AstExpr* expr, AstName index) + { + if (AstExprLocal* lv = expr->as()) + { + if (AstExprTable** table = tables.find(lv->local)) + { + std::pair field = {*table, index}; + + if (!fields.contains(field)) + { + fields.insert(field); + shapes[*table].hashSize += 1; + } + } + } + } + + void assignField(AstExpr* expr, AstExpr* index) + { + AstExprLocal* lv = expr->as(); + AstExprConstantNumber* number = index->as(); + + if (lv && number) + { + if (AstExprTable** table = tables.find(lv->local)) + { + TableShape& shape = shapes[*table]; + + if (number->value == double(shape.arraySize + 1)) + shape.arraySize += 1; + } + } + } + + void assign(AstExpr* var) + { + if (AstExprIndexName* index = var->as()) + { + assignField(index->expr, index->index); + } + else if (AstExprIndexExpr* index = var->as()) + { + assignField(index->expr, index->index); + } + } + + bool visit(AstStatLocal* node) override + { + // track local -> table association so that we can update table size prediction in assignField + if (node->vars.size == 1 && node->values.size == 1) + if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) + tables[node->vars.data[0]] = table; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root) +{ + ShapeVisitor visitor{shapes}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.h b/Compiler/src/TableShape.h new file mode 100644 index 00000000..f30853a7 --- /dev/null +++ b/Compiler/src/TableShape.h @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +struct TableShape +{ + unsigned int arraySize = 0; + unsigned int hashSize = 0; +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.cpp b/Compiler/src/ValueTracking.cpp new file mode 100644 index 00000000..0bfaf9b3 --- /dev/null +++ b/Compiler/src/ValueTracking.cpp @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ValueTracking.h" + +#include "Luau/Lexer.h" + +namespace Luau +{ +namespace Compile +{ + +struct ValueVisitor : AstVisitor +{ + DenseHashMap& globals; + DenseHashMap& variables; + + ValueVisitor(DenseHashMap& globals, DenseHashMap& variables) + : globals(globals) + , variables(variables) + { + } + + void assign(AstExpr* var) + { + if (AstExprLocal* lv = var->as()) + { + variables[lv->local].written = true; + } + else if (AstExprGlobal* gv = var->as()) + { + globals[gv->name] = Global::Written; + } + else + { + // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 + var->visit(this); + } + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + variables[node->vars.data[i]].init = node->values.data[i]; + + for (size_t i = node->values.size; i < node->vars.size; ++i) + variables[node->vars.data[i]].init = nullptr; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + node->value->visit(this); + + return false; + } + + bool visit(AstStatLocalFunction* node) override + { + variables[node->name].init = node->func; + + return true; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals) +{ + if (AstName name = names.get("_G"); name.value) + globals[name] = Global::Mutable; + + if (mutableGlobals) + for (const char** ptr = mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + globals[name] = Global::Mutable; +} + +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root) +{ + ValueVisitor visitor{globals, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.h b/Compiler/src/ValueTracking.h new file mode 100644 index 00000000..fc74c84a --- /dev/null +++ b/Compiler/src/ValueTracking.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +class AstNameTable; +} + +namespace Luau +{ +namespace Compile +{ + +enum class Global +{ + Default = 0, + Mutable, // builtin that has contents unknown at compile time, blocks GETIMPORT for chains + Written, // written in the code which means we can't reason about the value +}; + +struct Variable +{ + AstExpr* init = nullptr; // initial value of the variable; filled by trackValues + bool written = false; // is the variable ever assigned to? filled by trackValues + bool constant = false; // is the variable's value a compile-time constant? filled by constantFold +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals); +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root); + +inline Global getGlobalState(const DenseHashMap& globals, AstName name) +{ + const Global* it = globals.find(name); + + return it ? *it : Global::Default; +} + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 5dd486aa..bafe7594 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -29,7 +29,15 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/Builtins.cpp + Compiler/src/ConstantFolding.cpp + Compiler/src/TableShape.cpp + Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp + Compiler/src/Builtins.h + Compiler/src/ConstantFolding.h + Compiler/src/TableShape.h + Compiler/src/ValueTracking.h ) # Luau.Analysis Sources diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index eb47971a..3cce7665 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,7 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) /* @@ -545,11 +544,8 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e if (!oldactive) resetbit(L->stackstate, THREAD_ACTIVEBIT); - if (FFlag::LuauCcallRestoreFix) - { - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. - L->nCcalls = oldnCcalls; - } + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) @@ -564,10 +560,6 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - if (!FFlag::LuauCcallRestoreFix) - { - L->nCcalls = oldnCcalls; - } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 64878569..4178eda4 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,6 +6,8 @@ #include "lmem.h" #include "lgc.h" +LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) + Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); @@ -113,14 +115,16 @@ void luaF_freeupval(lua_State* L, UpVal* uv) void luaF_close(lua_State* L, StkId level) { UpVal* uv; - global_State* g = L->global; + global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); L->openupval = uv->next; /* remove from `open' list */ - if (isdead(g, o)) + if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + { luaF_freeupval(L, uv); /* free upvalue */ + } else { unlinkupval(uv); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 0b3054ae..74a8aa8a 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) - /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1406,20 +1404,10 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - if (FFlag::LuauStrPackUBCastFix) - { - long long n = (long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, (unsigned long long)n, h.islittle, size, 0); - } - else - { - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); - } + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); break; } case Kfloat: diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 6d40406c..5d162ab9 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -111,15 +111,10 @@ end -- multiplies two matrices function MMulti(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do - M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; j = j + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; end - - i = i + 1 end return M; end @@ -127,28 +122,27 @@ end -- multiplies matrix with vector function VMulti(M, V) local Vect = {}; - local i = 1; - while i <= 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; i = i + 1 end + for i = 1,4 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; + end return Vect; end function VMulti2(M, V) local Vect = {}; - local i = 1; - while i < 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; i = i + 1 end + for i = 1,3 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; + end return Vect; end -- add to matrices function MAdd(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do M[i][j] = M1[i][j] + M2[i][j]; j = j + 1 end - - i = i + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][j] + M2[i][j]; + end end return M; end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 210db7ee..211e1be1 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1938,7 +1938,6 @@ return target(b@1 TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function bar(a: number) return -a end @@ -1954,7 +1953,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function foo() return 1 end @@ -2538,10 +2536,6 @@ TEST_CASE("autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { check(R"( local temp = false local even = true; @@ -2614,7 +2608,6 @@ a = if temp then even elseif true then temp else e@9 CHECK(ac.entryMap.count("then") == 0); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); - } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") @@ -2681,4 +2674,58 @@ local r4 = t:bar1(@4) CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +{ + ScopedFastFlag flag("LuauMissingFollowACMetatables", true); + check(R"( +--!strict +local Class = {} +Class.__index = Class +type Class = typeof(setmetatable({} :: { x: number }, Class)) +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end +function Class.getx(self: Class) + return self.x +end +function test() + local c = Class.new(42) + local n = c:@1 + print(n) +end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("getx")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 95811b3f..8eed953f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,9 +603,9 @@ RETURN R0 1 )"); } -TEST_CASE("EmptyTableHashSizePredictionOptimization") +TEST_CASE("TableSizePredictionBasic") { - const char* hashSizeSource = R"( + CHECK_EQ("\n" + compileFunction0(R"( local t = {} t.a = 1 t.b = 1 @@ -616,36 +616,8 @@ t.f = 1 t.g = 1 t.h = 1 t.i = 1 -)"; - - const char* hashSizeSource2 = R"( -local t = {} -t.x = 1 -t.x = 2 -t.x = 3 -t.x = 4 -t.x = 5 -t.x = 6 -t.x = 7 -t.x = 8 -t.x = 9 -)"; - - const char* arraySizeSource = R"( -local t = {} -t[1] = 1 -t[2] = 1 -t[3] = 1 -t[4] = 1 -t[5] = 1 -t[6] = 1 -t[7] = 1 -t[8] = 1 -t[9] = 1 -t[10] = 1 -)"; - - CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"( +)"), + R"( NEWTABLE R0 16 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -668,7 +640,19 @@ SETTABLEKS R1 R0 K8 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t.x = 1 +t.x = 2 +t.x = 3 +t.x = 4 +t.x = 5 +t.x = 6 +t.x = 7 +t.x = 8 +t.x = 9 +)"), + R"( NEWTABLE R0 1 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -691,7 +675,20 @@ SETTABLEKS R1 R0 K0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t[1] = 1 +t[2] = 1 +t[3] = 1 +t[4] = 1 +t[5] = 1 +t[6] = 1 +t[7] = 1 +t[8] = 1 +t[9] = 1 +t[10] = 1 +)"), + R"( NEWTABLE R0 0 10 LOADN R1 1 SETTABLEN R1 R0 1 @@ -717,6 +714,27 @@ RETURN R0 0 )"); } +TEST_CASE("TableSizePredictionObject") +{ + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +t.field = 1 +function t:getfield() + return self.field +end +return t +)", + 1), + R"( +NEWTABLE R0 2 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +DUPCLOSURE R1 K1 +SETTABLEKS R1 R0 K2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionSetMetatable") { CHECK_EQ("\n" + compileFunction0(R"( @@ -1031,9 +1049,6 @@ RETURN R0 1 TEST_CASE("IfElseExpression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - // codegen for a true constant condition CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( LOADN R0 10 @@ -3058,7 +3073,7 @@ RETURN R0 0 // table variants (indexed by string, number, variable) CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( -NEWTABLE R0 1 0 +NEWTABLE R0 0 0 GETTABLEKS R1 R0 K0 ADDK R1 R1 K1 SETTABLEKS R1 R0 K0 @@ -3066,7 +3081,7 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( -NEWTABLE R0 0 1 +NEWTABLE R0 0 0 GETTABLEN R1 R0 1 ADDK R1 R1 K0 SETTABLEN R1 R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 663b329e..5222af33 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -366,15 +366,11 @@ TEST_CASE("PCall") TEST_CASE("Pack") { - ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; - runConformance("tpack.lua"); } TEST_CASE("Vector") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - lua_CompileOptions copts = {}; copts.optimizationLevel = 1; copts.debugLevel = 1; @@ -861,15 +857,11 @@ TEST_CASE("ExceptionObject") TEST_CASE("IfElseExpression") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("ifelseexpr.lua"); } TEST_CASE("TagMethodError") { - ScopedFastFlag sff{"LuauCcallRestoreFix", true}; - runConformance("tmerror.lua", [](lua_State* L) { auto* cb = lua_callbacks(L); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ca4281a0..c74bfa27 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -191,7 +191,7 @@ ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& par return result; } -ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) +ParseResult Fixture::matchParseError(const std::string& source, const std::string& message, std::optional location) { ParseOptions options; options.allowDeclarationSyntax = true; @@ -203,6 +203,9 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin CHECK_EQ(result.errors.front().getMessage(), message); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + return result; } diff --git a/tests/Fixture.h b/tests/Fixture.h index e01632ea..ab852ef6 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -106,7 +106,7 @@ struct Fixture /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); - ParseResult matchParseError(const std::string& source, const std::string& message); + ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 5abcb09a..e9135651 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1255,7 +1255,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " @@ -1266,7 +1265,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " @@ -1276,7 +1274,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements" TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " @@ -1286,7 +1283,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1 TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError( @@ -1962,6 +1958,37 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + AstStat* stat = parse(R"( +type A = {} +type B = {} +type C = {} +type D = {} +type E = {} +type F = (T...) -> U... +type G = (U...) -> T... + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); + matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); + matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2455,10 +2482,19 @@ do end CHECK_EQ(1, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true}; + + ParseResult result = tryParse(R"( +type Y = (T...) -> U... + )"); + CHECK_EQ(1, result.errors.size()); +} + TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - { AstStat* stat = parse("return if true then 1 else 2"); @@ -2524,9 +2560,4 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } -TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") -{ - matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); -} - TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 80a258f5..445ee532 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -338,6 +338,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -353,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -500,6 +502,24 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") +{ + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( + local function f(a: number, b: string) end + local function test(...: T...): U... + f(...) + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("test(...: T...): U...", toStringNamedFunction("test", *ftv)); +} + TEST_CASE("toStringNamedFunction_unit_f") { TypePackVar empty{TypePack{}}; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 47c3883c..ac5be859 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -421,8 +421,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") { - ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); - std::string code = "local a = if 1 then 2 else 3"; CHECK_EQ(code, transpile(code).code); @@ -641,4 +639,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string") CHECK_EQ("'hello'", toString(expr)); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + std::string code = R"( +type Packed = (T, U, V...)->(W...) +local a: Packed + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 275782b3..86165814 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -497,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") CHECK(arrayTable->indexer); CHECK(isInArena(array.type, mod.interfaceTypes)); - CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); + CHECK_EQ(array.typeParams[0].ty, arrayTable->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 503b613f..d76b920b 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1031,9 +1031,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if v then v else tostring(v) @@ -1048,9 +1045,6 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if not v then tostring(v) else v @@ -1065,9 +1059,6 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function returnOne(x) return 1 @@ -1119,6 +1110,25 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = { x: { y: number }? } + + local function f(t: T?) + if t and t.x then + local foo = t.x.y + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({5, 32}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 68dc1b4f..94cfb643 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -360,6 +360,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { ScopedFastFlag sffs[] = { {"LuauParseSingletonTypes", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -369,7 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", toString(result.errors[0])); } @@ -423,4 +424,27 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + {"LuauIfElseExpectedType2", true}, + {"LuauIfElseBranchTypeUnion", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag = 'dog', dogfood = 'other' } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 80f40407..27cda146 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { - CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); + CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); - CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); + CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}})); } TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") @@ -1155,7 +1155,8 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_mu TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t.m() end )"); @@ -1165,13 +1166,38 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t:m() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t.m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + // This unit test could be flaky if the fix has regressed. TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") { @@ -1439,8 +1465,13 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") CHECK_EQ("{| |}", toString(mp->subType)); } -TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") +TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; + CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1448,6 +1479,25 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + // Should t now have an indexer? + // It would if the assignment to rt was correctly typed. + CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer") +{ + CheckResult result = check(R"( + type StringToStringMap = { [string]: string } + function mkrt() return { ["foo"] = 1 } end + local rt: StringToStringMap = mkrt() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); @@ -1467,7 +1517,10 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1480,7 +1533,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); + CHECK_EQ("{ a: number }", toString(tm->givenType, o)); } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") @@ -1536,8 +1589,11 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {{x = 1, y = 2, z = 3}} - local vec1 = {{x = 1}} + function mkvec3() return {x = 1, y = 2, z = 3} end + function mkvec1() return {x = 1} end + + local vec3 = {mkvec3()} + local vec1 = {mkvec1()} vec1 = vec3 )"); @@ -1620,7 +1676,8 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_ { CheckResult result = check(R"( --!strict - local A = {"value"} + function mkA() return {"value"} end + local A = mkA() A.B = "Hello" )"); @@ -1668,7 +1725,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") --!strict local function f() - local t = { x = 1 } + local function mkt() return { x = 1 } end + local t = mkt() function t.a() end function t.b() end @@ -1995,7 +2053,10 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2010,7 +2071,7 @@ local c2: typeof(a2) = b2 LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' caused by: - Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); @@ -2018,7 +2079,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); } @@ -2026,7 +2087,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } @@ -2059,6 +2120,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -2077,7 +2139,7 @@ local y: number = tmp.p.y LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' caused by: - Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); + Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") @@ -2103,7 +2165,10 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { - ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + ScopedFastFlag sff[]{ + {"LuauFixRecursiveMetatableCall", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local b @@ -2112,7 +2177,7 @@ b() )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f70f3b1c..7a056af5 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4525,7 +4525,9 @@ f(function(x) print(x) end) } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4549,7 +4551,7 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") @@ -4689,6 +4691,18 @@ a = setmetatable(a, { __call = function(x) end }) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") +{ + ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true}; + + CheckResult result = check(R"( +local function f(a: (number, number) -> number) return a(1, 3) end +f(((function(a, b) return a + b end))) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "refine_and_or") { CheckResult result = check(R"( @@ -4743,46 +4757,75 @@ local c: X TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - CheckResult result = check(R"(local a = if true then "true" else "false")"); - LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"(local a = if true then "true" else "false")"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + // Test expression containing elseif + CheckResult result = check(R"( +local a = if false then "a" elseif false then "b" else "c" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") +{ + ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true}; { - // Test expression containing elseif - CheckResult result = check(R"( -local a = if false then "a" elseif false then "b" else "c" - )"); + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); } } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - { - CheckResult result = check(R"(local a = if true then "true" else 42)"); - // We currently require both true/false expressions to unify to the same type. However, we do intend to lift - // this restriction in the future. - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"( +type X = {number | string} +local a: X = if true then {"1", 2, 3} else {4, 5, 6} +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "{number | string}"); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true }; + + CheckResult result = check(R"( +local a: number? = if true then 1 else nil +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + + CheckResult result = check(R"( +local function times(n: any, f: () -> T) + local result: {T} = {} + local res = f() + table.insert(result, if true then res else n) + return result +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_error_addition") @@ -5039,4 +5082,51 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +{ + ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true}; + + CheckResult result = check(R"( +local function getIt() + local y + y = setmetatable({}, y) + return y +end +local a = getIt() +local b = getIt() +local c = a or b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5d37b032..d4878d14 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -621,4 +621,328 @@ type Other = Packed CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = 2, b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y = { a: T } + +local a: Y = { a = 2 } +local b: Y<> = { a = "s" } +local c: Y = { a = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = "h", b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y string> = { a: T, b: U } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y string>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U, c: V } + +local a: Y +local b: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: (U...) -> T } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U..., b: (T...) -> V... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T, U, V...) -> W... } +local a: Y +local b: Y +local c: Y +local d: Y ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); + CHECK_EQ(toString(requireType("d")), "Y ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T } +local a: Y = { a = 2 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T) -> U... } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Y' expects at least 1 type argument, but none are specified"); + + result = check(R"( +type Packed = (T) -> T +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Y = { a: T } +local a: Y + )"); + + LUAU_REQUIRE_ERRORS(result); + + result = check(R"( +type Y = { a: T } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + fileResolver.source["Module/Types"] = R"( +export type A = { a: T, b: U } +export type B = { a: T, b: U } +export type C string> = { a: T, b: U } +export type D = { a: T, b: U, c: V } +export type E = { a: (T...) -> () } +export type F = { a: T, b: (U...) -> T } +export type G = { b: (U...) -> T... } +export type H = { b: (T...) -> T... } +return {} + )"; + + CheckResult resultTypes = frontend.check("Module/Types"); + LUAU_REQUIRE_NO_ERRORS(resultTypes); + + fileResolver.source["Module/Users"] = R"( +local Types = require(script.Parent.Types) + +local a: Types.A +local b: Types.B +local c: Types.C +local d: Types.D +local e: Types.E<> +local eVoid: Types.E<()> +local f: Types.F +local g: Types.G<...number> +local h: Types.H<> + )"; + + CheckResult resultUsers = frontend.check("Module/Users"); + LUAU_REQUIRE_NO_ERRORS(resultUsers); + + CHECK_EQ(toString(requireType("Module/Users", "a")), "A"); + CHECK_EQ(toString(requireType("Module/Users", "b")), "B"); + CHECK_EQ(toString(requireType("Module/Users", "c")), "C string>"); + CHECK_EQ(toString(requireType("Module/Users", "d")), "D"); + CHECK_EQ(toString(requireType("Module/Users", "e")), "E"); + CHECK_EQ(toString(requireType("Module/Users", "eVoid")), "E<>"); + CHECK_EQ(toString(requireType("Module/Users", "f")), "F"); + CHECK_EQ(toString(requireType("Module/Users", "g")), "G<...number, ()>"); + CHECK_EQ(toString(requireType("Module/Users", "h")), "H<>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = (T...) -> number +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "(...string) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type A = (T, V...) -> (U, W...) +type B = A +type C = A + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("B"), {true}), "(string, ...any) -> (number, ...any)"); + CHECK_EQ(toString(*lookupType("C"), {true}), "(string, boolean) -> (number, boolean)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type F ()> = (K) -> V +type R = { m: F } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") +{ + ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +local a: () -> (number, ...string) +local b: () -> (number, ...boolean) +a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' +caused by: + Type 'boolean' could not be converted into 'string')"); +} + TEST_SUITE_END(); From d70a0788c5b7993cc535c19b2bb732d56b760663 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:23:02 -0800 Subject: [PATCH 015/102] Sync to upstream/release/511 --- Analysis/include/Luau/TypeVar.h | 4 + Analysis/src/Linter.cpp | 15 ++ Analysis/src/TypeInfer.cpp | 29 ++- Analysis/src/TypeVar.cpp | 46 +++++ Analysis/src/Unifier.cpp | 6 + Ast/include/Luau/Ast.h | 3 +- Ast/include/Luau/DenseHash.h | 8 +- Ast/src/Parser.cpp | 20 +- CLI/Repl.cpp | 110 +++++++--- CMakeLists.txt | 6 + Compiler/src/TableShape.cpp | 49 ++++- LICENSE.txt | 2 +- VM/include/lua.h | 2 +- VM/src/lapi.cpp | 2 +- VM/src/ldo.cpp | 9 +- VM/src/lfunc.cpp | 96 ++++++--- VM/src/lfunc.h | 7 +- VM/src/lgc.cpp | 251 +++++++++++++++++++---- VM/src/lgc.h | 1 + VM/src/lgcdebug.cpp | 77 +++++-- VM/src/lmem.cpp | 347 +++++++++++++++++++++++++++++++- VM/src/lmem.h | 15 ++ VM/src/lobject.h | 16 +- VM/src/lstate.cpp | 37 +++- VM/src/lstate.h | 15 +- VM/src/lstring.cpp | 136 +++++++++---- VM/src/lstring.h | 2 +- VM/src/ltable.cpp | 8 +- VM/src/ltable.h | 2 +- VM/src/ludata.cpp | 6 +- VM/src/ludata.h | 2 +- VM/src/lvmexecute.cpp | 2 +- tests/Compiler.test.cpp | 24 +++ tests/Linter.test.cpp | 9 +- tests/Parser.test.cpp | 13 +- tests/TypeInfer.tables.test.cpp | 48 +++++ tests/TypeInfer.test.cpp | 29 +++ tests/conformance/closure.lua | 15 ++ tests/conformance/gc.lua | 28 +++ 39 files changed, 1278 insertions(+), 219 deletions(-) diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index fd2c2afa..3f5e26d6 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" @@ -499,6 +500,9 @@ bool maybeGeneric(const TypeId ty); // Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton bool maybeSingleton(TypeId ty); +// Checks if the length operator can be applied on the value of type +bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount); + struct SingletonTypes { const TypeId nilType; diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 1a5b24fe..905b70bf 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauLintTableCreateTable, false) + namespace Luau { @@ -2153,6 +2155,19 @@ private: "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } + if (FFlag::LuauLintTableCreateTable && func->index == "create" && node->args.size == 2) + { + // table.create(n, {...}) + if (args[1]->is()) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + + // table.create(n, {...} :: ?) + if (AstExprTypeAssertion* as = args[1]->as(); as && as->expr->is()) + emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + } + return true; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index bedcc022..e2d8a4fb 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) +LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) @@ -2066,17 +2067,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn if (get(operandType)) return {errorRecoveryType(scope)}; - if (get(operandType)) - return {numberType}; // Not strictly correct: metatables permit overriding this - - if (auto p = get(operandType)) + if (FFlag::LuauLengthOnCompositeType) { - if (p->type == PrimitiveTypeVar::String) - return {numberType}; - } + DenseHashSet seen{nullptr}; - if (!getTableType(operandType)) - reportError(TypeError{expr.location, NotATable{operandType}}); + if (!hasLength(operandType, seen, &recursionCount)) + reportError(TypeError{expr.location, NotATable{operandType}}); + } + else + { + if (get(operandType)) + return {numberType}; // Not strictly correct: metatables permit overriding this + + if (auto p = get(operandType)) + { + if (p->type == PrimitiveTypeVar::String) + return {numberType}; + } + + if (!getTableType(operandType)) + reportError(TypeError{expr.location, NotATable{operandType}}); + } return {numberType}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index ac2b2541..df5d76ed 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -5,6 +5,7 @@ #include "Luau/Common.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" @@ -19,6 +20,8 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) @@ -326,6 +329,49 @@ bool maybeSingleton(TypeId ty) return false; } +bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) +{ + LUAU_ASSERT(FFlag::LuauLengthOnCompositeType); + + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); + + ty = follow(ty); + + if (seen.contains(ty)) + return true; + + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty) || get(ty) || get(ty)) + return true; + + if (auto uty = get(ty)) + { + seen.insert(ty); + + for (TypeId part : uty->options) + { + if (!hasLength(part, seen, recursionCount)) + return false; + } + + return true; + } + + if (auto ity = get(ty)) + { + seen.insert(ty); + + for (TypeId part : ity->parts) + { + if (hasLength(part, seen, recursionCount)) + return true; + } + + return false; + } + + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6873c657..2bd9cf83 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,6 +13,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); +LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); @@ -99,6 +100,11 @@ struct PromoteTypeLevels bool operator()(TypePackId tp, const FreeTypePack&) { + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauCommittingTxnLogFreeTpPromote && FFlag::LuauUseCommittingTxnLog && !log.is(tp)) + return true; + promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); return true; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 573850a5..ac5950c0 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -1265,7 +1265,8 @@ struct hash size_t operator()(const Luau::AstName& value) const { // note: since operator== uses pointer identity, hashing function uses it as well - return value.value ? std::hash()(value.value) : 0; + // the hasher is the same as DenseHashPointer (DenseHash.h) + return (uintptr_t(value.value) >> 4) ^ (uintptr_t(value.value) >> 9); } }; diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index a7b2515a..65939bee 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -12,10 +12,6 @@ namespace Luau { -// Internal implementation of DenseHashSet and DenseHashMap -namespace detail -{ - struct DenseHashPointer { size_t operator()(const void* key) const @@ -24,6 +20,10 @@ struct DenseHashPointer } }; +// Internal implementation of DenseHashSet and DenseHashMap +namespace detail +{ + template using DenseHashDefault = std::conditional_t, DenseHashPointer, std::hash>; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 77787cb1..3c607d24 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) +LUAU_FASTFLAGVARIABLE(LuauStartingBrokenComment, false) namespace Luau { @@ -174,10 +175,23 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n const Lexeme::Type type = p.lexer.current().type; const Location loc = p.lexer.current().location; - p.lexer.next(); + if (FFlag::LuauStartingBrokenComment) + { + if (options.captureComments) + p.commentLocations.push_back(Comment{type, loc}); - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); + if (type == Lexeme::BrokenComment) + break; + + p.lexer.next(); + } + else + { + p.lexer.next(); + + if (options.captureComments) + p.commentLocations.push_back(Comment{type, loc}); + } } p.lexer.setSkipComments(true); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 36747f48..e5042152 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -35,10 +35,15 @@ enum class CompileFormat Binary }; +struct GlobalOptions +{ + int optimizationLevel = 1; +} globalOptions; + static Luau::CompileOptions copts() { Luau::CompileOptions result = {}; - result.optimizationLevel = 1; + result.optimizationLevel = globalOptions.optimizationLevel; result.debugLevel = 1; result.coverageLevel = coverageActive() ? 2 : 0; @@ -232,13 +237,14 @@ static std::string runCode(lua_State* L, const std::string& source) static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; + char lastSep = 0; for (;;) { - size_t dot = lookup.find('.'); - std::string_view prefix = lookup.substr(0, dot); + size_t sep = lookup.find_first_of(".:"); + std::string_view prefix = lookup.substr(0, sep); - if (dot == std::string_view::npos) + if (sep == std::string_view::npos) { // table, key lua_pushnil(L); @@ -249,11 +255,22 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, { // table, key, value std::string_view key = lua_tostring(L, -2); + int valueType = lua_type(L, -1); - if (!key.empty() && Luau::startsWith(key, prefix)) - completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + // If the last separator was a ':' (i.e. a method call) then only functions should be completed. + bool requiredValueType = (lastSep != ':' || valueType == LUA_TFUNCTION); + + if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) + { + std::string completion(editBuffer + std::string(key.substr(prefix.size()))); + if (valueType == LUA_TFUNCTION) + { + // Add an opening paren for function calls by default. + completion += "("; + } + completions.push_back(completion); + } } - lua_pop(L, 1); } @@ -266,10 +283,21 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_rawget(L, -2); lua_remove(L, -2); - if (!lua_istable(L, -1)) + if (lua_type(L, -1) == LUA_TSTRING) + { + // Replace the string object with the string class to perform further lookups of string functions + // Note: We retrieve the string class from _G to prevent issues if the user assigns to `string`. + lua_getglobal(L, "_G"); + lua_pushlstring(L, "string", 6); + lua_rawget(L, -2); + lua_remove(L, -2); + LUAU_ASSERT(lua_istable(L, -1)); + } + else if (!lua_istable(L, -1)) break; - lookup.remove_prefix(dot + 1); + lastSep = lookup[sep]; + lookup.remove_prefix(sep + 1); } } @@ -279,7 +307,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) { size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == '_')) + while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == ':' || editBuffer[start - 1] == '_')) start--; // look the value up in current global table first @@ -319,15 +347,8 @@ struct LinenoiseScopedHistory std::string historyFilepath; }; -static void runRepl() +static void runReplImpl(lua_State* L) { - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - setupState(L); - - luaL_sandboxthread(L); - linenoise::SetCompletionCallback([L](const char* editBuffer, std::vector& completions) { completeRepl(L, editBuffer, completions); }); @@ -368,7 +389,18 @@ static void runRepl() } } -static bool runFile(const char* name, lua_State* GL) +static void runRepl() +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + setupState(L); + luaL_sandboxthread(L); + runReplImpl(L); +} + +// `repl` is used it indicate if a repl should be started after executing the file. +static bool runFile(const char* name, lua_State* GL, bool repl) { std::optional source = readFile(name); if (!source) @@ -419,6 +451,10 @@ static bool runFile(const char* name, lua_State* GL) fprintf(stderr, "%s", error.c_str()); } + if (repl) + { + runReplImpl(L); + } lua_pop(GL, 1); return status == 0; } @@ -457,7 +493,7 @@ static bool compileFile(const char* name, CompileFormat format) bcb.setDumpSource(*source); } - Luau::compileOrThrow(bcb, *source); + Luau::compileOrThrow(bcb, *source, copts()); switch (format) { @@ -495,9 +531,11 @@ static void displayHelp(const char* argv0) printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); - printf(" -h, --help: Display this usage message.\n"); - printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -i, --interactive: Run an interactive REPL after executing the last script specified.\n"); + printf(" -O: use compiler optimization level (n=0-2).\n"); + printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -519,6 +557,7 @@ int main(int argc, char** argv) CompileFormat compileFormat{}; int profile = 0; bool coverage = false; + bool interactive = false; // Set the mode if the user has explicitly specified one. int argStart = 1; @@ -540,8 +579,8 @@ int main(int argc, char** argv) } else { - fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n"); - return -1; + fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); + return 1; } } @@ -552,6 +591,20 @@ int main(int argc, char** argv) displayHelp(argv[0]); return 0; } + else if (strcmp(argv[i], "-i") == 0 || strcmp(argv[i], "--interactive") == 0) + { + interactive = true; + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.optimizationLevel = level; + } else if (strcmp(argv[i], "--profile") == 0) { profile = 10000; // default to 10 KHz @@ -575,7 +628,7 @@ int main(int argc, char** argv) } else if (argv[i][0] == '-') { - fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]); + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); displayHelp(argv[0]); return 1; } @@ -623,8 +676,11 @@ int main(int argc, char** argv) int failed = 0; - for (const std::string& path : files) - failed += !runFile(path.c_str(), L); + for (size_t i = 0; i < files.size(); ++i) + { + bool isLastFile = i == files.size() - 1; + failed += !runFile(files[i].c_str(), L, interactive && isLastFile); + } if (profile) { diff --git a/CMakeLists.txt b/CMakeLists.txt index bafc59e5..77cf47e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,6 +78,12 @@ target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) +if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) + # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: + # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 + set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) +endif() + if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp index 7d99f222..9dc2f0a4 100644 --- a/Compiler/src/TableShape.cpp +++ b/Compiler/src/TableShape.cpp @@ -1,11 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "TableShape.h" +LUAU_FASTFLAGVARIABLE(LuauPredictTableSizeLoop, false) + namespace Luau { namespace Compile { +// conservative limit for the loop bound that establishes table array size +static const int kMaxLoopBound = 16; + static AstExprTable* getTableHint(AstExpr* expr) { // unadorned table literal @@ -27,7 +32,7 @@ struct ShapeVisitor : AstVisitor { size_t operator()(const std::pair& p) const { - return std::hash()(p.first) ^ std::hash()(p.second); + return DenseHashPointer()(p.first) ^ std::hash()(p.second); } }; @@ -36,10 +41,13 @@ struct ShapeVisitor : AstVisitor DenseHashMap tables; DenseHashSet, Hasher> fields; + DenseHashMap loops; // iterator => upper bound for 1..k + ShapeVisitor(DenseHashMap& shapes) : shapes(shapes) , tables(nullptr) , fields(std::pair()) + , loops(nullptr) { } @@ -63,16 +71,31 @@ struct ShapeVisitor : AstVisitor void assignField(AstExpr* expr, AstExpr* index) { AstExprLocal* lv = expr->as(); - AstExprConstantNumber* number = index->as(); + if (!lv) + return; - if (lv && number) + AstExprTable** table = tables.find(lv->local); + if (!table) + return; + + if (AstExprConstantNumber* number = index->as()) { - if (AstExprTable** table = tables.find(lv->local)) + TableShape& shape = shapes[*table]; + + if (number->value == double(shape.arraySize + 1)) + shape.arraySize += 1; + } + else if (AstExprLocal* iter = index->as()) + { + if (!FFlag::LuauPredictTableSizeLoop) + return; + + if (const unsigned int* bound = loops.find(iter->local)) { TableShape& shape = shapes[*table]; - if (number->value == double(shape.arraySize + 1)) - shape.arraySize += 1; + if (shape.arraySize == 0) + shape.arraySize = *bound; } } } @@ -117,6 +140,20 @@ struct ShapeVisitor : AstVisitor return false; } + + bool visit(AstStatFor* node) override + { + if (!FFlag::LuauPredictTableSizeLoop) + return true; + + AstExprConstantNumber* from = node->from->as(); + AstExprConstantNumber* to = node->to->as(); + + if (from && to && from->value == 1.0 && to->value >= 1.0 && to->value <= double(kMaxLoopBound) && !node->step) + loops[node->var] = unsigned(to->value); + + return true; + } }; void predictTableShapes(DenseHashMap& shapes, AstNode* root) diff --git a/LICENSE.txt b/LICENSE.txt index d63e7299..fa9914d7 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019-2021 Roblox Corporation +Copyright (c) 2019-2022 Roblox Corporation Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/VM/include/lua.h b/VM/include/lua.h index 55902160..c5dcef25 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -382,7 +382,7 @@ typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); /****************************************************************************** - * Copyright (c) 2019-2021 Roblox Corporation + * Copyright (c) 2019-2022 Roblox Corporation * Copyright (C) 1994-2008 Lua.org, PUC-Rio. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index c98b9590..d5416285 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -18,7 +18,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; -const char* luau_ident = "$Luau: Copyright (C) 2019-2021 Roblox Corporation $\n" +const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" "$URL: luau-lang.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 3cce7665..581506a8 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -150,12 +150,11 @@ l_noret luaD_throw(lua_State* L, int errcode) static void correctstack(lua_State* L, TValue* oldstack) { - CallInfo* ci; - GCObject* up; L->top = (L->top - oldstack) + L->stack; - for (up = L->openupval; up != NULL; up = up->gch.next) - gco2uv(up)->v = (gco2uv(up)->v - oldstack) + L->stack; - for (ci = L->base_ci; ci <= L->ci; ci++) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (UpVal* up = L->openupval; up != NULL; up = (UpVal*)up->next) + up->v = (up->v - oldstack) + L->stack; + for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) { ci->top = (ci->top - oldstack) + L->stack; ci->base = (ci->base - oldstack) + L->stack; diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 4178eda4..6088f71c 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -7,10 +7,11 @@ #include "lgc.h" LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) +LUAU_FASTFLAG(LuauGcPagedSweep) Proto* luaF_newproto(lua_State* L) { - Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); + Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); luaC_link(L, f, LUA_TPROTO); f->k = NULL; f->sizek = 0; @@ -38,7 +39,7 @@ Proto* luaF_newproto(lua_State* L) Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) { - Closure* c = luaM_new(L, Closure, sizeLclosure(nelems), L->activememcat); + Closure* c = luaM_newgco(L, Closure, sizeLclosure(nelems), L->activememcat); luaC_link(L, c, LUA_TFUNCTION); c->isC = 0; c->env = e; @@ -53,7 +54,7 @@ Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) { - Closure* c = luaM_new(L, Closure, sizeCclosure(nelems), L->activememcat); + Closure* c = luaM_newgco(L, Closure, sizeCclosure(nelems), L->activememcat); luaC_link(L, c, LUA_TFUNCTION); c->isC = 1; c->env = e; @@ -69,10 +70,9 @@ Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) UpVal* luaF_findupval(lua_State* L, StkId level) { global_State* g = L->global; - GCObject** pp = &L->openupval; + UpVal** pp = &L->openupval; UpVal* p; - UpVal* uv; - while (*pp != NULL && (p = gco2uv(*pp))->v >= level) + while (*pp != NULL && (p = *pp)->v >= level) { LUAU_ASSERT(p->v != &p->u.value); if (p->v == level) @@ -81,53 +81,95 @@ UpVal* luaF_findupval(lua_State* L, StkId level) changewhite(obj2gco(p)); /* resurrect it */ return p; } - pp = &p->next; + + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + pp = (UpVal**)&p->next; } - uv = luaM_new(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ + + UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ uv->tt = LUA_TUPVAL; uv->marked = luaC_white(g); uv->memcat = L->activememcat; uv->v = level; /* current value lives in the stack */ - uv->next = *pp; /* chain it in the proper position */ - *pp = obj2gco(uv); - uv->u.l.prev = &g->uvhead; /* double link it in `uvhead' list */ + + // chain the upvalue in the threads open upvalue list at the proper position + UpVal* next = *pp; + + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + uv->next = (GCObject*)next; + + if (FFlag::LuauGcPagedSweep) + { + uv->u.l.threadprev = pp; + if (next) + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + next->u.l.threadprev = (UpVal**)&uv->next; + } + } + + *pp = uv; + + // double link the upvalue in the global open upvalue list + uv->u.l.prev = &g->uvhead; uv->u.l.next = g->uvhead.u.l.next; uv->u.l.next->u.l.prev = uv; g->uvhead.u.l.next = uv; LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); return uv; } - -static void unlinkupval(UpVal* uv) +void luaF_unlinkupval(UpVal* uv) { + // unlink upvalue from the global open upvalue list LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); - uv->u.l.next->u.l.prev = uv->u.l.prev; /* remove from `uvhead' list */ + uv->u.l.next->u.l.prev = uv->u.l.prev; uv->u.l.prev->u.l.next = uv->u.l.next; + + if (FFlag::LuauGcPagedSweep) + { + // unlink upvalue from the thread open upvalue list + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and this and the following cast will not be required + *uv->u.l.threadprev = (UpVal*)uv->next; + + if (UpVal* next = (UpVal*)uv->next) + next->u.l.threadprev = uv->u.l.threadprev; + } } -void luaF_freeupval(lua_State* L, UpVal* uv) +void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { if (uv->v != &uv->u.value) /* is it open? */ - unlinkupval(uv); /* remove from open list */ - luaM_free(L, uv, sizeof(UpVal), uv->memcat); /* free upvalue */ + luaF_unlinkupval(uv); /* remove from open list */ + luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); /* free upvalue */ } void luaF_close(lua_State* L, StkId level) { - UpVal* uv; global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval - while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) + UpVal* uv; + while (L->openupval != NULL && (uv = L->openupval)->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); - L->openupval = uv->next; /* remove from `open' list */ - if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + + if (!FFlag::LuauGcPagedSweep) + L->openupval = (UpVal*)uv->next; /* remove from `open' list */ + + if (FFlag::LuauGcPagedSweep && isdead(g, o)) { - luaF_freeupval(L, uv); /* free upvalue */ + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); + // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again + uv->v = &uv->u.value; + } + else if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + { + luaF_freeupval(L, uv, NULL); /* free upvalue */ } else { - unlinkupval(uv); + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); setobj(L, &uv->u.value, uv->v); uv->v = &uv->u.value; /* now current value lives here */ luaC_linkupval(L, uv); /* link upvalue into `gcroot' list */ @@ -135,7 +177,7 @@ void luaF_close(lua_State* L, StkId level) } } -void luaF_freeproto(lua_State* L, Proto* f) +void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) { luaM_freearray(L, f->code, f->sizecode, Instruction, f->memcat); luaM_freearray(L, f->p, f->sizep, Proto*, f->memcat); @@ -146,13 +188,13 @@ void luaF_freeproto(lua_State* L, Proto* f) luaM_freearray(L, f->upvalues, f->sizeupvalues, TString*, f->memcat); if (f->debuginsn) luaM_freearray(L, f->debuginsn, f->sizecode, uint8_t, f->memcat); - luaM_free(L, f, sizeof(Proto), f->memcat); + luaM_freegco(L, f, sizeof(Proto), f->memcat, page); } -void luaF_freeclosure(lua_State* L, Closure* c) +void luaF_freeclosure(lua_State* L, Closure* c, lua_Page* page) { int size = c->isC ? sizeCclosure(c->nupvalues) : sizeLclosure(c->nupvalues); - luaM_free(L, c, size, c->memcat); + luaM_freegco(L, c, size, c->memcat, page); } const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 4be23667..8047cebe 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -12,7 +12,8 @@ LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e); LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); -LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f); -LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c); -LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv); +LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); +LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); +void luaF_unlinkupval(UpVal* uv); +LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 76ef7a06..50859b1e 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -8,12 +8,16 @@ #include "lfunc.h" #include "lstring.h" #include "ldo.h" +#include "lmem.h" #include "ludata.h" #include +LUAU_FASTFLAGVARIABLE(LuauGcPagedSweep, false) + #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 +#define GC_SWEEPPAGESTEPCOST 4 #define GC_INTERRUPT(state) \ { \ @@ -457,31 +461,31 @@ static void shrinkstack(lua_State* L) condhardstacktests(luaD_reallocstack(L, s_used)); } -static void freeobj(lua_State* L, GCObject* o) +static void freeobj(lua_State* L, GCObject* o, lua_Page* page) { switch (o->gch.tt) { case LUA_TPROTO: - luaF_freeproto(L, gco2p(o)); + luaF_freeproto(L, gco2p(o), page); break; case LUA_TFUNCTION: - luaF_freeclosure(L, gco2cl(o)); + luaF_freeclosure(L, gco2cl(o), page); break; case LUA_TUPVAL: - luaF_freeupval(L, gco2uv(o)); + luaF_freeupval(L, gco2uv(o), page); break; case LUA_TTABLE: - luaH_free(L, gco2h(o)); + luaH_free(L, gco2h(o), page); break; case LUA_TTHREAD: LUAU_ASSERT(gco2th(o) != L && gco2th(o) != L->global->mainthread); - luaE_freethread(L, gco2th(o)); + luaE_freethread(L, gco2th(o), page); break; case LUA_TSTRING: - luaS_free(L, gco2ts(o)); + luaS_free(L, gco2ts(o), page); break; case LUA_TUSERDATA: - luaU_freeudata(L, gco2u(o)); + luaU_freeudata(L, gco2u(o), page); break; default: LUAU_ASSERT(0); @@ -492,6 +496,8 @@ static void freeobj(lua_State* L, GCObject* o) static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + GCObject* curr; global_State* g = L->global; int deadmask = otherwhite(g); @@ -502,7 +508,7 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; if (curr->gch.tt == LUA_TTHREAD) { - sweepwholelist(L, &gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ + sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ lua_State* th = gco2th(curr); @@ -524,7 +530,7 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr *p = curr->gch.next; if (curr == g->rootgc) /* is the first element of the list? */ g->rootgc = curr->gch.next; /* adjust first */ - freeobj(L, curr); + freeobj(L, curr, NULL); } } @@ -537,14 +543,16 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr static void deletelist(lua_State* L, GCObject** p, GCObject* limit) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + GCObject* curr; while ((curr = *p) != limit) { if (curr->gch.tt == LUA_TTHREAD) /* delete open upvalues of each thread */ - deletelist(L, &gco2th(curr)->openupval, NULL); + deletelist(L, (GCObject**)&gco2th(curr)->openupval, NULL); *p = curr->gch.next; - freeobj(L, curr); + freeobj(L, curr, NULL); } } @@ -567,23 +575,62 @@ static void shrinkbuffersfull(lua_State* L) luaS_resize(L, hashsize); /* table is too big */ } +static bool deletegco(void* context, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + // we are in the process of deleting everything + // threads with open upvalues will attempt to close them all on removal + // but those upvalues might point to stack values that were already deleted + if (gco->gch.tt == LUA_TTHREAD) + { + lua_State* th = gco2th(gco); + + while (UpVal* uv = th->openupval) + { + luaF_unlinkupval(uv); + // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again + uv->v = &uv->u.value; + } + } + + lua_State* L = (lua_State*)context; + freeobj(L, gco, page); + return true; +} + void luaC_freeall(lua_State* L) { global_State* g = L->global; LUAU_ASSERT(L == g->mainthread); - LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ - deletelist(L, &g->rootgc, obj2gco(L)); + if (FFlag::LuauGcPagedSweep) + { + luaM_visitgco(L, L, deletegco); - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ - deletelist(L, &g->strt.hash[i], NULL); + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + LUAU_ASSERT(g->strt.hash[i] == NULL); - LUAU_ASSERT(L->global->strt.nuse == 0); - deletelist(L, &g->strbufgc, NULL); - // unfortunately, when string objects are freed, the string table use count is decremented - // even when the string is a buffer that wasn't placed into the table - L->global->strt.nuse = 0; + LUAU_ASSERT(L->global->strt.nuse == 0); + LUAU_ASSERT(g->strbufgc == NULL); + } + else + { + LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ + + deletelist(L, &g->rootgc, obj2gco(L)); + + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + deletelist(L, (GCObject**)&g->strt.hash[i], NULL); + + LUAU_ASSERT(L->global->strt.nuse == 0); + deletelist(L, (GCObject**)&g->strbufgc, NULL); + + // unfortunately, when string objects are freed, the string table use count is decremented + // even when the string is a buffer that wasn't placed into the table + L->global->strt.nuse = 0; + } } static void markmt(global_State* g) @@ -648,12 +695,88 @@ static size_t atomic(lua_State* L) /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); g->sweepstrgc = 0; - g->sweepgc = &g->rootgc; - g->gcstate = GCSsweepstring; + + if (FFlag::LuauGcPagedSweep) + { + g->sweepgcopage = g->allgcopages; + g->gcstate = GCSsweep; + } + else + { + g->sweepgc = &g->rootgc; + g->gcstate = GCSsweepstring; + } return work; } +static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + int deadmask = otherwhite(g); + LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); // make sure we never sweep fixed objects + + int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; + + g->gcstats.currcycle.sweepitems++; + + if (gco->gch.tt == LUA_TTHREAD) + { + lua_State* th = gco2th(gco); + + if (alive) + { + resetbit(th->stackstate, THREAD_SLEEPINGBIT); + shrinkstack(th); + } + } + + if (alive) + { + LUAU_ASSERT(!isdead(g, gco)); + makewhite(g, gco); // make it white (for next cycle) + return false; + } + + LUAU_ASSERT(isdead(g, gco)); + freeobj(L, gco, page); + return true; +} + +// a version of generic luaM_visitpage specialized for the main sweep stage +static int sweepgcopage(lua_State* L, lua_Page* page) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + char* start; + char* end; + int busyBlocks; + int blockSize; + luaM_getpagewalkinfo(page, &start, &end, &busyBlocks, &blockSize); + + for (char* pos = start; pos != end; pos += blockSize) + { + GCObject* gco = (GCObject*)pos; + + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; + + // when true is returned it means that the element was deleted + if (sweepgco(L, page, gco)) + { + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + return (pos - start) / blockSize + 1; + } + } + + return (end - start) / blockSize; +} + static size_t gcstep(lua_State* L, size_t limit) { size_t cost = 0; @@ -706,15 +829,21 @@ static size_t gcstep(lua_State* L, size_t limit) g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; cost = atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); + + if (FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->gcstate == GCSsweep); + else + LUAU_ASSERT(g->gcstate == GCSsweepstring); break; } case GCSsweepstring: { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + while (g->sweepstrgc < g->strt.size && cost < limit) { size_t traversedcount = 0; - sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++], &traversedcount); g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPCOST; @@ -727,7 +856,7 @@ static size_t gcstep(lua_State* L, size_t limit) uint32_t nuse = L->global->strt.nuse; size_t traversedcount = 0; - sweepwholelist(L, &g->strbufgc, &traversedcount); + sweepwholelist(L, (GCObject**)&g->strbufgc, &traversedcount); L->global->strt.nuse = nuse; @@ -738,19 +867,44 @@ static size_t gcstep(lua_State* L, size_t limit) } case GCSsweep: { - while (*g->sweepgc && cost < limit) + if (FFlag::LuauGcPagedSweep) { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + while (g->sweepgcopage && cost < limit) + { + lua_Page* next = luaM_getnextgcopage(g->sweepgcopage); // page sweep might destroy the page - g->gcstats.currcycle.sweepitems += traversedcount; - cost += GC_SWEEPMAX * GC_SWEEPCOST; + int steps = sweepgcopage(L, g->sweepgcopage); + + g->sweepgcopage = next; + cost += steps * GC_SWEEPPAGESTEPCOST; + } + + // nothing more to sweep? + if (g->sweepgcopage == NULL) + { + // don't forget to visit main thread + sweepgco(L, NULL, obj2gco(g->mainthread)); + + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } } + else + { + while (*g->sweepgc && cost < limit) + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPMAX * GC_SWEEPCOST; + } + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } } break; } @@ -877,12 +1031,19 @@ void luaC_fullgc(lua_State* L) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; - g->sweepgc = &g->rootgc; + if (FFlag::LuauGcPagedSweep) + g->sweepgcopage = g->allgcopages; + else + g->sweepgc = &g->rootgc; /* reset other collector lists */ g->gray = NULL; g->grayagain = NULL; g->weak = NULL; - g->gcstate = GCSsweepstring; + + if (FFlag::LuauGcPagedSweep) + g->gcstate = GCSsweep; + else + g->gcstate = GCSsweepstring; } LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); /* finish any pending sweep phase */ @@ -979,8 +1140,11 @@ void luaC_barrierback(lua_State* L, Table* t) void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt) { global_State* g = L->global; - o->gch.next = g->rootgc; - g->rootgc = o; + if (!FFlag::LuauGcPagedSweep) + { + o->gch.next = g->rootgc; + g->rootgc = o; + } o->gch.marked = luaC_white(g); o->gch.tt = tt; o->gch.memcat = L->activememcat; @@ -990,8 +1154,13 @@ void luaC_linkupval(lua_State* L, UpVal* uv) { global_State* g = L->global; GCObject* o = obj2gco(uv); - o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ - g->rootgc = o; + + if (!FFlag::LuauGcPagedSweep) + { + o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ + g->rootgc = o; + } + if (isgray(o)) { if (keepinvariant(g)) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index f434e506..4455fec5 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -13,6 +13,7 @@ #define GCSpropagate 1 #define GCSpropagateagain 2 #define GCSatomic 3 +// TODO: remove with FFlagLuauGcPagedSweep #define GCSsweepstring 4 #define GCSsweep 5 diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index c66de9c1..906fb0d0 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -2,16 +2,19 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lgc.h" +#include "lfunc.h" +#include "lmem.h" #include "lobject.h" #include "lstate.h" -#include "ltable.h" -#include "lfunc.h" #include "lstring.h" +#include "ltable.h" #include "ludata.h" #include #include +LUAU_FASTFLAG(LuauGcPagedSweep) + static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -101,10 +104,11 @@ static void validatestack(global_State* g, lua_State* l) if (l->namecall) validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (UpVal* uv = l->openupval; uv; uv = (UpVal*)uv->next) { - LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); - LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); + LUAU_ASSERT(uv->tt == LUA_TUPVAL); + LUAU_ASSERT(uv->v != &uv->u.value); } } @@ -178,6 +182,8 @@ static void validateobj(global_State* g, GCObject* o) static void validatelist(global_State* g, GCObject* o) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + while (o) { validateobj(g, o); @@ -216,6 +222,17 @@ static void validategraylist(global_State* g, GCObject* o) } } +static bool validategco(void* context, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + lua_State* L = (lua_State*)context; + global_State* g = L->global; + + validateobj(g, gco); + return false; +} + void luaC_validate(lua_State* L) { global_State* g = L->global; @@ -231,11 +248,18 @@ void luaC_validate(lua_State* L) validategraylist(g, g->gray); validategraylist(g, g->grayagain); - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, g->strt.hash[i]); + if (FFlag::LuauGcPagedSweep) + { + luaM_visitgco(L, L, validategco); + } + else + { + for (int i = 0; i < g->strt.size; ++i) + validatelist(g, (GCObject*)(g->strt.hash[i])); - validatelist(g, g->rootgc); - validatelist(g, g->strbufgc); + validatelist(g, g->rootgc); + validatelist(g, (GCObject*)(g->strbufgc)); + } for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { @@ -499,6 +523,8 @@ static void dumpobj(FILE* f, GCObject* o) static void dumplist(FILE* f, GCObject* o) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + while (o) { dumpref(f, o); @@ -509,22 +535,45 @@ static void dumplist(FILE* f, GCObject* o) // thread has additional list containing collectable objects that are not present in rootgc if (o->gch.tt == LUA_TTHREAD) - dumplist(f, gco2th(o)->openupval); + dumplist(f, (GCObject*)gco2th(o)->openupval); o = o->gch.next; } } +static bool dumpgco(void* context, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + FILE* f = (FILE*)context; + + dumpref(f, gco); + fputc(':', f); + dumpobj(f, gco); + fputc(',', f); + fputc('\n', f); + + return false; +} + void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) { global_State* g = L->global; FILE* f = static_cast(file); fprintf(f, "{\"objects\":{\n"); - dumplist(f, g->rootgc); - dumplist(f, g->strbufgc); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, g->strt.hash[i]); + + if (FFlag::LuauGcPagedSweep) + { + luaM_visitgco(L, f, dumpgco); + } + else + { + dumplist(f, g->rootgc); + dumplist(f, (GCObject*)(g->strbufgc)); + for (int i = 0; i < g->strt.size; ++i) + dumplist(f, (GCObject*)(g->strt.hash[i])); + } fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , fprintf(f, "},\"roots\":{\n"); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 9f9d4a98..6d3b7772 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauGcPagedSweep) + #ifndef __has_feature #define __has_feature(x) 0 #endif @@ -42,13 +44,21 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table #endif static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); +// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(16, 16, 16) static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); +// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(48, 32, 32) static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); +// TODO (FFlagLuauGcPagedSweep): new code with old 'next' pointer requires that GCObject start at the same point as TString/UpVal +static_assert(offsetof(GCObject, uv) == 0, "UpVal data must be located at the start of the GCObject"); +static_assert(offsetof(GCObject, ts) == 0, "TString data must be located at the start of the GCObject"); + const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms +// TODO (FFlagLuauGcPagedSweep): when 'next' is removed, 'kBlockHeader' can be used unconditionally +const size_t kGCOHeader = sizeof(GCheader) > kBlockHeader ? sizeof(GCheader) : kBlockHeader; struct SizeClassConfig { @@ -96,6 +106,7 @@ const SizeClassConfig kSizeClassConfig; // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) +#define freegcolink(block) (*(void**)((char*)block + kGCOHeader)) /* ** About the realloc function: @@ -117,15 +128,22 @@ const SizeClassConfig kSizeClassConfig; struct lua_Page { + // list of pages with free blocks lua_Page* prev; lua_Page* next; + // list of all gco pages + lua_Page* gcolistprev; + lua_Page* gcolistnext; + int busyBlocks; int blockSize; void* freeList; int freeNext; + int pageSize; + union { char data[1]; @@ -141,6 +159,8 @@ l_noret luaM_toobig(lua_State* L) static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + global_State* g = L->global; lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, kPageSize); if (!page) @@ -155,6 +175,9 @@ static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) page->prev = NULL; page->next = NULL; + page->gcolistprev = NULL; + page->gcolistnext = NULL; + page->busyBlocks = 0; page->blockSize = blockSize; @@ -171,8 +194,69 @@ static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) return page; } +static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int blockSize, int blockCount) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + + lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); + if (!page) + luaD_throw(L, LUA_ERRMEM); + + ASAN_POISON_MEMORY_REGION(page->data, blockSize * blockCount); + + // setup page header + page->prev = NULL; + page->next = NULL; + + page->gcolistprev = NULL; + page->gcolistnext = NULL; + + page->busyBlocks = 0; + page->blockSize = blockSize; + + // note: we start with the last block in the page and move downward + // either order would work, but that way we don't need to store the block count in the page + // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order + page->freeList = NULL; + page->freeNext = (blockCount - 1) * blockSize; + + page->pageSize = pageSize; + + if (gcopageset) + { + page->gcolistnext = *gcopageset; + if (page->gcolistnext) + page->gcolistnext->gcolistprev = page; + *gcopageset = page; + } + + return page; +} + +static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); + int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; + + lua_Page* page = newpage(L, gcopageset, kPageSize, blockSize, blockCount); + + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!freepageset[sizeClass]); + freepageset[sizeClass] = page; + + return page; +} + static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + global_State* g = L->global; // remove page from freelist @@ -188,6 +272,44 @@ static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) (*g->frealloc)(L, g->ud, page, kPageSize, 0); } +static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + if (gcopageset) + { + // remove page from alllist + if (page->gcolistnext) + page->gcolistnext->gcolistprev = page->gcolistprev; + + if (page->gcolistprev) + page->gcolistprev->gcolistnext = page->gcolistnext; + else if (*gcopageset == page) + *gcopageset = page->gcolistnext; + } + + // so long + (*g->frealloc)(L, g->ud, page, page->pageSize, 0); +} + +static void freeclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, lua_Page* page, uint8_t sizeClass) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + // remove page from freelist + if (page->next) + page->next->prev = page->prev; + + if (page->prev) + page->prev->next = page->next; + else if (freepageset[sizeClass] == page) + freepageset[sizeClass] = page->next; + + freepage(L, gcopageset, page); +} + static void* luaM_newblock(lua_State* L, int sizeClass) { global_State* g = L->global; @@ -195,7 +317,12 @@ static void* luaM_newblock(lua_State* L, int sizeClass) // slow path: no page in the freelist, allocate a new one if (!page) - page = luaM_newpage(L, sizeClass); + { + if (FFlag::LuauGcPagedSweep) + page = newclasspage(L, g->freepages, NULL, sizeClass, true); + else + page = luaM_newpage(L, sizeClass); + } LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); @@ -236,6 +363,55 @@ static void* luaM_newblock(lua_State* L, int sizeClass) return (char*)block + kBlockHeader; } +static void* luaM_newgcoblock(lua_State* L, int sizeClass) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + lua_Page* page = g->freegcopages[sizeClass]; + + // slow path: no page in the freelist, allocate a new one + if (!page) + page = newclasspage(L, g->freegcopages, &g->allgcopages, sizeClass, false); + + LUAU_ASSERT(!page->prev); + LUAU_ASSERT(page->freeList || page->freeNext >= 0); + LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + + void* block; + + if (page->freeNext >= 0) + { + block = &page->data + page->freeNext; + ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); + + page->freeNext -= page->blockSize; + page->busyBlocks++; + } + else + { + // when separate block metadata is not used, free list link is stored inside the block data itself + block = (char*)page->freeList - kGCOHeader; + + ASAN_UNPOISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + + page->freeList = freegcolink(block); + page->busyBlocks++; + } + + // if we allocate the last block out of a page, we need to remove it from free list + if (!page->freeList && page->freeNext < 0) + { + g->freegcopages[sizeClass] = page->next; + if (page->next) + page->next->prev = NULL; + page->next = NULL; + } + + // the user data is right after the metadata + return (char*)block; +} + static void luaM_freeblock(lua_State* L, int sizeClass, void* block) { global_State* g = L->global; @@ -270,12 +446,45 @@ static void luaM_freeblock(lua_State* L, int sizeClass, void* block) // if it's the last block in the page, we don't need the page if (page->busyBlocks == 0) - luaM_freepage(L, page, sizeClass); + { + if (FFlag::LuauGcPagedSweep) + freeclasspage(L, g->freepages, NULL, page, sizeClass); + else + luaM_freepage(L, page, sizeClass); + } +} + +static void luaM_freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + // if the page wasn't in the page free list, it should be now since it got a block! + if (!page->freeList && page->freeNext < 0) + { + LUAU_ASSERT(!page->prev); + LUAU_ASSERT(!page->next); + + page->next = g->freegcopages[sizeClass]; + if (page->next) + page->next->prev = page; + g->freegcopages[sizeClass] = page; + } + + // when separate block metadata is not used, free list link is stored inside the block data itself + freegcolink(block) = page->freeList; + page->freeList = (char*)block + kGCOHeader; + + ASAN_POISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + + page->busyBlocks--; + + // if it's the last block in the page, we don't need the page + if (page->busyBlocks == 0) + freeclasspage(L, g->freegcopages, &g->allgcopages, page, sizeClass); } -/* -** generic allocation routines. -*/ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) { global_State* g = L->global; @@ -292,6 +501,43 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) return block; } +GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) +{ + if (!FFlag::LuauGcPagedSweep) + return (GCObject*)luaM_new_(L, nsize, memcat); + + global_State* g = L->global; + + int nclass = sizeclass(nsize); + + void* block = NULL; + + if (nclass >= 0) + { + LUAU_ASSERT(nsize > 8); + + block = luaM_newgcoblock(L, nclass); + } + else + { + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + + block = &page->data; + ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); + + page->freeNext -= page->blockSize; + page->busyBlocks++; + } + + if (block == NULL && nsize > 0) + luaD_throw(L, LUA_ERRMEM); + + g->totalbytes += nsize; + g->memcatbytes[memcat] += nsize; + + return (GCObject*)block; +} + void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) { global_State* g = L->global; @@ -308,6 +554,36 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) g->memcatbytes[memcat] -= osize; } +void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, lua_Page* page) +{ + if (!FFlag::LuauGcPagedSweep) + { + luaM_free_(L, block, osize, memcat); + return; + } + + global_State* g = L->global; + LUAU_ASSERT((osize == 0) == (block == NULL)); + + int oclass = sizeclass(osize); + + if (oclass >= 0) + { + block->gch.tt = LUA_TNIL; + + luaM_freegcoblock(L, oclass, block, page); + } + else + { + LUAU_ASSERT(page->busyBlocks == 1); + + freepage(L, &g->allgcopages, page); + } + + g->totalbytes -= osize; + g->memcatbytes[memcat] -= osize; +} + void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8_t memcat) { global_State* g = L->global; @@ -344,3 +620,64 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 g->memcatbytes[memcat] += nsize - osize; return result; } + +void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + int blockCount = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; + + *start = page->data + page->freeNext + page->blockSize; + *end = page->data + blockCount * page->blockSize; + *busyBlocks = page->busyBlocks; + *blockSize = page->blockSize; +} + +lua_Page* luaM_getnextgcopage(lua_Page* page) +{ + return page->gcolistnext; +} + +void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + char* start; + char* end; + int busyBlocks; + int blockSize; + luaM_getpagewalkinfo(page, &start, &end, &busyBlocks, &blockSize); + + for (char* pos = start; pos != end; pos += blockSize) + { + GCObject* gco = (GCObject*)pos; + + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; + + // when true is returned it means that the element was deleted + if (visitor(context, page, gco)) + { + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + break; + } + } +} + +void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + for (lua_Page* curr = g->allgcopages; curr;) + { + lua_Page* next = curr->gcolistnext; // page blockvisit might destroy the page + + luaM_visitpage(curr, context, visitor); + + curr = next; + } +} diff --git a/VM/src/lmem.h b/VM/src/lmem.h index f526a1b6..1bfe48fa 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -4,8 +4,15 @@ #include "lua.h" +struct lua_Page; +union GCObject; + +// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_newgco to luaM_new #define luaM_new(L, t, size, memcat) cast_to(t*, luaM_new_(L, size, memcat)) +#define luaM_newgco(L, t, size, memcat) cast_to(t*, luaM_newgco_(L, size, memcat)) +// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_freegco to luaM_free #define luaM_free(L, p, size, memcat) luaM_free_(L, (p), size, memcat) +#define luaM_freegco(L, p, size, memcat, page) luaM_freegco_(L, obj2gco(p), size, memcat, page) #define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) @@ -15,7 +22,15 @@ ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(n, sizeof(t)), memcat))) LUAI_FUNC void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat); +LUAI_FUNC GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat); LUAI_FUNC void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat); +LUAI_FUNC void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, lua_Page* page); LUAI_FUNC void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8_t memcat); LUAI_FUNC l_noret luaM_toobig(lua_State* L); + +LUAI_FUNC void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize); +LUAI_FUNC lua_Page* luaM_getnextgcopage(lua_Page* page); + +LUAI_FUNC void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)); +LUAI_FUNC void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index b642cf78..57ffd82a 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -11,12 +11,11 @@ typedef union GCObject GCObject; /* -** Common Header for all collectible objects (in macro form, to be -** included in other objects) +** Common Header for all collectible objects (in macro form, to be included in other objects) */ // clang-format off #define CommonHeader \ - GCObject* next; \ + GCObject* next; /* TODO: remove with FFlagLuauGcPagedSweep */ \ uint8_t tt; uint8_t marked; uint8_t memcat // clang-format on @@ -229,8 +228,10 @@ typedef TValue* StkId; /* index to stack elements */ typedef struct TString { CommonHeader; + // 1 byte padding int16_t atom; + // 2 byte padding unsigned int hash; unsigned int len; @@ -314,14 +315,21 @@ typedef struct LocVar typedef struct UpVal { CommonHeader; + // 1 (x86) or 5 (x64) byte padding TValue* v; /* points to stack or to its own value */ union { TValue value; /* the value (when closed) */ struct - { /* double linked list (when open) */ + { + /* global double linked list (when open) */ struct UpVal* prev; struct UpVal* next; + + /* thread double linked list (when open) */ + // TODO: when FFlagLuauGcPagedSweep is removed, old outer 'next' value will be placed here + /* note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State */ + struct UpVal** threadprev; } l; } u; } UpVal; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 24e97063..6762c638 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,6 +10,8 @@ #include "ldo.h" #include "ldebug.h" +LUAU_FASTFLAG(LuauGcPagedSweep) + /* ** Main thread combines a thread state and the global state */ @@ -86,14 +88,21 @@ static void close_state(lua_State* L) global_State* g = L->global; luaF_close(L, L->stack); /* close all upvalues for this thread */ luaC_freeall(L); /* collect all objects */ - LUAU_ASSERT(g->rootgc == obj2gco(L)); + if (!FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->rootgc == obj2gco(L)); LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); freestack(L, L); - LUAU_ASSERT(g->totalbytes == sizeof(LG)); for (int i = 0; i < LUA_SIZECLASSES; i++) + { LUAU_ASSERT(g->freepages[i] == NULL); + if (FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->freegcopages[i] == NULL); + } + if (FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->allgcopages == NULL); + LUAU_ASSERT(g->totalbytes == sizeof(LG)); LUAU_ASSERT(g->memcatbytes[0] == sizeof(LG)); for (int i = 1; i < LUA_MEMORY_CATEGORIES; i++) LUAU_ASSERT(g->memcatbytes[i] == 0); @@ -102,7 +111,7 @@ static void close_state(lua_State* L) lua_State* luaE_newthread(lua_State* L) { - lua_State* L1 = luaM_new(L, lua_State, sizeof(lua_State), L->activememcat); + lua_State* L1 = luaM_newgco(L, lua_State, sizeof(lua_State), L->activememcat); luaC_link(L, L1, LUA_TTHREAD); preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category @@ -113,7 +122,7 @@ lua_State* luaE_newthread(lua_State* L) return L1; } -void luaE_freethread(lua_State* L, lua_State* L1) +void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) { luaF_close(L1, L1->stack); /* close all upvalues for this thread */ LUAU_ASSERT(L1->openupval == NULL); @@ -121,7 +130,7 @@ void luaE_freethread(lua_State* L, lua_State* L1) if (g->cb.userthread) g->cb.userthread(NULL, L1); freestack(L, L1); - luaM_free(L, L1, sizeof(lua_State), L1->memcat); + luaM_freegco(L, L1, sizeof(lua_State), L1->memcat, page); } void lua_resetthread(lua_State* L) @@ -162,7 +171,8 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) return NULL; L = (lua_State*)l; g = &((LG*)L)->g; - L->next = NULL; + if (!FFlag::LuauGcPagedSweep) + L->next = NULL; L->tt = LUA_TTHREAD; L->marked = g->currentwhite = bit2mask(WHITE0BIT, FIXEDBIT); L->memcat = 0; @@ -185,9 +195,11 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->strt.hash = NULL; setnilvalue(registry(L)); g->gcstate = GCSpause; - g->rootgc = obj2gco(L); + if (!FFlag::LuauGcPagedSweep) + g->rootgc = obj2gco(L); g->sweepstrgc = 0; - g->sweepgc = &g->rootgc; + if (!FFlag::LuauGcPagedSweep) + g->sweepgc = &g->rootgc; g->gray = NULL; g->grayagain = NULL; g->weak = NULL; @@ -197,7 +209,16 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->gcstepmul = LUAI_GCSTEPMUL; g->gcstepsize = LUAI_GCSTEPSIZE << 10; for (i = 0; i < LUA_SIZECLASSES; i++) + { g->freepages[i] = NULL; + if (FFlag::LuauGcPagedSweep) + g->freegcopages[i] = NULL; + } + if (FFlag::LuauGcPagedSweep) + { + g->allgcopages = NULL; + g->sweepgcopage = NULL; + } for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 56379883..080f0024 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -22,7 +22,7 @@ typedef struct stringtable { - GCObject** hash; + TString** hash; uint32_t nuse; /* number of elements */ int size; } stringtable; @@ -149,13 +149,15 @@ typedef struct global_State int sweepstrgc; /* position of sweep in `strt' */ + // TODO: remove with FFlagLuauGcPagedSweep GCObject* rootgc; /* list of all collectable objects */ + // TODO: remove with FFlagLuauGcPagedSweep GCObject** sweepgc; /* position of sweep in `rootgc' */ GCObject* gray; /* list of gray objects */ GCObject* grayagain; /* list of objects to be traversed atomically */ GCObject* weak; /* list of weak tables (to be cleared) */ - GCObject* strbufgc; // list of all string buffer objects + TString* strbufgc; // list of all string buffer objects size_t GCthreshold; // when totalbytes > GCthreshold; run GC step @@ -164,7 +166,10 @@ typedef struct global_State int gcstepmul; // see LUAI_GCSTEPMUL int gcstepsize; // see LUAI_GCSTEPSIZE - struct lua_Page* freepages[LUA_SIZECLASSES]; /* free page linked list for each size class */ + struct lua_Page* freepages[LUA_SIZECLASSES]; // free page linked list for each size class for non-collectable objects + struct lua_Page* freegcopages[LUA_SIZECLASSES]; // free page linked list for each size class for collectable objects + struct lua_Page* allgcopages; // page linked list with all pages for all classes + struct lua_Page* sweepgcopage; // position of the sweep in `allgcopages' size_t memcatbytes[LUA_MEMORY_CATEGORIES]; /* total amount of memory used by each memory category */ @@ -231,7 +236,7 @@ struct lua_State TValue l_gt; /* table of globals */ TValue env; /* temporary place for environments */ - GCObject* openupval; /* list of open upvalues in this stack */ + UpVal* openupval; /* list of open upvalues in this stack */ GCObject* gclist; TString* namecall; /* when invoked from Luau using NAMECALL, what method do we need to invoke? */ @@ -268,4 +273,4 @@ union GCObject #define obj2gco(v) check_exp(iscollectable(v), cast_to(GCObject*, (v) + 0)) LUAI_FUNC lua_State* luaE_newthread(lua_State* L); -LUAI_FUNC void luaE_freethread(lua_State* L, lua_State* L1); +LUAI_FUNC void luaE_freethread(lua_State* L, lua_State* L1, struct lua_Page* page); diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index a9e90d17..cb22cc23 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAG(LuauGcPagedSweep) + unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -44,26 +46,25 @@ unsigned int luaS_hash(const char* str, size_t len) void luaS_resize(lua_State* L, int newsize) { - GCObject** newhash; - stringtable* tb; - int i; if (L->global->gcstate == GCSsweepstring) return; /* cannot resize during GC traverse */ - newhash = luaM_newarray(L, newsize, GCObject*, 0); - tb = &L->global->strt; - for (i = 0; i < newsize; i++) + TString** newhash = luaM_newarray(L, newsize, TString*, 0); + stringtable* tb = &L->global->strt; + for (int i = 0; i < newsize; i++) newhash[i] = NULL; /* rehash */ - for (i = 0; i < tb->size; i++) + for (int i = 0; i < tb->size; i++) { - GCObject* p = tb->hash[i]; + TString* p = tb->hash[i]; while (p) { /* for each node in the list */ - GCObject* next = p->gch.next; /* save next */ - unsigned int h = gco2ts(p)->hash; + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + TString* next = (TString*)p->next; /* save next */ + unsigned int h = p->hash; int h1 = lmod(h, newsize); /* new position */ LUAU_ASSERT(cast_int(h % newsize) == lmod(h, newsize)); - p->gch.next = newhash[h1]; /* chain it */ + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + p->next = (GCObject*)newhash[h1]; /* chain it */ newhash[h1] = p; p = next; } @@ -79,7 +80,7 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) stringtable* tb; if (l > MAXSSIZE) luaM_toobig(L); - ts = luaM_new(L, TString, sizestring(l), L->activememcat); + ts = luaM_newgco(L, TString, sizestring(l), L->activememcat); ts->len = unsigned(l); ts->hash = h; ts->marked = luaC_white(L->global); @@ -90,8 +91,9 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; tb = &L->global->strt; h = lmod(h, tb->size); - ts->next = tb->hash[h]; /* chain new entry */ - tb->hash[h] = obj2gco(ts); + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the case will not be required + ts->next = (GCObject*)tb->hash[h]; /* chain new entry */ + tb->hash[h] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) luaS_resize(L, tb->size * 2); /* too crowded */ @@ -101,28 +103,41 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) static void linkstrbuf(lua_State* L, TString* ts) { global_State* g = L->global; - GCObject* o = obj2gco(ts); - o->gch.next = g->strbufgc; - g->strbufgc = o; - o->gch.marked = luaC_white(g); + + if (FFlag::LuauGcPagedSweep) + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + ts->next = (GCObject*)g->strbufgc; + g->strbufgc = ts; + ts->marked = luaC_white(g); + } + else + { + GCObject* o = obj2gco(ts); + o->gch.next = (GCObject*)g->strbufgc; + g->strbufgc = gco2ts(o); + o->gch.marked = luaC_white(g); + } } static void unlinkstrbuf(lua_State* L, TString* ts) { global_State* g = L->global; - GCObject** p = &g->strbufgc; + TString** p = &g->strbufgc; - while (GCObject* curr = *p) + while (TString* curr = *p) { - if (curr == obj2gco(ts)) + if (curr == ts) { - *p = curr->gch.next; + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + *p = (TString*)curr->next; return; } else { - p = &curr->gch.next; + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + p = (TString**)&curr->next; } } @@ -134,7 +149,7 @@ TString* luaS_bufstart(lua_State* L, size_t size) if (size > MAXSSIZE) luaM_toobig(L); - TString* ts = luaM_new(L, TString, sizestring(size), L->activememcat); + TString* ts = luaM_newgco(L, TString, sizestring(size), L->activememcat); ts->tt = LUA_TSTRING; ts->memcat = L->activememcat; @@ -152,15 +167,14 @@ TString* luaS_buffinish(lua_State* L, TString* ts) int bucket = lmod(h, tb->size); // search if we already have this string in the hash table - for (GCObject* o = tb->hash[bucket]; o != NULL; o = o->gch.next) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (TString* el = tb->hash[bucket]; el != NULL; el = (TString*)el->next) { - TString* el = gco2ts(o); - if (el->len == ts->len && memcmp(el->data, ts->data, ts->len) == 0) { // string may be dead - if (isdead(L->global, o)) - changewhite(o); + if (isdead(L->global, obj2gco(el))) + changewhite(obj2gco(el)); return el; } @@ -173,8 +187,9 @@ TString* luaS_buffinish(lua_State* L, TString* ts) // Complete string object ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; - ts->next = tb->hash[bucket]; // chain new entry - tb->hash[bucket] = obj2gco(ts); + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + ts->next = (GCObject*)tb->hash[bucket]; // chain new entry + tb->hash[bucket] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) @@ -185,24 +200,63 @@ TString* luaS_buffinish(lua_State* L, TString* ts) TString* luaS_newlstr(lua_State* L, const char* str, size_t l) { - GCObject* o; unsigned int h = luaS_hash(str, l); - for (o = L->global->strt.hash[lmod(h, L->global->strt.size)]; o != NULL; o = o->gch.next) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (TString* el = L->global->strt.hash[lmod(h, L->global->strt.size)]; el != NULL; el = (TString*)el->next) { - TString* ts = gco2ts(o); - if (ts->len == l && (memcmp(str, getstr(ts), l) == 0)) + if (el->len == l && (memcmp(str, getstr(el), l) == 0)) { /* string may be dead */ - if (isdead(L->global, o)) - changewhite(o); - return ts; + if (isdead(L->global, obj2gco(el))) + changewhite(obj2gco(el)); + return el; } } return newlstr(L, str, l, h); /* not found */ } -void luaS_free(lua_State* L, TString* ts) +static bool unlinkstr(lua_State* L, TString* ts) { - L->global->strt.nuse--; - luaM_free(L, ts, sizestring(ts->len), ts->memcat); + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + TString** p = &g->strt.hash[lmod(ts->hash, g->strt.size)]; + + while (TString* curr = *p) + { + if (curr == ts) + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + *p = (TString*)curr->next; + return true; + } + else + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + p = (TString**)&curr->next; + } + } + + return false; +} + +void luaS_free(lua_State* L, TString* ts, lua_Page* page) +{ + if (FFlag::LuauGcPagedSweep) + { + // Unchain from the string table + if (!unlinkstr(L, ts)) + unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands + else + L->global->strt.nuse--; + + luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); + } + else + { + L->global->strt.nuse--; + + luaM_free(L, ts, sizestring(ts->len), ts->memcat); + } } diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 3fd0bd39..290b64d8 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -20,7 +20,7 @@ LUAI_FUNC unsigned int luaS_hash(const char* str, size_t len); LUAI_FUNC void luaS_resize(lua_State* L, int newsize); LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); -LUAI_FUNC void luaS_free(lua_State* L, TString* ts); +LUAI_FUNC void luaS_free(lua_State* L, TString* ts, struct lua_Page* page); LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 83b59f3f..c57374e0 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -424,7 +424,7 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) Table* luaH_new(lua_State* L, int narray, int nhash) { - Table* t = luaM_new(L, Table, sizeof(Table), L->activememcat); + Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_link(L, t, LUA_TTABLE); t->metatable = NULL; t->flags = cast_byte(~0); @@ -443,12 +443,12 @@ Table* luaH_new(lua_State* L, int narray, int nhash) return t; } -void luaH_free(lua_State* L, Table* t) +void luaH_free(lua_State* L, Table* t, lua_Page* page) { if (t->node != dummynode) luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); - luaM_free(L, t, sizeof(Table), t->memcat); + luaM_freegco(L, t, sizeof(Table), t->memcat, page); } static LuaNode* getfreepos(Table* t) @@ -741,7 +741,7 @@ int luaH_getn(Table* t) Table* luaH_clone(lua_State* L, Table* tt) { - Table* t = luaM_new(L, Table, sizeof(Table), L->activememcat); + Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_link(L, t, LUA_TTABLE); t->metatable = tt->metatable; t->flags = tt->flags; diff --git a/VM/src/ltable.h b/VM/src/ltable.h index 45061443..e8413c85 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -20,7 +20,7 @@ LUAI_FUNC TValue* luaH_set(lua_State* L, Table* t, const TValue* key); LUAI_FUNC Table* luaH_new(lua_State* L, int narray, int lnhash); LUAI_FUNC void luaH_resizearray(lua_State* L, Table* t, int nasize); LUAI_FUNC void luaH_resizehash(lua_State* L, Table* t, int nhsize); -LUAI_FUNC void luaH_free(lua_State* L, Table* t); +LUAI_FUNC void luaH_free(lua_State* L, Table* t, struct lua_Page* page); LUAI_FUNC int luaH_next(lua_State* L, Table* t, StkId key); LUAI_FUNC int luaH_getn(Table* t); LUAI_FUNC Table* luaH_clone(lua_State* L, Table* tt); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index d180c388..758a9bdb 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -11,7 +11,7 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) { if (s > INT_MAX - sizeof(Udata)) luaM_toobig(L); - Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + Udata* u = luaM_newgco(L, Udata, sizeudata(s), L->activememcat); luaC_link(L, u, LUA_TUSERDATA); u->len = int(s); u->metatable = NULL; @@ -20,7 +20,7 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) return u; } -void luaU_freeudata(lua_State* L, Udata* u) +void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); @@ -33,5 +33,5 @@ void luaU_freeudata(lua_State* L, Udata* u) if (dtor) dtor(u->data); - luaM_free(L, u, sizeudata(u->len), u->memcat); + luaM_freegco(L, u, sizeudata(u->len), u->memcat, page); } diff --git a/VM/src/ludata.h b/VM/src/ludata.h index 59cb85bd..ec374c28 100644 --- a/VM/src/ludata.h +++ b/VM/src/ludata.h @@ -10,4 +10,4 @@ #define sizeudata(len) (offsetof(Udata, data) + len) LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); -LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u); +LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u, struct lua_Page* page); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index cebeeb58..e58ff2a8 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -496,7 +496,7 @@ static void luau_execute(lua_State* L) Instruction insn = *pc++; StkId ra = VM_REG(LUAU_INSN_A(insn)); - if (L->openupval && gco2uv(L->openupval)->v >= ra) + if (L->openupval && L->openupval->v >= ra) luaF_close(L, ra); VM_NEXT(); } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 8eed953f..3b0d677d 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -756,6 +756,30 @@ RETURN R0 1 )"); } +TEST_CASE("TableSizePredictionLoop") +{ + ScopedFastFlag sff("LuauPredictTableSizeLoop", true); + + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +for i=1,4 do + t[i] = 0 +end +return t +)"), + R"( +NEWTABLE R0 0 4 +LOADN R3 1 +LOADN R1 4 +LOADN R2 1 +FORNPREP R1 +3 +LOADN R4 0 +SETTABLE R4 R0 R3 +FORNLOOP R1 -3 +RETURN R0 1 +)"); +} + TEST_CASE("ReflectionEnums") { CHECK_EQ("\n" + compileFunction0("return Enum.EasingStyle.Linear"), R"( diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 1d13df28..5ad06f0d 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1396,6 +1396,8 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { + ScopedFastFlag sff("LuauLintTableCreateTable", true); + LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1416,9 +1418,12 @@ table.insert(t, string.find("hello", "h")) table.move(t, 0, #t, 1, tt) table.move(t, 1, #t, 0, tt) + +table.create(42, {}) +table.create(42, {} :: {}) )"); - REQUIRE_EQ(result.warnings.size(), 8); + REQUIRE_EQ(result.warnings.size(), 10); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1430,6 +1435,8 @@ table.move(t, 1, #t, 0, tt) "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ(result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e9135651..ac81005c 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1498,6 +1498,17 @@ return CHECK_EQ(std::string(str->value.data, str->value.size), "\n"); } +TEST_CASE_FIXTURE(Fixture, "parse_error_broken_comment") +{ + ScopedFastFlag luauStartingBrokenComment{"LuauStartingBrokenComment", true}; + + const char* expected = "Expected identifier when parsing expression, got unfinished comment"; + + matchParseError("--[[unfinished work", expected); + matchParseError("--!strict\n--[[unfinished work", expected); + matchParseError("local x = 1 --[[unfinished work", expected); +} + TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") { const char* expected = "String literal contains malformed escape sequence"; @@ -2333,7 +2344,7 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") ParseOptions options; options.captureComments = true; - ParseResult result = parseEx(R"( + ParseResult result = tryParse(R"( --[[ )", options); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 27cda146..644efed7 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2180,4 +2180,52 @@ b() CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } +TEST_CASE_FIXTURE(Fixture, "length_operator_union") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} | {string} +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_intersection") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} & {z:string} -- mixed tables are evil +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_non_table_union") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} | any | string +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} | number | string +local y = #x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7a056af5..7ee5253c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5129,4 +5129,33 @@ local c = a or b LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "bound_typepack_promote") +{ + ScopedFastFlag luauCommittingTxnLogFreeTpPromote{"LuauCommittingTxnLogFreeTpPromote", true}; + + // No assertions should trigger + check(R"( +local function p() + local this = {} + this.pf = foo() + function this:IsActive() end + function this:Start(o) end + return this +end + +local function h(tp, o) + ep = tp + tp:Start(o) + tp.pf.Connect(function() + ep:IsActive() + end) +end + +function on() + local t = p() + h(t) +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index f32d5bdc..7b057354 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -419,5 +419,20 @@ co = coroutine.create(function () return loadstring("return a")() end) +-- large closure size +do + local a1, a2, a3, a4, a5, a6, a7, a8, a9, a0 + local b1, b2, b3, b4, b5, b6, b7, b8, b9, b0 + local c1, c2, c3, c4, c5, c6, c7, c8, c9, c0 + local d1, d2, d3, d4, d5, d6, d7, d8, d9, d0 + + local f = function() + return + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a0 + + b1 + b2 + b3 + b4 + b5 + b6 + b7 + b8 + b9 + b0 + + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c0 + + d1 + d2 + d3 + d4 + d5 + d6 + d7 + d8 + d9 + d0 + end +end return 'OK' diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 6d9eb854..409cd224 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -291,4 +291,32 @@ do for i = 1,10 do table.insert(___Glob, newproxy(true)) end end +-- create threads that die together with their unmarked upvalues +do + local t = {} + + for i = 1,100 do + local c = coroutine.wrap(function() + local uv = {i + 1} + local function f() + return uv[1] * 10 + end + coroutine.yield(uv[1]) + uv = {i + 2} + coroutine.yield(f()) + end) + + assert(c() == i + 1) + table.insert(t, c) + end + + for i = 1,100 do + t[i] = nil + end + + collectgarbage() + +end + + return('OK') From 699660a4ebe33c582a71f73caacdc98440228b5d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:37:50 -0800 Subject: [PATCH 016/102] Fix MSVC warnings --- VM/src/lgc.cpp | 4 ++-- VM/src/lmem.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 50859b1e..9a8cb079 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -770,11 +770,11 @@ static int sweepgcopage(lua_State* L, lua_Page* page) { // if the last block was removed, page would be removed as well if (--busyBlocks == 0) - return (pos - start) / blockSize + 1; + return int(pos - start) / blockSize + 1; } } - return (end - start) / blockSize; + return int(end - start) / blockSize; } static size_t gcstep(lua_State* L, size_t limit) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 6d3b7772..7a31d6c8 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -520,7 +520,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) } else { - lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + int(nsize), int(nsize), 1); block = &page->data; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); From 0062000d4674c44bb145584b2baa531388c9f355 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:43:41 -0800 Subject: [PATCH 017/102] One more --- VM/src/lmem.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 7a31d6c8..beacca65 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -200,7 +200,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int global_State* g = L->global; - LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); if (!page) From 9c15f6a6d79d769f4ac6cf80b1521826035f4d76 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:52:48 -0800 Subject: [PATCH 018/102] And one more --- VM/src/lmem.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index beacca65..e1dbce50 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -376,7 +376,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); - LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); void* block; From 6e1e277cb8f19f18740259c830ef49920f76043d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 27 Jan 2022 13:29:34 -0800 Subject: [PATCH 019/102] Sync to upstream/release/512 --- Analysis/include/Luau/AstQuery.h | 15 + Analysis/include/Luau/Module.h | 8 + Analysis/include/Luau/TypeInfer.h | 21 +- Analysis/include/Luau/TypeVar.h | 8 +- Analysis/include/Luau/Unifier.h | 7 + Analysis/src/Autocomplete.cpp | 105 ++--- Analysis/src/Frontend.cpp | 7 +- Analysis/src/Module.cpp | 28 +- Analysis/src/ToString.cpp | 11 +- Analysis/src/TypeAttach.cpp | 2 +- Analysis/src/TypeInfer.cpp | 484 +++++++++++++--------- Analysis/src/TypeVar.cpp | 138 ++++--- Analysis/src/Unifier.cpp | 319 ++++++--------- CLI/Analyze.cpp | 40 +- CLI/FileUtils.cpp | 24 +- CLI/FileUtils.h | 1 + CLI/Repl.cpp | 17 +- CLI/Repl.h | 12 + CLI/ReplEntry.cpp | 10 + CMakeLists.txt | 12 + Compiler/include/Luau/Bytecode.h | 3 + Compiler/src/Builtins.cpp | 5 + Compiler/src/Compiler.cpp | 353 ++++++++++------ Makefile | 7 +- Sources.cmake | 18 +- VM/src/lapi.cpp | 13 +- VM/src/lbuiltins.cpp | 30 ++ VM/src/lcorolib.cpp | 5 - VM/src/ldo.cpp | 4 +- VM/src/lgc.cpp | 35 +- VM/src/lgc.h | 1 + VM/src/lmem.cpp | 6 +- VM/src/lstate.h | 3 - tests/AstQuery.test.cpp | 6 - tests/Autocomplete.test.cpp | 36 +- tests/Compiler.test.cpp | 131 ++++++ tests/Conformance.test.cpp | 2 - tests/Frontend.test.cpp | 1 - tests/Parser.test.cpp | 134 +++--- tests/Repl.test.cpp | 117 ++++++ tests/ToString.test.cpp | 4 - tests/TypeInfer.aliases.test.cpp | 4 - tests/TypeInfer.generics.test.cpp | 15 +- tests/TypeInfer.provisional.test.cpp | 28 +- tests/TypeInfer.refinements.test.cpp | 591 ++++++++++++++++----------- tests/TypeInfer.singletons.test.cpp | 6 - tests/TypeInfer.tables.test.cpp | 16 +- tests/TypeInfer.test.cpp | 64 +-- tests/TypeInfer.typePacks.cpp | 1 - tests/TypeInfer.unionTypes.test.cpp | 2 - tests/TypeVar.test.cpp | 72 +++- 51 files changed, 1836 insertions(+), 1146 deletions(-) create mode 100644 CLI/Repl.h create mode 100644 CLI/ReplEntry.cpp create mode 100644 tests/Repl.test.cpp diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index d38976ef..dfe373a5 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -42,6 +42,21 @@ struct ExprOrLocal { return expr ? expr->location : (local ? local->location : std::optional{}); } + std::optional getName() + { + if (expr) + { + if (AstName name = getIdentifier(expr); name.value) + { + return name; + } + } + else if (local) + { + return local->name; + } + return std::nullopt; + } private: AstExpr* expr = nullptr; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 2e41674b..1bf0473c 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -13,6 +13,8 @@ #include #include +LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation) + namespace Luau { @@ -58,6 +60,12 @@ struct TypeArena template TypeId addType(T tv) { + if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + } + return addTV(TypeVar(std::move(tv))); } diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index aa090014..b843509d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -135,7 +135,8 @@ struct TypeChecker void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr( + const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); @@ -160,14 +161,12 @@ struct TypeChecker // Returns the type of the lvalue. TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); - // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). - // Note: the binding may be null. - // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding - std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + // Returns the type of the lvalue. + TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, @@ -322,8 +321,6 @@ private: return addTV(TypeVar(tv)); } - TypeId addType(const UnionTypeVar& utv); - TypeId addTV(TypeVar&& tv); TypePackId addTypePack(TypePackVar&& tp); @@ -349,6 +346,8 @@ public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: + void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); + std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 3f5e26d6..11dc9377 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -111,16 +111,16 @@ struct PrimitiveTypeVar // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false -struct BoolSingleton +struct BooleanSingleton { bool value; - bool operator==(const BoolSingleton& rhs) const + bool operator==(const BooleanSingleton& rhs) const { return value == rhs.value; } - bool operator!=(const BoolSingleton& rhs) const + bool operator!=(const BooleanSingleton& rhs) const { return !(*this == rhs); } @@ -145,7 +145,7 @@ struct StringSingleton // No type for float singletons, partly because === isn't any equalivalence on floats // (NaN != NaN). -using SingletonVariant = Luau::Variant; +using SingletonVariant = Luau::Variant; struct SingletonTypeVar { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index a3be739a..1b1671c0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -85,6 +85,13 @@ public: Unifier makeChildUnifier(); + // A utility function that appends the given error to the unifier's error log. + // This allows setting a breakpoint wherever the unifier reports an error. + void reportError(TypeError error) + { + errors.push_back(error); + } + private: bool isNonstrictMode() const; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 7a801f97..85099e12 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,9 +14,9 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); +LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -194,8 +194,6 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) { - LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); - auto expr = node->asExpr(); if (!expr) return std::nullopt; @@ -266,43 +264,63 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - TypeId expectedType; + auto typeAtPosition = findExpectedTypeAt(module, node, position); - if (FFlag::LuauAutocompleteFirstArg) + if (!typeAtPosition) + return TypeCorrectKind::None; + + TypeId expectedType = follow(*typeAtPosition); + + if (FFlag::PreferToCallFunctionsForIntersects) { - auto typeAtPosition = findExpectedTypeAt(module, node, position); + auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto [retHead, retTail] = flatten(ftv->retType); - if (!typeAtPosition) - return TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return true; - expectedType = follow(*typeAtPosition); + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return true; + } + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } + } } else { - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; - - expectedType = follow(*it); - } - - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return TypeCorrectKind::CorrectFunctionResult; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; + } } } @@ -741,29 +759,12 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - TypeId expectedType; + auto typeAtPosition = findExpectedTypeAt(module, node, position); - if (FFlag::LuauAutocompleteFirstArg) - { - auto typeAtPosition = findExpectedTypeAt(module, node, position); + if (!typeAtPosition) + return std::nullopt; - if (!typeAtPosition) - return std::nullopt; - - expectedType = follow(*typeAtPosition); - } - else - { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; - - expectedType = follow(*it); - } + TypeId expectedType = follow(*typeAtPosition); if (get(expectedType)) return true; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index fe4b6529..9001b19d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) namespace Luau { @@ -102,8 +101,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy); + persist(globalTy); } for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) @@ -113,8 +111,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy.type); + persist(globalTy.type); } return LoadDefinitionFileResult{true, parseResult, checkedModule}; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 9f352f4b..4fdff8f7 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) +LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) + namespace Luau { @@ -377,14 +379,28 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.addType(UnionTypeVar{}); - seenTypes[typeId] = result; + if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) + { + std::vector options; + options.reserve(t.options.size()); - UnionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; + } + else + { + TypeId result = dest.addType(UnionTypeVar{}); + seenTypes[typeId] = result; + + UnionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.options) + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + } } void TypeCloner::operator()(const IntersectionTypeVar& t) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 4b898d3a..5e79b841 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAG(LuauTypeAliasDefaults) /* @@ -374,7 +373,7 @@ struct TypeVarStringifier void operator()(TypeId, const SingletonTypeVar& stv) { - if (const BoolSingleton* bs = Luau::get(&stv)) + if (const BooleanSingleton* bs = Luau::get(&stv)) state.emit(bs->value ? "true" : "false"); else if (const StringSingleton* ss = Luau::get(&stv)) { @@ -617,9 +616,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -675,9 +672,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 2ec02093..2208213f 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -97,7 +97,7 @@ public: AstType* operator()(const SingletonTypeVar& stv) { - if (const BoolSingleton* bs = get(&stv)) + if (const BooleanSingleton* bs = get(&stv)) return allocator->alloc(Location(), bs->value); else if (const StringSingleton* ss = get(&stv)) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e2d8a4fb..23fcc2d5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,8 +26,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) -LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) @@ -37,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -46,10 +45,8 @@ LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -1139,33 +1136,25 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - auto [leftType, leftTypeBinding] = checkLValueBinding(scope, *function.name); + TypeId leftType = checkLValueBinding(scope, *function.name); checkFunctionBody(funScope, ty, *function.func); unify(ty, leftType, function.location); - if (FFlag::LuauUpdateFunctionNameBinding) - { - LUAU_ASSERT(function.name->is() || function.name->is()); + LUAU_ASSERT(function.name->is() || function.name->is()); - if (auto exprIndexName = function.name->as()) + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) { - if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + if (auto ttv = getMutableTableType(*typeIt)) { - if (auto ttv = getMutableTableType(*typeIt)) - { - if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) - it->second.type = follow(quantify(funScope, leftType, function.name->location)); - } + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); } } } - else - { - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); - } } } @@ -1426,7 +1415,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1443,14 +1432,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) result = {singletonType(bexpr->value)}; else result = {booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else result = {stringType}; @@ -1488,15 +1477,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result.type = follow(result.type); - if (FFlag::LuauStoreMatchingOverloadFnType) - { - if (!currentModule->astTypes.find(&expr)) - currentModule->astTypes[&expr] = result.type; - } - else - { + if (!currentModule->astTypes.find(&expr)) currentModule->astTypes[&expr] = result.type; - } if (expectedType) currentModule->astExpectedTypes[&expr] = *expectedType; @@ -2242,7 +2224,6 @@ TypeId TypeChecker::checkRelationalOperation( state.log.commit(); } - bool needsMetamethod = !isEquality; TypeId leftType = follow(lhsType); @@ -2250,10 +2231,11 @@ TypeId TypeChecker::checkRelationalOperation( { reportErrors(state.errors); - const PrimitiveTypeVar* ptv = get(leftType); - if (!isEquality && state.errors.empty() && (get(leftType) || (ptv && ptv->type == PrimitiveTypeVar::Boolean))) + if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + { reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())}); + } return booleanType; } @@ -2501,7 +2483,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult rhs = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type), + {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } else if (expr.op == AstExprBinary::Or) { @@ -2513,7 +2496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult rhs = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2521,8 +2504,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); PredicateVec predicates; @@ -2621,11 +2604,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) { - auto [ty, binding] = checkLValueBinding(scope, expr); - return ty; + return checkLValueBinding(scope, expr); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkLValueBinding(scope, *a); @@ -2639,22 +2621,22 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return {errorRecoveryType(scope), nullptr}; + return errorRecoveryType(scope); } else ice("Unexpected AST node in checkLValue", expr.location); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return {*ty, nullptr}; + return *ty; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorRecoveryType(scope), nullptr}; + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) { Name name = expr.name.value; ScopePtr moduleScope = currentModule->getModuleScope(); @@ -2662,7 +2644,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return std::pair(it->second.typeId, &it->second.typeId); + return it->second.typeId; TypeId result = freshType(scope); Binding& binding = moduleScope->bindings[expr.name]; @@ -2673,15 +2655,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!isNonstrictMode()) reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(result, &binding.typeId); + return result; } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) { TypeId lhs = checkExpr(scope, *expr.expr).type; if (get(lhs) || get(lhs)) - return std::pair(lhs, nullptr); + return lhs; tablify(lhs); @@ -2694,7 +2676,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { @@ -2702,7 +2684,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; - return std::pair(theType, &property.type); + return theType; } else if (auto indexer = lhsTable->indexer) { @@ -2720,17 +2702,17 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope else if (FFlag::LuauUseCommittingTxnLog) state.log.commit(); - return std::pair(retType, nullptr); + return retType; } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2739,29 +2721,29 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } else if (get(lhs)) { if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) - return std::pair(*ty, nullptr); + return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); @@ -2771,7 +2753,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId indexType = checkExpr(scope, *expr.index).type; if (get(exprType) || get(exprType)) - return std::pair(exprType, nullptr); + return exprType; AstExprConstantString* value = expr.index->as(); @@ -2783,9 +2765,9 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } } @@ -2794,7 +2776,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } if (value) @@ -2802,7 +2784,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = exprTable->props.find(value->value.data); if (it != exprTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { @@ -2810,7 +2792,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; - return std::pair(resultType, &property.type); + return resultType; } } @@ -2818,18 +2800,18 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { const TableIndexer& indexer = *exprTable->indexer; unify(indexType, indexer.indexType, expr.index->location); - return std::pair(indexer.indexResultType, nullptr); + return indexer.indexResultType; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return std::pair(resultType, nullptr); + return resultType; } else { TypeId resultType = freshType(scope); - return std::pair(resultType, nullptr); + return resultType; } } @@ -3326,7 +3308,7 @@ void TypeChecker::checkArgumentList( } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } ++paramIter; @@ -3348,7 +3330,7 @@ void TypeChecker::checkArgumentList( Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3405,7 +3387,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } } @@ -3520,7 +3502,7 @@ void TypeChecker::checkArgumentList( } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } ++paramIter; @@ -3540,7 +3522,7 @@ void TypeChecker::checkArgumentList( Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } TypePackId tail = *paramIter.tail(); @@ -3606,7 +3588,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } } @@ -3825,22 +3807,11 @@ std::optional> TypeChecker::checkCallOverload(const Scope metaArgLocations = *argLocations; metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - if (FFlag::LuauFixRecursiveMetatableCall) - { - fn = instantiate(scope, *ty, expr.func->location); + fn = instantiate(scope, *ty, expr.func->location); - argPack = metaCallArgPack; - args = metaCallArgs; - argLocations = &metaArgLocations; - } - else - { - TypeId fn = *ty; - fn = instantiate(scope, fn, expr.func->location); - - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, - overloadsThatMatchArgCount, overloadsThatDont, errors); - } + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; } } @@ -3932,8 +3903,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope } } - if (FFlag::LuauStoreMatchingOverloadFnType) - currentModule->astOverloadResolvedTypes[&expr] = fn; + currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload return {{retPack}}; @@ -4776,7 +4746,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { // TODO: cache singleton types - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); } TypeId TypeChecker::singletonType(std::string value) @@ -4813,13 +4783,6 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -TypeId TypeChecker::addType(const UnionTypeVar& utv) -{ - LUAU_ASSERT(utv.options.size() > 1); - - return addTV(TypeVar(utv)); -} - TypeId TypeChecker::addTV(TypeVar&& tv) { return currentModule->internalTypes.addType(std::move(tv)); @@ -5347,54 +5310,35 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; - if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) + // TODO: CLI-46926 it's not a good idea to rename the type here + TypeId target = follow(instantiated); + bool needsClone = follow(tf.type) == target; + TableTypeVar* ttv = getMutableTableType(target); + + if (ttv && needsClone) { - // TODO: CLI-46926 it's not a good idea to rename the type here - TypeId target = follow(instantiated); - bool needsClone = follow(tf.type) == target; - TableTypeVar* ttv = getMutableTableType(target); - - if (ttv && needsClone) + // Substitution::clone is a shallow clone. If this is a metatable type, we + // want to mutate its table, so we need to explicitly clone that table as + // well. If we don't, we will mutate another module's type surface and cause + // a use-after-free. + if (get(target)) { - // Substitution::clone is a shallow clone. If this is a metatable type, we - // want to mutate its table, so we need to explicitly clone that table as - // well. If we don't, we will mutate another module's type surface and cause - // a use-after-free. - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - MetatableTypeVar* mtv = getMutable(instantiated); - mtv->table = applyTypeFunction.clone(mtv->table); - ttv = getMutable(mtv->table); - } - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutable(instantiated); - } + instantiated = applyTypeFunction.clone(tf.type); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); } - - if (ttv) + if (get(target)) { - ttv->instantiatedTypeParams = typeParams; - ttv->instantiatedTypePackParams = typePackParams; + instantiated = applyTypeFunction.clone(tf.type); + ttv = getMutable(instantiated); } } - else - { - if (TableTypeVar* ttv = getMutableTableType(instantiated)) - { - if (follow(tf.type) == instantiated) - { - // This can happen if a type alias has generics that it does not use at all. - // ex type FooBar = { a: number } - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutableTableType(instantiated); - } - ttv->instantiatedTypeParams = typeParams; - ttv->instantiatedTypePackParams = typePackParams; - } + if (ttv) + { + ttv->instantiatedTypeParams = typeParams; + ttv->instantiatedTypePackParams = typePackParams; } return instantiated; @@ -5482,6 +5426,85 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st return {generics, genericPacks}; } +void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) +{ + LUAU_ASSERT(FFlag::LuauDiscriminableUnions); + + const LValue* target = &lvalue; + std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. + + auto ty = resolveLValue(scope, *target); + if (!ty) + return; // Do nothing. An error was already reported. + + // If the provided lvalue is a local or global, then that's without a doubt the target. + // However, if there is a base lvalue, then we'll want that to be the target iff the base is a union type. + if (auto base = baseof(lvalue)) + { + std::optional baseTy = resolveLValue(scope, *base); + if (baseTy && get(follow(*baseTy))) + { + ty = baseTy; + target = base; + key = lvalue; + } + } + + // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. + if (!key) + { + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + + return; + } + + // Otherwise, we'll want to walk each option of ty, get its index type, and filter that. + auto utv = get(follow(*ty)); + LUAU_ASSERT(utv); + + std::unordered_set viableTargetOptions; + std::unordered_set viableChildOptions; // There may be additional refinements that apply. We add those here too. + + for (TypeId option : utv) + { + std::optional discriminantTy; + if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false); + else + LUAU_ASSERT(!"Unhandled LValue alternative?"); + + if (!discriminantTy) + return; // Do nothing. An error was already reported, as per usual. + + if (std::optional result = filterMap(*discriminantTy, predicate)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + + auto intoType = [this](const std::unordered_set& s) -> std::optional { + if (s.empty()) + return std::nullopt; + + // TODO: allocate UnionTypeVar and just normalize. + std::vector options(s.begin(), s.end()); + if (options.size() == 1) + return options[0]; + + return addType(UnionTypeVar{std::move(options)}); + }; + + if (std::optional viableTargetType = intoType(viableTargetOptions)) + addRefinement(refis, *target, *viableTargetType); + + if (std::optional viableChildType = intoType(viableChildOptions)) + addRefinement(refis, lvalue, *viableChildType); +} + std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { if (!FFlag::LuauLValueAsKey) @@ -5645,18 +5668,29 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi return std::nullopt; }; - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; + if (FFlag::LuauDiscriminableUnions) + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + refineLValue(truthyP.lvalue, refis, scope, predicate); + } + else + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (!ty) + return; - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); + // This is a hack. :( + // Without this, the expression 'a or b' might refine 'b' to be falsy. + // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. + if (fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, truthyP.lvalue, *result); + } } void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5746,16 +5780,23 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + refineLValue(isaP.lvalue, refis, scope, predicate); + } else { - addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + std::optional ty = resolveLValue(refis, scope, isaP.lvalue); + if (!ty) + return; + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, isaP.lvalue, *result); + else + { + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); + errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + } } } @@ -5814,21 +5855,30 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); + return; + } else { - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); - } + if (std::optional result = filterMap(*ty, it->second(sense))) + addRefinement(refis, typeguardP.lvalue, *result); + else + { + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); + if (sense) + errVec.push_back( + TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); + } - return; + return; + } } auto fail = [&](const TypeErrorData& err) { - errVec.push_back(TypeError{typeguardP.location, err}); + if (!FFlag::LuauDiscriminableUnions) + errVec.push_back(TypeError{typeguardP.location, err}); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; @@ -5853,55 +5903,85 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. - auto options = [](TypeId ty) -> std::vector { if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; }; - if (FFlag::LuauWeakEqConstraint) + if (FFlag::LuauDiscriminableUnions) { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); - - return; - } - - if (FFlag::LuauEqConstraint) - { - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; - - std::vector lhs = options(*ty); std::vector rhs = options(eqP.type); - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + auto predicate = [&](TypeId option) -> std::optional { + if (sense && isUndecidable(option)) + return FFlag::LuauWeakEqConstraint ? option : eqP.type; + + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; + + if (maybeSingleton(eqP.type)) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. + if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + return sense ? eqP.type : option; + + return std::nullopt; } + + return option; + }; + + refineLValue(eqP.lvalue, refis, scope, predicate); + } + else + { + if (FFlag::LuauWeakEqConstraint) + { + if (!sense && isNil(eqP.type)) + resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + + return; } - if (set.empty()) - return; + if (FFlag::LuauEqConstraint) + { + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); + + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + { + addRefinement(refis, eqP.lvalue, eqP.type); + return; + } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + + std::unordered_set set; + for (TypeId left : lhs) + { + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } + } + + if (set.empty()) + return; + + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); + } } } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index df5d76ed..5b162b31 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -18,14 +18,15 @@ #include #include +LUAU_FASTFLAG(DebugLuauFreezeArena) + LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) -LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -144,7 +145,20 @@ bool isNil(TypeId ty) bool isBoolean(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::Boolean); + if (FFlag::LuauRefactorTypeVarQuestions) + { + if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isBoolean); + + return false; + } + else + { + return isPrim(ty, PrimitiveTypeVar::Boolean); + } } bool isNumber(TypeId ty) @@ -154,7 +168,20 @@ bool isNumber(TypeId ty) bool isString(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::String); + if (FFlag::LuauRefactorTypeVarQuestions) + { + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isString); + + return false; + } + else + { + return isPrim(ty, PrimitiveTypeVar::String); + } } bool isThread(TypeId ty) @@ -167,37 +194,45 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - if (!get(follow(ty))) - return false; - - std::unordered_set seen; - std::deque queue{ty}; - while (!queue.empty()) + if (FFlag::LuauRefactorTypeVarQuestions) { - TypeId current = follow(queue.front()); - queue.pop_front(); + auto utv = get(follow(ty)); + if (!utv) + return false; - if (seen.count(current)) - continue; - - seen.insert(current); - - if (isNil(current)) - return true; - - if (auto u = get(current)) + return std::any_of(begin(utv), end(utv), isNil); + } + else + { + std::unordered_set seen; + std::deque queue{ty}; + while (!queue.empty()) { - for (TypeId option : u->options) - { - if (isNil(option)) - return true; + TypeId current = follow(queue.front()); + queue.pop_front(); - queue.push_back(option); + if (seen.count(current)) + continue; + + seen.insert(current); + + if (isNil(current)) + return true; + + if (auto u = get(current)) + { + for (TypeId option : u->options) + { + if (isNil(option)) + return true; + + queue.push_back(option); + } } } - } - return false; + return false; + } } bool isTableIntersection(TypeId ty) @@ -228,13 +263,27 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) + else if (FFlag::LuauRefactorTypeVarQuestions) { - LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); - return primitiveType->metatable; + if (isString(type)) + { + auto ptv = get(getSingletonTypes().stringType); + LUAU_ASSERT(ptv && ptv->metatable); + return ptv->metatable; + } + else + return std::nullopt; } else - return std::nullopt; + { + if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) + { + LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); + return primitiveType->metatable; + } + else + return std::nullopt; + } } const TableTypeVar* getTableType(TypeId type) @@ -696,7 +745,7 @@ TypeId SingletonTypes::makeStringMetatable() {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalString}, {}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, {"pack", {arena->addType(FunctionTypeVar{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), @@ -1108,30 +1157,14 @@ static Tags* getTags(TypeId ty) void attachTag(TypeId ty, const std::string& tagName) { - if (!FFlag::LuauRefactorTagging) - { - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(tagName); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } - } + if (auto tags = getTags(ty)) + tags->push_back(tagName); else - { - if (auto tags = getTags(ty)) - tags->push_back(tagName); - else - LUAU_ASSERT(!"This TypeId does not support tags"); - } + LUAU_ASSERT(!"This TypeId does not support tags"); } void attachTag(Property& prop, const std::string& tagName) { - LUAU_ASSERT(FFlag::LuauRefactorTagging); - prop.tags.push_back(tagName); } @@ -1140,7 +1173,6 @@ void attachTag(Property& prop, const std::string& tagName) // Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. bool hasTag(const Tags& tags, const std::string& tagName) { - LUAU_ASSERT(FFlag::LuauRefactorTagging); return std::find(tags.begin(), tags.end(), tagName) != tags.end(); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 2bd9cf83..17d9bf58 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,15 +17,11 @@ LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { @@ -229,8 +225,6 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); - type = follow(type); if (auto ttv = get(type)) @@ -291,7 +285,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(TypeError{location, UnificationTooComplex{}}); return; } @@ -403,7 +397,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); return; } @@ -448,7 +442,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } @@ -561,13 +555,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (failed) { if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } } else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) @@ -582,50 +576,44 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool bool foundHeuristic = false; size_t startIndex = 0; - if (FFlag::LuauUnionHeuristic) + if (const std::string* subName = getName(subTy)) { - if (const std::string* subName = getName(subTy)) + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) { - const std::string* optionName = getName(uv->options[i]); - if (optionName && *optionName == *subName) - { - foundHeuristic = true; - startIndex = i; - break; - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (FFlag::LuauExtendedUnionMismatchError) + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - if (auto subMatchTag = getTableMatchTag(subTy)) + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) { - for (size_t i = 0; i < uv->options.size(); ++i) - { - auto optionMatchTag = getTableMatchTag(uv->options[i]); - if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) - { - foundHeuristic = true; - startIndex = i; - break; - } - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (!foundHeuristic && cacheEnabled) + if (!foundHeuristic && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) - { - TypeId type = uv->options[i]; + TypeId type = uv->options[i]; - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + { + startIndex = i; + break; } } } @@ -650,7 +638,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { unificationTooComplex = e; } - else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + else if (!isNil(type)) { failedOptionCount++; @@ -664,15 +652,15 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (unificationTooComplex) { - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); } else if (!found) { - if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) - errors.push_back( + if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + reportError( TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } else if (const IntersectionTypeVar* uv = @@ -702,9 +690,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else if (const IntersectionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) @@ -754,10 +742,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (!found) { - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || @@ -801,7 +789,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); if (FFlag::LuauUseCommittingTxnLog) log.popSeen(superTy, subTy); @@ -1067,7 +1055,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(TypeError{location, UnificationTooComplex{}}); return; } @@ -1166,7 +1154,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(*subIter, *superIter); - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + if (!errors.empty() && !firstPackErrorPos) firstPackErrorPos = loopCount; superIter.advance(); @@ -1251,7 +1239,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); while (superIter.good()) { @@ -1272,7 +1260,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } else @@ -1372,7 +1360,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(*subIter, *superIter); - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + if (!errors.empty() && !firstPackErrorPos) firstPackErrorPos = loopCount; superIter.advance(); @@ -1459,7 +1447,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); while (superIter.good()) { @@ -1480,7 +1468,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } } @@ -1493,7 +1481,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1508,13 +1496,13 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superSingleton && *superSingleton == *subSingleton) return; - if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) return; if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1536,10 +1524,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - if (FFlag::LuauExtendedFunctionMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); } size_t numGenericPacks = superFunction->genericPacks.size(); @@ -1547,10 +1532,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - if (FFlag::LuauExtendedFunctionMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); } for (size_t i = 0; i < numGenerics; i++) @@ -1567,48 +1549,35 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - if (FFlag::LuauExtendedFunctionMismatchError) + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + + bool reported = !innerState.errors.empty(); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + reportError( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(subFunction->retType, superFunction->retType); + + if (!reported) { - innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - - bool reported = !innerState.errors.empty(); - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); + reportError(*e); + else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - errors.push_back( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + reportError( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()}}); else if (!innerState.errors.empty()) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - - innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); - - if (!reported) - { - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - errors.push_back( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); - else if (!innerState.errors.empty()) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - } - } - else - { - ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - - ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); } if (FFlag::LuauUseCommittingTxnLog) @@ -1716,7 +1685,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } } @@ -1734,7 +1703,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } } @@ -1957,13 +1926,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } @@ -2051,7 +2020,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt return tryUnifySealedTables(subTy, superTy, isIntersection); else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not { TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; @@ -2090,7 +2059,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt { const auto& r = subTable->props.find(name); if (r == subTable->props.end()) - errors.push_back(TypeError{location, UnknownProperty{subTy, name}}); + reportError(TypeError{location, UnknownProperty{subTy, name}}); else tryUnify_(r->second.type, prop.type); } @@ -2113,7 +2082,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt } } else - errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); + reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); } } else if (superTable->state == TableState::Sealed) @@ -2194,7 +2163,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) } } else - errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}}); + reportError(TypeError{location, UnknownProperty{subTy, freeName}}); } } @@ -2268,7 +2237,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } } @@ -2284,7 +2253,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec missingPropertiesInSuper.push_back(it.first); - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -2299,7 +2268,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (oldErrorSize != innerState.errors.size() && !errorReported) { errorReported = true; - errors.push_back(innerState.errors.back()); + reportError(innerState.errors.back()); } } else @@ -2340,7 +2309,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } else - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -2369,7 +2338,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } else - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } } @@ -2386,7 +2355,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } @@ -2413,7 +2382,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!extraPropertiesInSub.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } } @@ -2437,9 +2406,9 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); + reportError(*e); else if (!innerState.errors.empty()) - errors.push_back( + reportError( TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); if (FFlag::LuauUseCommittingTxnLog) @@ -2470,7 +2439,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) case TableState::Sealed: case TableState::Unsealed: case TableState::Generic: - errors.push_back(mismatchError); + reportError(mismatchError); } } else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) @@ -2479,7 +2448,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) } else { - errors.push_back(mismatchError); + reportError(mismatchError); } } @@ -2491,9 +2460,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); else - errors.push_back(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(TypeError{location, TypeMismatch{subTy, superTy}}); }; const ClassTypeVar* superClass = get(superTy); @@ -2538,7 +2507,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(TypeError{location, UnknownProperty{superTy, propName}}); } else { @@ -2577,7 +2546,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - errors.push_back(TypeError{location, GenericError{msg}}); + reportError(TypeError{location, GenericError{msg}}); } if (!ok) @@ -2695,7 +2664,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - errors.push_back(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); } else if (get(tail)) { @@ -2709,7 +2678,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); } } @@ -2886,7 +2855,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); log.replace(needle, *getSingletonTypes().errorRecoveryType()); return; @@ -2894,17 +2863,6 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (log.getMutable(haystack)) return; - else if (auto a = log.getMutable(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it) - check(*it); - - for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it) - check(*it); - } - } else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->options) @@ -2934,7 +2892,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); DEPRECATED_log(needle); *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; @@ -2942,17 +2900,6 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (get(haystack)) return; - else if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypeId ty : a->argTypes) - check(ty); - - for (TypeId ty : a->retType) - check(ty); - } - } else if (auto a = get(haystack)) { for (TypeId ty : a->options) @@ -2988,7 +2935,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (log.getMutable(needle)) return; - if (!get(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -2997,32 +2944,18 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); return; } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - for (const auto& ty : a->head) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - if (auto f = log.getMutable(log.follow(ty))) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = log.follow(*a->tail); + continue; } + break; } } @@ -3048,31 +2981,17 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); DEPRECATED_log(needle); *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (const auto& ty : a->head) - { - if (auto f = get(follow(ty))) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = follow(*a->tail); + continue; } + break; } } @@ -3094,17 +3013,17 @@ bool Unifier::isNonstrictMode() const void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back( + reportError( TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index e0dc3e0f..10cf17d2 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -43,14 +43,14 @@ static void report(ReportFormat format, const char* name, const Luau::Location& } } -static void reportError(ReportFormat format, const Luau::TypeError& error) +static void reportError(const Luau::Frontend& frontend, ReportFormat format, const Luau::TypeError& error) { - const char* name = error.moduleName.c_str(); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(error.moduleName); if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) - report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); + report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); else - report(format, name, error.location, "TypeError", Luau::toString(error).c_str()); + report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) @@ -72,14 +72,15 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, error); + reportError(frontend, format, error); Luau::LintResult lr = frontend.lint(name); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); for (auto& error : lr.errors) - reportWarning(format, name, error); + reportWarning(format, humanReadableName.c_str(), error); for (auto& warning : lr.warnings) - reportWarning(format, name, warning); + reportWarning(format, humanReadableName.c_str(), warning); if (annotate) { @@ -120,11 +121,25 @@ struct CliFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override { - std::optional source = readFile(name); + Luau::SourceCode::Type sourceType; + std::optional source = std::nullopt; + + // If the module name is "-", then read source from stdin + if (name == "-") + { + source = readStdin(); + sourceType = Luau::SourceCode::Script; + } + else + { + source = readFile(name); + sourceType = Luau::SourceCode::Module; + } + if (!source) return std::nullopt; - return Luau::SourceCode{*source, Luau::SourceCode::Module}; + return Luau::SourceCode{*source, sourceType}; } std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override @@ -143,6 +158,13 @@ struct CliFileResolver : Luau::FileResolver return std::nullopt; } + + std::string getHumanReadableModuleName(const Luau::ModuleName& name) const override + { + if (name == "-") + return "stdin"; + return name; + } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index cb993dfe..c6807022 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -74,6 +74,21 @@ std::optional readFile(const std::string& name) return result; } +std::optional readStdin() +{ + std::string result; + char buffer[4096] = { }; + + while (fgets(buffer, sizeof(buffer), stdin) != nullptr) + result.append(buffer); + + // If eof was not reached for stdin, then a read error occurred + if (!feof(stdin)) + return std::nullopt; + + return result; +} + template static void joinPaths(std::basic_string& str, const Ch* lhs, const Ch* rhs) { @@ -190,7 +205,10 @@ bool traverseDirectory(const std::string& path, const std::function getSourceFiles(int argc, char** argv) for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') + // Treat '-' as a special file whose source is read from stdin + // All other arguments that start with '-' are skipped + if (argv[i][0] == '-' && argv[i][1] != '\0') continue; if (isDirectory(argv[i])) diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index da11f512..97471cdc 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -7,6 +7,7 @@ #include std::optional readFile(const std::string& name); +std::optional readStdin(); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index e5042152..ab0f0ed0 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -158,7 +158,7 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } -static void setupState(lua_State* L) +void setupState(lua_State* L) { luaL_openlibs(L); @@ -176,7 +176,7 @@ static void setupState(lua_State* L) luaL_sandbox(L); } -static std::string runCode(lua_State* L, const std::string& source) +std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source, copts()); @@ -206,7 +206,13 @@ static std::string runCode(lua_State* L, const std::string& source) if (n) { luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); - lua_getglobal(T, "print"); + lua_getglobal(T, "_PRETTYPRINT"); + // If _PRETTYPRINT is nil, then use the standard print function instead + if (lua_isnil(T, -1)) + { + lua_pop(T, 1); + lua_getglobal(T, "print"); + } lua_insert(T, 1); lua_pcall(T, n, 0, 0); } @@ -545,7 +551,7 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } -int main(int argc, char** argv) +int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -696,7 +702,6 @@ int main(int argc, char** argv) case CliMode::Unknown: default: LUAU_ASSERT(!"Unhandled cli mode."); + return 1; } } - - diff --git a/CLI/Repl.h b/CLI/Repl.h new file mode 100644 index 00000000..11a077ae --- /dev/null +++ b/CLI/Repl.h @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" + +#include + +// Note: These are internal functions which are being exposed in a header +// so they can be included by unit tests. +int replMain(int argc, char** argv); +void setupState(lua_State* L); +std::string runCode(lua_State* L, const std::string& source); diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp new file mode 100644 index 00000000..b3131712 --- /dev/null +++ b/CLI/ReplEntry.cpp @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Repl.h" + + + +int main(int argc, char** argv) +{ + return replMain(argc, argv); +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 77cf47e8..b9f7a9e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,7 @@ endif() if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) + add_executable(Luau.CLI.Test) endif() if(LUAU_BUILD_WEB) @@ -109,6 +110,17 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) + + target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) + target_include_directories(Luau.CLI.Test PRIVATE extern CLI) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM) + if(UNIX) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.CLI.Test PRIVATE pthread) + endif() + endif() + endif() if(LUAU_BUILD_WEB) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index d9694d7d..679712f6 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -472,6 +472,9 @@ enum LuauBuiltinFunction // bit32.count LBF_BIT32_COUNTLZ, LBF_BIT32_COUNTRZ, + + // select(_, ...) + LBF_SELECT_VARARG, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index e344eb91..a907271c 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin, false) + namespace Luau { namespace Compile @@ -62,6 +64,9 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; + if (FFlag::LuauCompileSelectBuiltin && builtin.isGlobal("select")) + return LBF_SELECT_VARARG; + if (builtin.object == "math") { if (builtin.method == "abs") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9758c4a9..7da85244 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,6 +15,9 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) +LUAU_FASTFLAG(LuauCompileSelectBuiltin) + namespace Luau { @@ -261,6 +264,122 @@ struct Compiler bytecode.emitABC(LOP_GETVARARGS, target, multRet ? 0 : uint8_t(targetCount + 1), 0); } + void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) + { + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(targetCount == 1); + LUAU_ASSERT(!expr->self); + LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); + + AstExpr* arg = expr->args.data[0]; + + uint8_t argreg; + + if (isExprLocalReg(arg)) + argreg = getLocal(arg->as()->local); + else + { + argreg = uint8_t(regs + 1); + compileExprTempTop(arg, argreg); + } + + size_t fastcallLabel = bytecode.emitLabel(); + + bytecode.emitABC(LOP_FASTCALL1, LBF_SELECT_VARARG, argreg, 0); + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + bytecode.emitABC(LOP_GETVARARGS, uint8_t(regs + 2), 0, 0); + + size_t callLabel = bytecode.emitLabel(); + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + + // note, this is always multCall (last argument is variadic) + bytecode.emitABC(LOP_CALL, regs, 0, multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + + void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid) + { + LUAU_ASSERT(!expr->self); + LUAU_ASSERT(expr->args.size <= 2); + + LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; + + uint32_t args[2] = {}; + + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0) + { + if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) + { + opc = LOP_FASTCALL2K; + args[i] = cid; + break; + } + } + + if (isExprLocalReg(expr->args.data[i])) + args[i] = getLocal(expr->args.data[i]->as()->local); + else + { + args[i] = uint8_t(regs + 1 + i); + compileExprTempTop(expr->args.data[i], uint8_t(args[i])); + } + } + + size_t fastcallLabel = bytecode.emitLabel(); + + bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); + if (opc != LOP_FASTCALL1) + bytecode.emitAux(args[1]); + + // Set up a traditional Lua stack for the subsequent LOP_CALL. + // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for + // these FASTCALL variants. + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0 && opc == LOP_FASTCALL2K) + { + emitLoadK(uint8_t(regs + 1 + i), args[i]); + break; + } + + if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); + } + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + size_t callLabel = bytecode.emitLabel(); + + // FASTCALL will skip over the instructions needed to compute function and jump over CALL which must immediately follow the instruction + // sequence after FASTCALL + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + + bytecode.emitABC(LOP_CALL, regs, uint8_t(expr->args.size + 1), multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) { LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); @@ -284,6 +403,25 @@ struct Compiler bfid = getBuiltinFunctionId(builtin, options); } + if (bfid == LBF_SELECT_VARARG) + { + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly + // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases + if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) + return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs); + else + bfid = -1; + } + + // Optimization: for 1/2 argument fast calls use specialized opcodes + if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + { + AstExpr* last = expr->args.data[expr->args.size - 1]; + if (!last->is() && !last->is()) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } + if (expr->self) { AstExprIndexName* fi = expr->func->as(); @@ -309,24 +447,13 @@ struct Compiler compileExprTempTop(expr->func, regs); } - // Note: if the last argument is ExprVararg or ExprCall, we need to route that directly to the called function preserving the # of args bool multCall = false; - bool skipArgs = false; - if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) - { - AstExpr* last = expr->args.data[expr->args.size - 1]; - skipArgs = !(last->is() || last->is()); - } - - if (!skipArgs) - { - for (size_t i = 0; i < expr->args.size; ++i) - if (i + 1 == expr->args.size) - multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); - else - compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); - } + for (size_t i = 0; i < expr->args.size; ++i) + if (i + 1 == expr->args.size) + multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); + else + compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); setDebugLineEnd(expr->func); @@ -347,59 +474,8 @@ struct Compiler } else if (bfid >= 0) { - size_t fastcallLabel; - - if (skipArgs) - { - LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; - - uint32_t args[2] = {}; - for (size_t i = 0; i < expr->args.size; ++i) - { - if (i > 0) - { - if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) - { - opc = LOP_FASTCALL2K; - args[i] = cid; - break; - } - } - - if (isExprLocalReg(expr->args.data[i])) - args[i] = getLocal(expr->args.data[i]->as()->local); - else - { - args[i] = uint8_t(regs + 1 + i); - compileExprTempTop(expr->args.data[i], uint8_t(args[i])); - } - } - - fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); - if (opc != LOP_FASTCALL1) - bytecode.emitAux(args[1]); - - // Set up a traditional Lua stack for the subsequent LOP_CALL. - // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for - // these FASTCALL variants. - for (size_t i = 0; i < expr->args.size; ++i) - { - if (i > 0 && opc == LOP_FASTCALL2K) - { - emitLoadK(uint8_t(regs + 1 + i), args[i]); - break; - } - - if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); - } - } - else - { - fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); - } + size_t fastcallLabel = bytecode.emitLabel(); + bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); // note, these instructions are normally not executed and are used as a fallback for FASTCALL // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten @@ -1101,9 +1177,20 @@ struct Compiler for (size_t i = 0; i < expr->items.size; ++i) { const AstExprTable::Item& item = expr->items.data[i]; - AstExprConstantNumber* ckey = item.key->as(); + LUAU_ASSERT(item.key); // no list portion => all items have keys - indexSize += (ckey && ckey->value == double(indexSize + 1)); + if (FFlag::LuauCompileTableIndexOpt) + { + const Constant* ckey = constants.find(item.key); + + indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); + } + else + { + AstExprConstantNumber* ckey = item.key->as(); + + indexSize += (ckey && ckey->value == double(indexSize + 1)); + } } // we only perform the optimization if we don't have any other []-keys @@ -1200,37 +1287,47 @@ struct Compiler arrayChunkCurrent = 0; } - // items with a key are set one by one via SETTABLE/SETTABLEKS + // items with a key are set one by one via SETTABLE/SETTABLEKS/SETTABLEN if (key) { RegScope rsi(this); - // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax - if (AstExprConstantString* ckey = key->as()) + if (FFlag::LuauCompileTableIndexOpt) { - BytecodeBuilder::StringRef cname = sref(ckey->value); - int32_t cid = bytecode.addConstantString(cname); - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - + LValue lv = compileLValueIndex(reg, key, rsi); uint8_t rv = compileExprAuto(value, rsi); - bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); - bytecode.emitAux(cid); - } - else if (AstExprConstantNumber* ckey = key->as(); - ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) - { - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + compileAssign(lv, rv); } else { - uint8_t rk = compileExprAuto(key, rsi); - uint8_t rv = compileExprAuto(value, rsi); + // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax + if (AstExprConstantString* ckey = key->as()) + { + BytecodeBuilder::StringRef cname = sref(ckey->value); + int32_t cid = bytecode.addConstantString(cname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); + bytecode.emitAux(cid); + } + else if (AstExprConstantNumber* ckey = key->as(); + ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) + { + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + } + else + { + uint8_t rk = compileExprAuto(key, rsi); + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + } } } // items without a key are set using SETLIST so that we can initialize large arrays quickly @@ -1339,6 +1436,9 @@ struct Compiler uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(expr->index); + bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } else if (cv && cv->type == Constant::Type_String) @@ -1350,6 +1450,9 @@ struct Compiler if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(expr->index); + bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); } @@ -1657,6 +1760,40 @@ struct Compiler Location location; }; + LValue compileLValueIndex(uint8_t reg, AstExpr* index, RegScope& rs) + { + const Constant* cv = constants.find(index); + + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) + { + LValue result = {LValue::Kind_IndexNumber}; + result.reg = reg; + result.number = uint8_t(int(cv->valueNumber) - 1); + result.location = index->location; + + return result; + } + else if (cv && cv->type == Constant::Type_String) + { + LValue result = {LValue::Kind_IndexName}; + result.reg = reg; + result.name = sref(cv->getString()); + result.location = index->location; + + return result; + } + else + { + LValue result = {LValue::Kind_IndexExpr}; + result.reg = reg; + result.index = compileExprAuto(index, rs); + result.location = index->location; + + return result; + } + } + LValue compileLValue(AstExpr* node, RegScope& rs) { setDebugLine(node); @@ -1699,36 +1836,9 @@ struct Compiler } else if (AstExprIndexExpr* expr = node->as()) { - const Constant* cv = constants.find(expr->index); + uint8_t reg = compileExprAuto(expr->expr, rs); - if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && - double(int(cv->valueNumber)) == cv->valueNumber) - { - LValue result = {LValue::Kind_IndexNumber}; - result.reg = compileExprAuto(expr->expr, rs); - result.number = uint8_t(int(cv->valueNumber) - 1); - result.location = node->location; - - return result; - } - else if (cv && cv->type == Constant::Type_String) - { - LValue result = {LValue::Kind_IndexName}; - result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->getString()); - result.location = node->location; - - return result; - } - else - { - LValue result = {LValue::Kind_IndexExpr}; - result.reg = compileExprAuto(expr->expr, rs); - result.index = compileExprAuto(expr->index, rs); - result.location = node->location; - - return result; - } + return compileLValueIndex(reg, expr->index, rs); } else { @@ -1740,6 +1850,9 @@ struct Compiler void compileLValueUse(const LValue& lv, uint8_t reg, bool set) { + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(lv.location); + switch (lv.kind) { case LValue::Kind_Local: diff --git a/Makefile b/Makefile index b144cac6..638c4c63 100644 --- a/Makefile +++ b/Makefile @@ -23,11 +23,11 @@ VM_SOURCES=$(wildcard VM/src/*.cpp) VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) VM_TARGET=$(BUILD)/libluauvm.a -TESTS_SOURCES=$(wildcard tests/*.cpp) +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -90,11 +90,12 @@ $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -Iextern +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include +$(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a diff --git a/Sources.cmake b/Sources.cmake index bafe7594..22e7af22 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -176,7 +176,8 @@ if(TARGET Luau.Repl.CLI) CLI/FileUtils.cpp CLI/Profiler.h CLI/Profiler.cpp - CLI/Repl.cpp) + CLI/Repl.cpp + CLI/ReplEntry.cpp) endif() if(TARGET Luau.Analyze.CLI) @@ -243,6 +244,21 @@ if(TARGET Luau.Conformance) tests/main.cpp) endif() +if(TARGET Luau.CLI.Test) + # Luau.CLI.Test Sources + target_sources(Luau.CLI.Test PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp + CLI/FileUtils.h + CLI/FileUtils.cpp + CLI/Profiler.h + CLI/Profiler.cpp + CLI/Repl.cpp + + tests/Repl.test.cpp + tests/main.cpp) +endif() + if(TARGET Luau.Web) # Luau.Web Sources target_sources(Luau.Web PRIVATE diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index d5416285..5cffba63 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauGcForwardMetatableBarrier, false) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -869,7 +871,16 @@ int lua_setmetatable(lua_State* L, int objindex) luaG_runerror(L, "Attempt to modify a readonly table"); hvalue(obj)->metatable = mt; if (mt) - luaC_objbarriert(L, hvalue(obj), mt); + { + if (FFlag::LuauGcForwardMetatableBarrier) + { + luaC_objbarrier(L, hvalue(obj), mt); + } + else + { + luaC_objbarriert(L, hvalue(obj), mt); + } + } break; } case LUA_TUSERDATA: diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 34e9ebc1..ecc14e87 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1087,6 +1087,34 @@ static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, St return -1; } +static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams == 1 && nresults == 1) + { + int n = cast_int(L->base - L->ci->func) - clvalue(L->ci->func)->l.p->numparams - 1; + + if (ttisnumber(arg0)) + { + int i = int(nvalue(arg0)); + + // i >= 1 && i <= n + if (unsigned(i - 1) <= unsigned(n)) + { + setobj2s(L, res, L->base - n + (i - 1)); + return 1; + } + // note: for now we don't handle negative case (wrap around) and defer to fallback + } + else if (ttisstring(arg0) && *svalue(arg0) == '#') + { + setnvalue(res, double(n)); + return 1; + } + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1156,4 +1184,6 @@ luau_FastFunction luauF_table[256] = { luauF_countlz, luauF_countrz, + + luauF_select, }; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index abcde779..19222861 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -235,9 +233,6 @@ static int coyieldable(lua_State* L) static int coclose(lua_State* L) { - if (!FFlag::LuauCoroutineClose) - luaL_error(L, "coroutine.close is not enabled"); - lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 581506a8..a3982bc6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAG(LuauCoroutineClose) - /* ** {====================================================== ** Error-recovery functions @@ -300,7 +298,7 @@ static void resume(lua_State* L, void* ud) { // start coroutine LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); - if (FFlag::LuauCoroutineClose && firstArg == L->base) + if (firstArg == L->base) luaG_runerror(L, "cannot resume dead coroutine"); if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 50859b1e..82ac0009 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -93,10 +93,8 @@ static void finishGcCycleStats(global_State* g) g->gcstats.lastcycle = g->gcstats.currcycle; g->gcstats.currcycle = GCCycleStats(); - g->gcstats.cyclestatsacc.markitems += g->gcstats.lastcycle.markitems; g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime; g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime; - g->gcstats.cyclestatsacc.sweepitems += g->gcstats.lastcycle.sweepitems; g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime; } @@ -492,23 +490,22 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page) } } -#define sweepwholelist(L, p, tc) sweeplist(L, p, SIZE_MAX, tc) +#define sweepwholelist(L, p) sweeplist(L, p, SIZE_MAX) -static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount) +static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); GCObject* curr; global_State* g = L->global; int deadmask = otherwhite(g); - size_t startcount = count; LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */ while ((curr = *p) != NULL && count-- > 0) { int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; if (curr->gch.tt == LUA_TTHREAD) { - sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ + sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval); /* sweep open upvalues */ lua_State* th = gco2th(curr); @@ -534,10 +531,6 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr } } - // if we didn't reach the end of the list it means that we've stopped because the count dropped below zero - if (traversedcount) - *traversedcount += startcount - (curr ? count + 1 : count); - return p; } @@ -721,8 +714,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; - g->gcstats.currcycle.sweepitems++; - if (gco->gch.tt == LUA_TTHREAD) { lua_State* th = gco2th(gco); @@ -770,11 +761,11 @@ static int sweepgcopage(lua_State* L, lua_Page* page) { // if the last block was removed, page would be removed as well if (--busyBlocks == 0) - return (pos - start) / blockSize + 1; + return int(pos - start) / blockSize + 1; } } - return (end - start) / blockSize; + return int(end - start) / blockSize; } static size_t gcstep(lua_State* L, size_t limit) @@ -793,8 +784,6 @@ static size_t gcstep(lua_State* L, size_t limit) { while (g->gray && cost < limit) { - g->gcstats.currcycle.markitems++; - cost += propagatemark(g); } @@ -812,8 +801,6 @@ static size_t gcstep(lua_State* L, size_t limit) { while (g->gray && cost < limit) { - g->gcstats.currcycle.markitems++; - cost += propagatemark(g); } @@ -842,10 +829,8 @@ static size_t gcstep(lua_State* L, size_t limit) while (g->sweepstrgc < g->strt.size && cost < limit) { - size_t traversedcount = 0; - sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++], &traversedcount); + sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++]); - g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPCOST; } @@ -855,12 +840,10 @@ static size_t gcstep(lua_State* L, size_t limit) // sweep string buffer list and preserve used string count uint32_t nuse = L->global->strt.nuse; - size_t traversedcount = 0; - sweepwholelist(L, (GCObject**)&g->strbufgc, &traversedcount); + sweepwholelist(L, (GCObject**)&g->strbufgc); L->global->strt.nuse = nuse; - g->gcstats.currcycle.sweepitems += traversedcount; g->gcstate = GCSsweep; // end sweep-string phase } break; @@ -893,10 +876,8 @@ static size_t gcstep(lua_State* L, size_t limit) { while (*g->sweepgc && cost < limit) { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX); - g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPMAX * GC_SWEEPCOST; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 4455fec5..528d0944 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -113,6 +113,7 @@ luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } +// TODO: remove with FFlagLuauGcForwardMetatableBarrier #define luaC_objbarriert(L, t, o) \ { \ if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \ diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 6d3b7772..e1dbce50 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -200,7 +200,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int global_State* g = L->global; - LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); if (!page) @@ -376,7 +376,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); - LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); void* block; @@ -520,7 +520,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) } else { - lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + int(nsize), int(nsize), 1); block = &page->data; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 080f0024..0708b71f 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -96,9 +96,6 @@ struct GCCycleStats double sweeptime = 0.0; - size_t markitems = 0; - size_t sweepitems = 0; - size_t assistwork = 0; size_t explicitwork = 0; diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 41b553b5..292625b0 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -44,10 +44,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { - ScopedFastFlag sffs[] = { - {"LuauPersistDefinitionFileTypes", true}, - }; - loadDefinition(R"( declare function Connect(fn: (string) -> ()) )"); @@ -63,8 +59,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") { - ScopedFastFlag sffs{"LuauStoreMatchingOverloadFnType", true}; - loadDefinition(R"( declare foo: ((string) -> number) & ((number) -> string) )"); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 211e1be1..e8e3b315 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2626,7 +2626,6 @@ local a: A<(number, s@1> TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); check(R"( local function foo1() return 1 end @@ -2728,4 +2727,39 @@ end CHECK(ac.entryMap.count("getx")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauRefactorTypeVarQuestions", true}, + }; + + check(R"( + --!strict + local foo: "hello" | "bye" = "hello" + foo:@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("format")); +} + +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag preferToCallFunctionsForIntersects("PreferToCallFunctionsForIntersects", true); + + check(R"( +local bar: ((number) -> number) & (number, number) -> number) +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3b0d677d..4a28bdde 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,6 +603,37 @@ RETURN R0 1 )"); } +TEST_CASE("TableLiteralsIndexConstant") +{ + ScopedFastFlag sff("LuauCompileTableIndexOpt", true); + + // validate that we use SETTTABLEKS for constant variable keys + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = "key", "value" + return {[a] = 42, [b] = 0} +)"), R"( +NEWTABLE R0 2 0 +LOADN R1 42 +SETTABLEKS R1 R0 K0 +LOADN R1 0 +SETTABLEKS R1 R0 K1 +RETURN R0 1 +)"); + + // validate that we use SETTABLEN for constant variable keys *and* that we predict array size + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = 1, 2 + return {[a] = 42, [b] = 0} +)"), R"( +NEWTABLE R0 0 2 +LOADN R1 42 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionBasic") { CHECK_EQ("\n" + compileFunction0(R"( @@ -2450,6 +2481,37 @@ return )"); } +TEST_CASE("DebugLineInfoAssignment") +{ + ScopedFastFlag sff("LuauCompileTableIndexOpt", true); + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( + local a = { b = { c = { d = 3 } } } + +a +["b"] +["c"] +["d"] = 4 +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: DUPTABLE R0 1 +2: DUPTABLE R1 3 +2: DUPTABLE R2 5 +2: LOADN R3 3 +2: SETTABLEKS R3 R2 K4 +2: SETTABLEKS R2 R1 K2 +2: SETTABLEKS R1 R0 K0 +5: GETTABLEKS R2 R0 K0 +6: GETTABLEKS R1 R2 K2 +7: LOADN R2 4 +7: SETTABLEKS R2 R1 K4 +8: RETURN R0 0 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -2763,6 +2825,75 @@ RETURN R1 -1 )"); } +TEST_CASE("FastcallSelect") +{ + ScopedFastFlag sff("LuauCompileSelectBuiltin", true); + + // select(_, ...) compiles to a builtin call + CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( +LOADK R1 K0 +FASTCALL1 57 R1 +3 +GETIMPORT R0 2 +GETVARARGS R2 -1 +CALL R0 -1 1 +RETURN R0 1 +)"); + + // more complex example: select inside a for loop bound + select from a iterator + CHECK_EQ("\n" + compileFunction0(R"( +local sum = 0 +for i=1, select('#', ...) do + sum += select(i, ...) +end +return sum +)"), R"( +LOADN R0 0 +LOADN R3 1 +LOADK R5 K0 +FASTCALL1 57 R5 +3 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +MOVE R1 R4 +LOADN R2 1 +FORNPREP R1 +7 +FASTCALL1 57 R3 +3 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +ADD R0 R0 R4 +FORNLOOP R1 -7 +RETURN R0 1 +)"); + + // currently we assume a single value return to avoid dealing with stack resizing + CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETVARARGS R2 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#')"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#', foo())"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETIMPORT R2 4 +CALL R2 0 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); +} + TEST_CASE("LotsOfParameters") { const char* source = R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5222af33..914b881f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -331,8 +331,6 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { - ScopedFastFlag sff("LuauCoroutineClose", true); - runConformance("coroutine.lua"); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 405f26e0..ea1a08fe 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -956,7 +956,6 @@ TEST_CASE("no_use_after_free_with_type_fun_instantiation") { // This flag forces this test to crash if there's a UAF in this code. ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true); - ScopedFastFlag sff_LuauCloneCorrectlyBeforeMutatingTableType("LuauCloneCorrectlyBeforeMutatingTableType", true); FrontendFixture fix; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index ac81005c..90831ee9 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2000,6 +2000,73 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); } +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ + { + AstStat* stat = parse("return if true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + } + + { + AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use "else if" as opposed to elseif + { + AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use an if-else expression as the conditional expression of an if-else expression + { + AstStat* stat = parse("return if if true then false else true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + auto* nestedIfElseExpr = ifElseExpr->condition->as(); + REQUIRE(nestedIfElseExpr != nullptr); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2504,71 +2571,4 @@ type Y = (T...) -> U... CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") -{ - { - AstStat* stat = parse("return if true then 1 else 2"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - } - - { - AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } - - // Use "else if" as opposed to elseif - { - AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } - - // Use an if-else expression as the conditional expression of an if-else expression - { - AstStat* stat = parse("return if if true then false else true then 1 else 2"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - auto* nestedIfElseExpr = ifElseExpr->condition->as(); - REQUIRE(nestedIfElseExpr != nullptr); - } -} - -TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") -{ - AstStat* stat = parse(R"( -type Packed = () -> T... - -type A = Packed -type B = Packed<...number> -type C = Packed<(number, X...)> - )"); - REQUIRE(stat != nullptr); -} - TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp new file mode 100644 index 00000000..f660bcd3 --- /dev/null +++ b/tests/Repl.test.cpp @@ -0,0 +1,117 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Repl.h" + +#include "doctest.h" + +#include +#include +#include +#include + + +class ReplFixture +{ +public: + ReplFixture() + : luaState(luaL_newstate(), lua_close) + { + L = luaState.get(); + setupState(L); + luaL_sandboxthread(L); + + std::string result = runCode(L, prettyPrintSource); + } + + // Returns all of the output captured from the pretty printer + std::string getCapturedOutput() + { + lua_getglobal(L, "capturedoutput"); + const char* str = lua_tolstring(L, -1, nullptr); + std::string result(str); + lua_pop(L, 1); + return result; + } + lua_State* L; + +private: + std::unique_ptr luaState; + + // This is a simplicitic and incomplete pretty printer. + // It is included here to test that the pretty printer hook is being called. + // More elaborate tests to ensure correct output can be added if we introduce + // a more feature rich pretty printer. + std::string prettyPrintSource = R"( +-- Accumulate pretty printer output in `capturedoutput` +capturedoutput = "" + +function arraytostring(arr) + local strings = {} + table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end ) + return "{" .. table.concat(strings, ", ") .. "}" +end + +function pptostring(x) + if type(x) == "table" then + -- Just assume array-like tables for now. + return arraytostring(x) + elseif type(x) == "string" then + return '"' .. x .. '"' + else + return tostring(x) + end +end + +-- Note: Instead of calling print, the pretty printer just stores the output +-- in `capturedoutput` so we can check for the correct results. +function _PRETTYPRINT(...) + local args = table.pack(...) + local strings = {} + for i=1, args.n do + local item = args[i] + local str = pptostring(item, customoptions) + if i == 1 then + capturedoutput = capturedoutput .. str + else + capturedoutput = capturedoutput .. "\t" .. str + end + end +end +)"; +}; + +TEST_SUITE_BEGIN("ReplPrettyPrint"); + +TEST_CASE_FIXTURE(ReplFixture, "AdditionStatement") +{ + runCode(L, "return 30 + 12"); + CHECK(getCapturedOutput() == "42"); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableLiteral") +{ + runCode(L, "return {1, 2, 3, 4}"); + CHECK(getCapturedOutput() == "{1, 2, 3, 4}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "StringLiteral") +{ + runCode(L, "return 'str'"); + CHECK(getCapturedOutput() == "\"str\""); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithStringLiterals") +{ + runCode(L, "return {1, 'two', 3, 'four'}"); + CHECK(getCapturedOutput() == "{1, \"two\", 3, \"four\"}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments") +{ + runCode(L, "return 3, 'three'"); + CHECK(getCapturedOutput() == "3\t\"three\""); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 445ee532..bbb26291 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -435,8 +435,6 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = ((() -> number)?) -> F? local function f(p) return f end @@ -450,8 +448,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union" TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_intersection") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f end local a: ((number) -> ()) & typeof(f) diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 822bd727..76ab23b3 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -11,8 +11,6 @@ TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = () -> F? local function f() @@ -194,8 +192,6 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type A = () -> (number, B) type B = () -> (string, A) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 114679e3..a7f27551 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) - TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -656,11 +654,7 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauExtendedFunctionMismatchError) - CHECK_EQ( - toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") @@ -675,11 +669,8 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauExtendedFunctionMismatchError) - CHECK_EQ(toString(result.errors[0]), - R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e6d3d4d4..47c13be9 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -271,6 +271,32 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } +// Also belongs in TypeInfer.refinements.test.cpp. +// Just needs to fully support equality refinement. Which is annoying without type states. +TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") +{ + ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + + CheckResult result = check(R"( + type T = {x: string, y: number} | {x: nil, y: nil} + + local function f(t: T) + if t.x ~= nil then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28}))); + + // Should be {| x: nil, y: nil |} + CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); +} + TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; @@ -590,8 +616,6 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - // Mutability in type function application right now can create strange recursive types // TODO: instantiation right now is problematic, in this example should either leave the Table type alone // or it should rename the type to 'Self' so that the result will be 'Self
' diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index d76b920b..f346ddfd 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -6,11 +6,77 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauDiscriminableUnions) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; +namespace +{ +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + if (expr.args.size != 1) + return std::nullopt; + + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; + + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; + + unfreeze(typeChecker.globalTypes); + TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); + freeze(typeChecker.globalTypes); + return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; +} + +struct RefinementClassFixture : Fixture +{ + RefinementClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + getMutable(vec3)->props = { + {"X", Property{typeChecker.numberType}}, + {"Y", Property{typeChecker.numberType}}, + {"Z", Property{typeChecker.numberType}}, + }; + + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + + TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); + TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); + getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + + getMutable(inst)->props = { + {"Name", Property{typeChecker.stringType}}, + {"IsA", Property{isA}}, + }; + + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + getMutable(part)->props = { + {"Position", Property{vec3}}, + }; + + typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + freeze(typeChecker.globalTypes); + } +}; +} // namespace + TEST_SUITE_BEGIN("RefinementTest"); TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") @@ -196,8 +262,18 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + if (FFlag::LuauDiscriminableUnions) + { + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -237,7 +313,6 @@ TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") { - CheckResult result = check(R"( local t: {x: number?} = {x = 1} @@ -254,7 +329,6 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") { - CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -360,7 +434,10 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauSingletonTypes", true}, + }; CheckResult result = check(R"( local function f(a: (string | number)?) @@ -374,16 +451,8 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "string"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" } TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") @@ -416,7 +485,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( local function f(a, b: string?) @@ -428,16 +498,8 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") @@ -527,9 +589,17 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + if (FFlag::LuauDiscriminableUnions) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + } + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); } @@ -614,214 +684,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" } -namespace -{ -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) -{ - if (expr.args.size != 1) - return std::nullopt; - - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; - - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; - - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} - -struct RefinementClassFixture : Fixture -{ - RefinementClassFixture() - { - TypeArena& arena = typeChecker.globalTypes; - - unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); - getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, - }; - - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); - - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); - TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - - getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, - {"IsA", Property{isA}}, - }; - - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); - getMutable(part)->props = { - {"Position", Property{vec3}}, - }; - - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - freeze(typeChecker.globalTypes); - } -}; -} // namespace - -TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") -{ - CheckResult result = check(R"( - local function f(vec) - local X, Y, Z = vec.X, vec.Y, vec.Z - - if type(vec) == "vector" then - local foo = vec - elseif typeof(vec) == "Instance" then - local foo = vec - else - local foo = vec - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - else - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); - - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" - else - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") -{ - CheckResult result = check(R"( - local function f(x: Instance | Vector3) - if typeof(x) == "Vector3" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") -{ - CheckResult result = check(R"( - local function f(x: string | number | Instance | Vector3) - if type(x) == "userdata" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | string) - if typeof(x) == "Instance" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | Instance | string | Vector3 | any) - if typeof(x) == "Instance" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") -{ - CheckResult result = check(R"( - --!nonstrict - - local function f(x) - if typeof(x) == "Instance" and x:IsA("Folder") then - local foo = x - elseif typeof(x) == "table" then - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | string) - if typeof(x) ~= "Instance" or not x:IsA("Part") then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); -} - TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { CheckResult result = check(R"( @@ -1145,4 +1007,259 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type T = {tag: "missing", x: nil} | {tag: "exists", x: string} + + local function f(t: T) + if t.x then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_tag") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", name: string, catfood: string} + type Dog = {tag: "Dog", name: string, dogfood: string} + type Animal = Cat | Dog + + local function f(animal: Animal) + if animal.tag == "Cat" then + local cat: Cat = animal + elseif animal.tag == "Dog" then + local dog: Dog = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); +} + +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") +{ + CheckResult result = check(R"( + local function len(a: {any}) + return a and #a or nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} + + local function f(t: T) + if t.x:IsA("Part") then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") +{ + CheckResult result = check(R"( + local function f(vec) + local X, Y, Z = vec.X, vec.Y, vec.Z + + if type(vec) == "vector" then + local foo = vec + elseif typeof(vec) == "Instance" then + local foo = vec + else + local foo = vec + end + end + )"); + + if (FFlag::LuauDiscriminableUnions) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + } + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") +{ + CheckResult result = check(R"( + local function f(x: Instance | Vector3) + if typeof(x) == "Vector3" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") +{ + CheckResult result = check(R"( + local function f(x: string | number | Instance | Vector3) + if type(x) == "userdata" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | Instance | string | Vector3 | any) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f(x) + if typeof(x) == "Instance" and x:IsA("Folder") then + local foo = x + elseif typeof(x) == "table" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) ~= "Instance" or not x:IsA("Part") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 94cfb643..df365fda 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -379,9 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, }; CheckResult result = check(R"( @@ -404,9 +402,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, }; CheckResult result = check(R"( @@ -429,9 +425,7 @@ TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, {"LuauIfElseExpectedType2", true}, {"LuauIfElseBranchTypeUnion", true}, }; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 644efed7..48310921 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) - TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -2075,22 +2073,11 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); - } - else - { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' -caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' -caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); - } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2166,7 +2153,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { ScopedFastFlag sff[]{ - {"LuauFixRecursiveMetatableCall", true}, {"LuauUnsealedTableLiteral", true}, }; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7ee5253c..c9b30e1a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -959,8 +958,6 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f @@ -973,8 +970,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f(g) return f(f) @@ -1699,8 +1694,6 @@ TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict local s @@ -1711,8 +1704,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict function u(t, w) @@ -3326,11 +3317,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") +TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") { CheckResult result = check(R"( - local x - print((x == true and (x .. "y")) .. 1) + local function f(x) + return x .. "y" + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -3340,13 +3332,14 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") { CheckResult result = check(R"( - local x - print("foo" .. x) + local function f(x) + return "foo" .. x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("(string) -> string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") @@ -4374,8 +4367,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") TEST_CASE_FIXTURE(Fixture, "record_matching_overload") { - ScopedFastFlag sffs("LuauStoreMatchingOverloadFnType", true); - CheckResult result = check(R"( type Overload = ((string) -> string) & ((number) -> number) local abc: Overload @@ -4475,17 +4466,10 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); - } + toString(result.errors[0])); // Infer from variadic packs into elements result = check(R"( @@ -4618,17 +4602,9 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ( - "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); - } - else - { - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); - } + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4741,8 +4717,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - CheckResult result = check(R"( type A = { x: number } local a: A = { x = 1 } @@ -4965,8 +4939,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number) -> string @@ -4983,8 +4955,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number, string) -> string @@ -5001,8 +4971,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> (number) type B = (number, number) -> (number, boolean) @@ -5019,8 +4987,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number, number) -> number @@ -5037,8 +5003,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> (number, string) type B = (number, number) -> (number, boolean) @@ -5069,8 +5033,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") { - ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; - CheckResult result = check(R"( local t = {} diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index d4878d14..079870f5 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -931,7 +931,6 @@ type R = { m: F } TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") { ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; CheckResult result = check(R"( local a: () -> (number, ...string) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index b54ba996..759794e6 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -464,8 +464,6 @@ local a: XYZ = { w = 4 } TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") { - ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 2e0d149e..329e7b1f 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -268,8 +268,6 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TEST_CASE("tagging_tables") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar ttv{TableTypeVar{}}; CHECK(!Luau::hasTag(&ttv, "foo")); Luau::attachTag(&ttv, "foo"); @@ -278,8 +276,6 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); @@ -288,8 +284,6 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; @@ -307,8 +301,6 @@ TEST_CASE("tagging_subclasses") TEST_CASE("tagging_functions") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypePackVar empty{TypePack{}}; TypeVar ftv{FunctionTypeVar{&empty, &empty}}; CHECK(!Luau::hasTag(&ftv, "foo")); @@ -318,8 +310,6 @@ TEST_CASE("tagging_functions") TEST_CASE("tagging_props") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - Property prop{}; CHECK(!Luau::hasTag(prop, "foo")); Luau::attachTag(prop, "foo"); @@ -370,4 +360,66 @@ local b: (T, T, T) -> T CHECK_EQ(count, 1); } +TEST_CASE("isString_on_string_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + CHECK(isString(&helloString)); +} + +TEST_CASE("isString_on_unions_of_various_string_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; + + CHECK(isString(&union_)); +} + +TEST_CASE("proof_that_isString_uses_all_of") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString, &booleanType}}}; + + CHECK(!isString(&union_)); +} + +TEST_CASE("isBoolean_on_boolean_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + CHECK(isBoolean(&trueBool)); +} + +TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; + + CHECK(isBoolean(&union_)); +} + +TEST_CASE("proof_that_isBoolean_uses_all_of") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool, &stringType}}}; + + CHECK(!isBoolean(&union_)); +} + TEST_SUITE_END(); From 78039f45355c10064bfb635ca6b316b2e9303294 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 27 Jan 2022 13:52:56 -0800 Subject: [PATCH 020/102] Thanks gcc, we know you can't compile code. --- Analysis/src/TypeInfer.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 23fcc2d5..d6b3b5b3 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5929,7 +5929,9 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa if (!sense || canUnify(eqP.type, option, eqP.location).empty()) return sense ? eqP.type : option; - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } return option; From f6b4cc9442f57db2ef9c186d7fea008dde8b2f68 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 3 Feb 2022 15:09:37 -0800 Subject: [PATCH 021/102] Sync to upstream/release/513 --- Analysis/include/Luau/Error.h | 12 +- Analysis/include/Luau/LValue.h | 14 +- Analysis/include/Luau/Substitution.h | 6 + Analysis/include/Luau/TxnLog.h | 8 +- Analysis/include/Luau/TypeInfer.h | 9 +- Analysis/include/Luau/TypedAllocator.h | 22 +- Analysis/include/Luau/Unifier.h | 4 + Analysis/src/BuiltinDefinitions.cpp | 41 +- Analysis/src/LValue.cpp | 52 +- Analysis/src/Linter.cpp | 4 +- Analysis/src/Substitution.cpp | 65 +- Analysis/src/TxnLog.cpp | 71 +- Analysis/src/TypeInfer.cpp | 244 +-- Analysis/src/TypeVar.cpp | 8 +- Analysis/src/TypedAllocator.cpp | 1 - Analysis/src/Unifier.cpp | 541 ++--- Ast/src/Parser.cpp | 3 +- CLI/Coverage.cpp | 2 +- CLI/FileUtils.cpp | 2 +- CLI/Repl.cpp | 55 +- CLI/ReplEntry.cpp | 3 +- CMakeLists.txt | 19 +- Compiler/src/Builtins.cpp | 4 +- Compiler/src/Compiler.cpp | 9 +- Compiler/src/TableShape.cpp | 8 - Makefile | 24 +- Sources.cmake | 5 + VM/include/luaconf.h | 4 +- VM/src/lcorolib.cpp | 2 +- VM/src/ldebug.cpp | 3 +- VM/src/lfunc.cpp | 6 +- VM/src/lgc.cpp | 2 + VM/src/lmem.cpp | 53 +- VM/src/lstring.cpp | 2 +- VM/src/lvmexecute.cpp | 11 +- VM/src/lvmload.cpp | 8 +- extern/isocline/.gitignore | 16 + extern/isocline/LICENSE | 21 + extern/isocline/include/isocline.h | 627 ++++++ extern/isocline/readme.md | 460 ++++ extern/isocline/src/attr.c | 294 +++ extern/isocline/src/attr.h | 70 + extern/isocline/src/bbcode.c | 842 +++++++ extern/isocline/src/bbcode.h | 37 + extern/isocline/src/bbcode_colors.c | 194 ++ extern/isocline/src/common.c | 347 +++ extern/isocline/src/common.h | 187 ++ extern/isocline/src/completers.c | 675 ++++++ extern/isocline/src/completions.c | 326 +++ extern/isocline/src/completions.h | 52 + extern/isocline/src/editline.c | 1142 ++++++++++ extern/isocline/src/editline_completion.c | 277 +++ extern/isocline/src/editline_help.c | 140 ++ extern/isocline/src/editline_history.c | 260 +++ extern/isocline/src/env.h | 60 + extern/isocline/src/highlight.c | 259 +++ extern/isocline/src/highlight.h | 24 + extern/isocline/src/history.c | 269 +++ extern/isocline/src/history.h | 38 + extern/isocline/src/isocline.c | 589 +++++ extern/isocline/src/stringbuf.c | 1038 +++++++++ extern/isocline/src/stringbuf.h | 121 ++ extern/isocline/src/term.c | 1124 ++++++++++ extern/isocline/src/term.h | 85 + extern/isocline/src/term_color.c | 371 ++++ extern/isocline/src/tty.c | 889 ++++++++ extern/isocline/src/tty.h | 160 ++ extern/isocline/src/tty_esc.c | 401 ++++ extern/isocline/src/undo.c | 67 + extern/isocline/src/undo.h | 24 + extern/isocline/src/wcwidth.c | 292 +++ extern/linenoise.hpp | 2415 --------------------- tests/Autocomplete.test.cpp | 100 +- tests/Compiler.test.cpp | 20 +- tests/Conformance.test.cpp | 1 - tests/LValue.test.cpp | 60 +- tests/Linter.test.cpp | 25 +- tests/Parser.test.cpp | 7 +- tests/TypeInfer.annotations.test.cpp | 3 - tests/TypeInfer.builtins.test.cpp | 51 + tests/TypeInfer.provisional.test.cpp | 13 - tests/TypeInfer.refinements.test.cpp | 65 +- tests/TypeInfer.tables.test.cpp | 8 +- tests/TypeInfer.test.cpp | 86 +- tests/TypeInfer.tryUnify.test.cpp | 13 + tests/conformance/basic.lua | 4 + tests/conformance/vararg.lua | 50 +- 87 files changed, 12814 insertions(+), 3212 deletions(-) create mode 100644 extern/isocline/.gitignore create mode 100644 extern/isocline/LICENSE create mode 100644 extern/isocline/include/isocline.h create mode 100644 extern/isocline/readme.md create mode 100644 extern/isocline/src/attr.c create mode 100644 extern/isocline/src/attr.h create mode 100644 extern/isocline/src/bbcode.c create mode 100644 extern/isocline/src/bbcode.h create mode 100644 extern/isocline/src/bbcode_colors.c create mode 100644 extern/isocline/src/common.c create mode 100644 extern/isocline/src/common.h create mode 100644 extern/isocline/src/completers.c create mode 100644 extern/isocline/src/completions.c create mode 100644 extern/isocline/src/completions.h create mode 100644 extern/isocline/src/editline.c create mode 100644 extern/isocline/src/editline_completion.c create mode 100644 extern/isocline/src/editline_help.c create mode 100644 extern/isocline/src/editline_history.c create mode 100644 extern/isocline/src/env.h create mode 100644 extern/isocline/src/highlight.c create mode 100644 extern/isocline/src/highlight.h create mode 100644 extern/isocline/src/history.c create mode 100644 extern/isocline/src/history.h create mode 100644 extern/isocline/src/isocline.c create mode 100644 extern/isocline/src/stringbuf.c create mode 100644 extern/isocline/src/stringbuf.h create mode 100644 extern/isocline/src/term.c create mode 100644 extern/isocline/src/term.h create mode 100644 extern/isocline/src/term_color.c create mode 100644 extern/isocline/src/tty.c create mode 100644 extern/isocline/src/tty.h create mode 100644 extern/isocline/src/tty_esc.c create mode 100644 extern/isocline/src/undo.c create mode 100644 extern/isocline/src/undo.h create mode 100644 extern/isocline/src/wcwidth.c delete mode 100644 extern/linenoise.hpp diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index aff3c4d9..a71e0224 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -285,12 +285,12 @@ struct TypesAreUnrelated bool operator==(const TypesAreUnrelated& rhs) const; }; -using TypeErrorData = Variant; +using TypeErrorData = + Variant; struct TypeError { diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 8fd96f05..3d510d5f 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -4,7 +4,6 @@ #include "Luau/Variant.h" #include "Luau/Symbol.h" -#include // TODO: Kill with LuauLValueAsKey. #include #include @@ -38,24 +37,13 @@ std::optional tryGetLValue(const class AstExpr& expr); // Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. std::pair> getFullName(const LValue& lvalue); -// Kill with LuauLValueAsKey. -std::string toString(const LValue& lvalue); - template const T* get(const LValue& lvalue) { return get_if(&lvalue); } -using NEW_RefinementMap = std::unordered_map; -using DEPRECATED_RefinementMap = std::map; - -// Transient. Kill with LuauLValueAsKey. -struct RefinementMap -{ - NEW_RefinementMap NEW_refinements; - DEPRECATED_RefinementMap DEPRECATED_refinements; -}; +using RefinementMap = std::unordered_map; void merge(RefinementMap& l, const RefinementMap& r, std::function f); void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 80a14e8f..4f3307cd 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -55,6 +55,8 @@ namespace Luau { +struct TxnLog; + enum class TarjanResult { TooManyChildren, @@ -89,6 +91,10 @@ struct Tarjan int childCount = 0; + // This should never be null; ensure you initialize it before calling + // substitution methods. + const TxnLog* log; + std::vector edgesTy; std::vector edgesTp; std::vector worklist; diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index dc45bebf..02b87374 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -72,6 +72,9 @@ struct PendingType } }; +std::string toString(PendingType* pending); +std::string dump(PendingType* pending); + // Pending state for a TypePackVar. Generated by a TxnLog and committed via // TxnLog::commit. struct PendingTypePack @@ -85,6 +88,9 @@ struct PendingTypePack } }; +std::string toString(PendingTypePack* pending); +std::string dump(PendingTypePack* pending); + template T* getMutable(PendingType* pending) { @@ -237,7 +243,7 @@ struct TxnLog // Follows a type, accounting for pending type states. The returned type may have // pending state; you should use `pending` or `get` to find out. - TypeId follow(TypeId ty); + TypeId follow(TypeId ty) const; // Follows a type pack, accounting for pending type states. The returned type pack // may have pending state; you should use `pending` or `get` to find out. diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index b843509d..90dc9f42 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -262,7 +262,7 @@ public: * {method: ({method: () -> a}) -> a} * */ - TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); + TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log = TxnLog::empty()); // Replace any free types or type packs by `any`. // This is used when exporting types from modules, to make sure free types don't leak. @@ -308,9 +308,15 @@ private: TypeId singletonType(bool value); TypeId singletonType(std::string value); + TypeIdPredicate mkTruthyPredicate(bool sense); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); +public: + std::optional pickTypesFromSense(TypeId type, bool sense); + +private: TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); // ex @@ -349,7 +355,6 @@ private: void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); - std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index 64227e7c..c1c04d10 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauTypedAllocatorZeroStart) - namespace Luau { @@ -22,10 +20,7 @@ class TypedAllocator public: TypedAllocator() { - if (FFlag::LuauTypedAllocatorZeroStart) - currentBlockSize = kBlockSize; - else - appendBlock(); + currentBlockSize = kBlockSize; } ~TypedAllocator() @@ -64,18 +59,12 @@ public: bool empty() const { - if (FFlag::LuauTypedAllocatorZeroStart) - return stuff.empty(); - else - return stuff.size() == 1 && currentBlockSize == 0; + return stuff.empty(); } size_t size() const { - if (FFlag::LuauTypedAllocatorZeroStart) - return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; - else - return kBlockSize * (stuff.size() - 1) + currentBlockSize; + return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; } void clear() @@ -84,10 +73,7 @@ public: unfreeze(); free(); - if (FFlag::LuauTypedAllocatorZeroStart) - currentBlockSize = kBlockSize; - else - appendBlock(); + currentBlockSize = kBlockSize; } void freeze() diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 1b1671c0..9db4e22b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -51,6 +51,10 @@ struct Unifier private: void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); + void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); + void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index d527414a..d72422a5 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) + /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -391,12 +393,41 @@ static std::optional> magicFunctionAssert( { auto [paramPack, predicates] = exprResult; - if (expr.args.size < 1) + if (FFlag::LuauAssertStripsFalsyTypes) + { + TypeArena& arena = typechecker.currentModule->internalTypes; + + auto [head, tail] = flatten(paramPack); + if (head.empty() && tail) + { + std::optional fst = first(*tail); + if (!fst) + return ExprResult{paramPack}; + head.push_back(*fst); + } + + typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); + + if (head.size() > 0) + { + std::optional newhead = typechecker.pickTypesFromSense(head[0], true); + if (!newhead) + head = {typechecker.nilType}; + else + head[0] = *newhead; + } + + return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; + } + else + { + if (expr.args.size < 1) + return ExprResult{paramPack}; + + typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); + return ExprResult{paramPack}; - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - return ExprResult{paramPack}; + } } static std::optional> magicFunctionPack( diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index da6804c6..c9466a40 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauLValueAsKey) - namespace Luau { @@ -94,17 +92,7 @@ std::pair> getFullName(const LValue& lvalue) return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } -// Kill with LuauLValueAsKey. -std::string toString(const LValue& lvalue) -{ - auto [symbol, keys] = getFullName(lvalue); - std::string s = toString(symbol); - for (std::string key : keys) - s += "." + key; - return s; -} - -static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::function f) +void merge(RefinementMap& l, const RefinementMap& r, std::function f) { for (const auto& [k, a] : r) { @@ -115,45 +103,9 @@ static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::functio } } -static void merge(DEPRECATED_RefinementMap& l, const DEPRECATED_RefinementMap& r, std::function f) -{ - auto itL = l.begin(); - auto itR = r.begin(); - while (itL != l.end() && itR != r.end()) - { - const auto& [k, a] = *itR; - if (itL->first == k) - { - l[k] = f(itL->second, a); - ++itL; - ++itR; - } - else if (itL->first < k) - ++itL; - else - { - l[k] = a; - ++itR; - } - } - - l.insert(itR, r.end()); -} - -void merge(RefinementMap& l, const RefinementMap& r, std::function f) -{ - if (FFlag::LuauLValueAsKey) - return merge(l.NEW_refinements, r.NEW_refinements, f); - else - return merge(l.DEPRECATED_refinements, r.DEPRECATED_refinements, f); -} - void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) { - if (FFlag::LuauLValueAsKey) - refis.NEW_refinements[lvalue] = ty; - else - refis.DEPRECATED_refinements[toString(lvalue)] = ty; + refis[lvalue] = ty; } } // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 905b70bf..57a33e93 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauLintTableCreateTable, false) - namespace Luau { @@ -2155,7 +2153,7 @@ private: "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } - if (FFlag::LuauLintTableCreateTable && func->index == "create" && node->args.size == 2) + if (func->index == "create" && node->args.size == 2) { // table.create(n, {...}) if (args[1]->is()) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 3d004bee..bacbca76 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -2,6 +2,7 @@ #include "Luau/Substitution.h" #include "Luau/Common.h" +#include "Luau/TxnLog.h" #include #include @@ -13,17 +14,17 @@ namespace Luau void Tarjan::visitChildren(TypeId ty, int index) { - ty = follow(ty); + ty = log->follow(ty); if (ignoreChildren(ty)) return; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableTypeVar* ttv = log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -40,17 +41,17 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableTypeVar* mtv = log->getMutable(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionTypeVar* utv = log->getMutable(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionTypeVar* itv = log->getMutable(ty)) { for (TypeId part : itv->parts) visitChild(part); @@ -59,19 +60,19 @@ void Tarjan::visitChildren(TypeId ty, int index) void Tarjan::visitChildren(TypePackId tp, int index) { - tp = follow(tp); + tp = log->follow(tp); if (ignoreChildren(tp)) return; - if (const TypePack* tpp = get(tp)) + if (const TypePack* tpp = log->getMutable(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = get(tp)) + else if (const VariadicTypePack* vtp = log->getMutable(tp)) { visitChild(vtp->ty); } @@ -79,7 +80,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -97,7 +98,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -115,7 +116,7 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -123,7 +124,7 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -243,7 +244,7 @@ void Tarjan::clear() TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; - ty = follow(ty); + ty = log->follow(ty); clear(); auto [index, fresh] = indexify(ty); @@ -254,7 +255,7 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; - tp = follow(tp); + tp = log->follow(tp); clear(); auto [index, fresh] = indexify(tp); @@ -325,7 +326,7 @@ TarjanResult FindDirty::findDirty(TypePackId tp) std::optional Substitution::substitute(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); newTypes.clear(); newPacks.clear(); @@ -345,7 +346,7 @@ std::optional Substitution::substitute(TypeId ty) std::optional Substitution::substitute(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); newTypes.clear(); newPacks.clear(); @@ -365,11 +366,11 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); TypeId result = ty; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -379,7 +380,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableTypeVar* ttv = log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -392,19 +393,19 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableTypeVar* mtv = log->getMutable(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionTypeVar* utv = log->getMutable(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionTypeVar* itv = log->getMutable(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; @@ -417,15 +418,15 @@ TypeId Substitution::clone(TypeId ty) TypePackId Substitution::clone(TypePackId tp) { - tp = follow(tp); - if (const TypePack* tpp = get(tp)) + tp = log->follow(tp); + if (const TypePack* tpp = log->getMutable(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = get(tp)) + else if (const VariadicTypePack* vtp = log->getMutable(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -437,7 +438,7 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); if (isDirty(ty)) newTypes[ty] = clean(ty); else @@ -446,7 +447,7 @@ void Substitution::foundDirty(TypeId ty) void Substitution::foundDirty(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); if (isDirty(tp)) newPacks[tp] = clean(tp); else @@ -455,7 +456,7 @@ void Substitution::foundDirty(TypePackId tp) TypeId Substitution::replace(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -464,7 +465,7 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -473,7 +474,7 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); if (ignoreChildren(ty)) return; @@ -519,7 +520,7 @@ void Substitution::replaceChildren(TypeId ty) void Substitution::replaceChildren(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index a46ac0c3..0968a4c1 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TxnLog.h" +#include "Luau/ToString.h" #include "Luau/TypePack.h" #include @@ -80,6 +81,56 @@ void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) sharedSeen->pop_back(); } +const std::string nullPendingResult = ""; + +std::string toString(PendingType* pending) +{ + if (pending == nullptr) + return nullPendingResult; + + return toString(pending->pending); +} + +std::string dump(PendingType* pending) +{ + if (pending == nullptr) + { + printf("%s\n", nullPendingResult.c_str()); + return nullPendingResult; + } + + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string result = toString(pending->pending, opts); + printf("%s\n", result.c_str()); + return result; +} + +std::string toString(PendingTypePack* pending) +{ + if (pending == nullptr) + return nullPendingResult; + + return toString(pending->pending); +} + +std::string dump(PendingTypePack* pending) +{ + if (pending == nullptr) + { + printf("%s\n", nullPendingResult.c_str()); + return nullPendingResult; + } + + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string result = toString(pending->pending, opts); + printf("%s\n", result.c_str()); + return result; +} + static const TxnLog emptyLog; const TxnLog* TxnLog::empty() @@ -199,8 +250,6 @@ PendingTypePack* TxnLog::queue(TypePackId tp) PendingType* TxnLog::pending(TypeId ty) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) @@ -212,8 +261,6 @@ PendingType* TxnLog::pending(TypeId ty) const PendingTypePack* TxnLog::pending(TypePackId tp) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) @@ -225,8 +272,6 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - PendingType* newTy = queue(ty); newTy->pending = replacement; return newTy; @@ -234,8 +279,6 @@ PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - PendingTypePack* newTp = queue(tp); newTp->pending = replacement; return newTp; @@ -243,7 +286,6 @@ PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(ty)); PendingType* newTy = queue(ty); @@ -255,7 +297,6 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); @@ -278,7 +319,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(tp)); PendingTypePack* newTp = queue(tp); @@ -292,7 +332,6 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(ty)); PendingType* newTy = queue(ty); @@ -306,8 +345,6 @@ PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexe std::optional TxnLog::getLevel(TypeId ty) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - if (FreeTypeVar* ftv = getMutable(ty)) return ftv->level; else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) @@ -318,10 +355,8 @@ std::optional TxnLog::getLevel(TypeId ty) const return std::nullopt; } -TypeId TxnLog::follow(TypeId ty) +TypeId TxnLog::follow(TypeId ty) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - return Luau::follow(ty, [this](TypeId ty) { PendingType* state = this->pending(ty); @@ -337,8 +372,6 @@ TypeId TxnLog::follow(TypeId ty) TypePackId TxnLog::follow(TypePackId tp) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - return Luau::follow(tp, [this](TypePackId tp) { PendingTypePack* state = this->pending(tp); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 23fcc2d5..4d25fe2e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,6 +32,7 @@ LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) +LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) @@ -40,13 +41,12 @@ LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) -LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) +LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) namespace Luau { @@ -1117,7 +1117,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco ty = follow(ty); - if (tableSelf && !selfTy->persistent) + if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1130,7 +1130,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && !selfTy->persistent) + if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1657,7 +1657,7 @@ std::optional TypeChecker::getIndexTypeFromType( RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); // Not needed when we normalize types. - if (FFlag::LuauLValueAsKey && get(follow(t))) + if (get(follow(t))) return t; if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) @@ -1802,12 +1802,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn { TypeId ty = checkLValue(scope, expr); - if (FFlag::LuauRefiLookupFromIndexExpr) - { - if (std::optional lvalue = tryGetLValue(expr)) - if (std::optional refiTy = resolveLValue(scope, *lvalue)) - return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - } + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; return {ty}; } @@ -2471,33 +2468,28 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi { if (expr.op == AstExprBinary::And) { - ExprResult lhs = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); - // We can't just report errors here. - // This function can be called from AstStatLocal or from AstStatIf, or even from AstExprBinary (and others). - // For now, ignore the errors returned by the predicate resolver. - // We may need an extra property for each predicate set that indicates it has been resolved. - // Requires a slight modification to the data structure. ScopePtr innerScope = childScope(scope, expr.location); - resolve(lhs.predicates, innerScope, true); + resolve(lhsPredicates, innerScope, true); - ExprResult rhs = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type), - {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy), + {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { - ExprResult lhs = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); ScopePtr innerScope = childScope(scope, expr.location); - resolve(lhs.predicates, innerScope, false); + resolve(lhsPredicates, innerScope, false); - ExprResult rhs = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type, lhs.predicates); - return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { @@ -2535,27 +2527,15 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy TypeId annotationType = resolveType(scope, *expr.annotation); ExprResult result = checkExpr(scope, *expr.expr, annotationType); - if (FFlag::LuauBidirectionalAsExpr) - { - // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (canUnify(annotationType, result.type, expr.location).empty()) - return {annotationType, std::move(result.predicates)}; - - if (canUnify(result.type, annotationType, expr.location).empty()) - return {annotationType, std::move(result.predicates)}; - - reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); - return {errorRecoveryType(annotationType), std::move(result.predicates)}; - } - else - { - ErrorVec errorVec = canUnify(annotationType, result.type, expr.location); - reportErrors(errorVec); - if (!errorVec.empty()) - annotationType = errorRecoveryType(annotationType); - + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(annotationType, result.type, expr.location).empty()) return {annotationType, std::move(result.predicates)}; - } + + if (canUnify(result.type, annotationType, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; + + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) @@ -4295,7 +4275,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); if (!child.errors.empty()) { - TypeId instantiated = instantiate(scope, subTy, state.location); + TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors @@ -4330,7 +4310,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s bool Instantiation::isDirty(TypeId ty) { - if (get(ty)) + if (log->getMutable(ty)) return true; else return false; @@ -4343,7 +4323,7 @@ bool Instantiation::isDirty(TypePackId tp) bool Instantiation::ignoreChildren(TypeId ty) { - if (get(ty)) + if (log->getMutable(ty)) return true; else return false; @@ -4351,7 +4331,7 @@ bool Instantiation::ignoreChildren(TypeId ty) TypeId Instantiation::clean(TypeId ty) { - const FunctionTypeVar* ftv = get(ty); + const FunctionTypeVar* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; @@ -4362,6 +4342,7 @@ TypeId Instantiation::clean(TypeId ty) // Annoyingly, we have to do this even if there are no generics, // to replace any generic tables. + replaceGenerics.log = log; replaceGenerics.level = level; replaceGenerics.currentModule = currentModule; replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); @@ -4383,7 +4364,7 @@ TypePackId Instantiation::clean(TypePackId tp) bool ReplaceGenerics::ignoreChildren(TypeId ty) { - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. @@ -4396,9 +4377,9 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) return ttv->state == TableState::Generic; - else if (get(ty)) + else if (log->getMutable(ty)) return std::find(generics.begin(), generics.end(), ty) != generics.end(); else return false; @@ -4406,7 +4387,7 @@ bool ReplaceGenerics::isDirty(TypeId ty) bool ReplaceGenerics::isDirty(TypePackId tp) { - if (get(tp)) + if (log->getMutable(tp)) return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); else return false; @@ -4415,7 +4396,7 @@ bool ReplaceGenerics::isDirty(TypePackId tp) TypeId ReplaceGenerics::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; @@ -4434,9 +4415,9 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) bool Quantification::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) return level.subsumes(ttv->level) && ((ttv->state == TableState::Free) || (ttv->state == TableState::Unsealed)); - else if (const FreeTypeVar* ftv = get(ty)) + else if (const FreeTypeVar* ftv = log->getMutable(ty)) return level.subsumes(ftv->level); else return false; @@ -4444,7 +4425,7 @@ bool Quantification::isDirty(TypeId ty) bool Quantification::isDirty(TypePackId tp) { - if (const FreeTypePack* ftv = get(tp)) + if (const FreeTypePack* ftv = log->getMutable(tp)) return level.subsumes(ftv->level); else return false; @@ -4453,7 +4434,7 @@ bool Quantification::isDirty(TypePackId tp) TypeId Quantification::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) { TableState state = (ttv->state == TableState::Unsealed ? TableState::Sealed : TableState::Generic); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, state}; @@ -4479,9 +4460,9 @@ TypePackId Quantification::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); - else if (get(ty)) + else if (log->getMutable(ty)) return true; else return false; @@ -4489,7 +4470,7 @@ bool Anyification::isDirty(TypeId ty) bool Anyification::isDirty(TypePackId tp) { - if (get(tp)) + if (log->getMutable(tp)) return true; else return false; @@ -4498,7 +4479,7 @@ bool Anyification::isDirty(TypePackId tp) TypeId Anyification::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; @@ -4535,6 +4516,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location return ty; } + quantification.log = TxnLog::empty(); quantification.level = scope->level; quantification.generics.clear(); quantification.genericPacks.clear(); @@ -4558,8 +4540,11 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location return *qty; } -TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location) +TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + LUAU_ASSERT(log != nullptr); + + instantiation.log = FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(); instantiation.level = scope->level; instantiation.currentModule = currentModule; std::optional instantiated = instantiation.substitute(ty); @@ -4574,6 +4559,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { + anyification.log = TxnLog::empty(); anyification.anyType = anyType; anyification.anyTypePack = anyTypePack; anyification.currentModule = currentModule; @@ -4589,6 +4575,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { + anyification.log = TxnLog::empty(); anyification.anyType = anyType; anyification.anyTypePack = anyTypePack; anyification.currentModule = currentModule; @@ -4660,7 +4647,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d } }; - if (auto ttv = getTableType(follow(utk->table))) + if (auto ttv = getTableType(FFlag::LuauUnionTagMatchFix ? utk->table : follow(utk->table))) accumulate(ttv->props); else if (auto ctv = get(follow(utk->table))) { @@ -4775,6 +4762,29 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return getSingletonTypes().errorRecoveryTypePack(guess); } +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) { + return [this, sense](TypeId ty) -> std::optional { + // any/error/free gets a special pass unconditionally because they can't be decided. + if (get(ty) || get(ty) || get(ty)) + return ty; + + // maps boolean primitive to the corresponding singleton equal to sense + if (isPrim(ty, PrimitiveTypeVar::Boolean)) + return singletonType(sense); + + // if we have boolean singleton, eliminate it if the sense doesn't match with that singleton + if (auto boolean = get(get(ty))) + return boolean->value == sense ? std::optional(ty) : std::nullopt; + + // if we have nil, eliminate it if sense is true, otherwise take it + if (isNil(ty)) + return sense ? std::nullopt : std::optional(ty); + + // at this point, anything else is kept if sense is true, or eliminated otherwise + return sense ? std::optional(ty) : std::nullopt; + }; +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4783,6 +4793,11 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } +std::optional TypeChecker::pickTypesFromSense(TypeId type, bool sense) +{ + return filterMap(type, mkTruthyPredicate(sense)); +} + TypeId TypeChecker::addTV(TypeVar&& tv) { return currentModule->internalTypes.addType(std::move(tv)); @@ -5293,6 +5308,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, for (size_t i = 0; i < tf.typePackParams.size(); ++i) applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; + applyTypeFunction.log = TxnLog::empty(); applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; applyTypeFunction.encounteredForwardedType = false; @@ -5507,9 +5523,6 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { - if (!FFlag::LuauLValueAsKey) - return DEPRECATED_resolveLValue(scope, lvalue); - // We want to be walking the Scope parents. // We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back. // For example: @@ -5529,7 +5542,7 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV const LValue* currentLValue = &lvalue; while (currentLValue) { - if (auto it = currentScope->refinements.NEW_refinements.find(*currentLValue); it != currentScope->refinements.NEW_refinements.end()) + if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) { found = it->second; break; @@ -5576,43 +5589,9 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV return std::nullopt; } -std::optional TypeChecker::DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue) -{ - auto [symbol, keys] = getFullName(lvalue); - - ScopePtr currentScope = scope; - while (currentScope) - { - if (auto it = currentScope->refinements.DEPRECATED_refinements.find(toString(lvalue)); it != currentScope->refinements.DEPRECATED_refinements.end()) - return it->second; - - // Should not be using scope->lookup. This is already recursive. - if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) - { - std::optional currentTy = it->second.typeId; - - for (std::string key : keys) - { - // TODO: This function probably doesn't need Location at all, or at least should hide the argument. - currentTy = getIndexTypeFromType(scope, *currentTy, key, Location(), false); - if (!currentTy) - break; - } - - return currentTy; - } - - currentScope = currentScope->parent; - } - - return std::nullopt; -} - std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) { - if (auto it = refis.DEPRECATED_refinements.find(toString(lvalue)); it != refis.DEPRECATED_refinements.end()) - return it->second; - else if (auto it = refis.NEW_refinements.find(lvalue); it != refis.NEW_refinements.end()) + if (auto it = refis.find(lvalue); it != refis.end()) return it->second; else return resolveLValue(scope, lvalue); @@ -5661,35 +5640,46 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { - auto predicate = [sense](TypeId option) -> std::optional { - if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) - return option; - - return std::nullopt; - }; - - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauAssertStripsFalsyTypes) { std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); if (ty && fromOr) return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, predicate); + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); } else { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; + auto predicate = [sense](TypeId option) -> std::optional { + if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) + return option; - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + return std::nullopt; + }; - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + refineLValue(truthyP.lvalue, refis, scope, predicate); + } + else + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (!ty) + return; + + // This is a hack. :( + // Without this, the expression 'a or b' might refine 'b' to be falsy. + // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. + if (fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, truthyP.lvalue, *result); + } } } @@ -5929,7 +5919,9 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa if (!sense || canUnify(eqP.type, option, eqP.location).empty()) return sense ? eqP.type : option; - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } return option; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 5b162b31..2321eafd 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(LuauUnionTagMatchFix) namespace Luau { @@ -288,10 +289,13 @@ std::optional getMetatable(TypeId type) const TableTypeVar* getTableType(TypeId type) { + if (FFlag::LuauUnionTagMatchFix) + type = follow(type); + if (const TableTypeVar* ttv = get(type)) return ttv; else if (const MetatableTypeVar* mtv = get(type)) - return get(mtv->table); + return get(FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table); else return nullptr; } @@ -308,7 +312,7 @@ const std::string* getName(TypeId type) { if (mtv->syntheticName) return &*mtv->syntheticName; - type = mtv->table; + type = FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table; } if (auto ttv = get(type)) diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index 1f7ef8c2..f037351e 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -20,7 +20,6 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE); #include LUAU_FASTFLAG(DebugLuauFreezeArena) -LUAU_FASTFLAGVARIABLE(LuauTypedAllocatorZeroStart, false) namespace Luau { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 17d9bf58..89e4ae23 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -8,6 +8,7 @@ #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" #include "Luau/VisitTypeVar.h" +#include "Luau/ToString.h" #include @@ -22,6 +23,7 @@ LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) +LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) namespace Luau { @@ -225,19 +227,33 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - type = follow(type); - - if (auto ttv = get(type)) + if (FFlag::LuauUnionTagMatchFix) { - for (auto&& [name, prop] : ttv->props) + if (auto ttv = getTableType(type)) { - if (auto sing = get(follow(prop.type))) - return {{name, sing}}; + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } } } - else if (auto mttv = get(type)) + else { - return getTableMatchTag(mttv->table); + type = follow(type); + + if (auto ttv = get(type)) + { + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + else if (auto mttv = get(type)) + { + return getTableMatchTag(mttv->table); + } } return std::nullopt; @@ -508,245 +524,21 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { - // A | B <: T if A <: T and B <: T - bool failed = false; - std::optional unificationTooComplex; - std::optional firstFailedOption; - - size_t count = uv->options.size(); - size_t i = 0; - - for (TypeId type : uv->options) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (!firstFailedOption && !isNil(type)) - firstFailedOption = {innerState.errors.front()}; - - failed = true; - } - - if (FFlag::LuauUseCommittingTxnLog) - { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } - } - else - { - if (i != count - 1) - { - innerState.DEPRECATED_log.rollback(); - } - else - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } - } - - ++i; - } - - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (failed) - { - if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); - else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } + tryUnifyUnionWithType(subTy, uv, superTy); } else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { - // T <: A | B if T <: A or T <: B - bool found = false; - std::optional unificationTooComplex; - - size_t failedOptionCount = 0; - std::optional failedOption; - - bool foundHeuristic = false; - size_t startIndex = 0; - - if (const std::string* subName = getName(subTy)) - { - for (size_t i = 0; i < uv->options.size(); ++i) - { - const std::string* optionName = getName(uv->options[i]); - if (optionName && *optionName == *subName) - { - foundHeuristic = true; - startIndex = i; - break; - } - } - } - - if (auto subMatchTag = getTableMatchTag(subTy)) - { - for (size_t i = 0; i < uv->options.size(); ++i) - { - auto optionMatchTag = getTableMatchTag(uv->options[i]); - if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) - { - foundHeuristic = true; - startIndex = i; - break; - } - } - } - - if (!foundHeuristic && cacheEnabled) - { - for (size_t i = 0; i < uv->options.size(); ++i) - { - TypeId type = uv->options[i]; - - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } - } - } - - for (size_t i = 0; i < uv->options.size(); ++i) - { - TypeId type = uv->options[(i + startIndex) % uv->options.size()]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, isFunctionCall); - - if (innerState.errors.empty()) - { - found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - - break; - } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - { - unificationTooComplex = e; - } - else if (!isNil(type)) - { - failedOptionCount++; - - if (!failedOption) - failedOption = {innerState.errors.front()}; - } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); - } - - if (unificationTooComplex) - { - reportError(*unificationTooComplex); - } - else if (!found) - { - if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); - else - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); - } + tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } else if (const IntersectionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { - std::optional unificationTooComplex; - std::optional firstFailedOption; - - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; - } - - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } - - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + tryUnifyTypeWithIntersection(subTy, superTy, uv); } else if (const IntersectionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { - // A & B <: T if T <: A or T <: B - bool found = false; - std::optional unificationTooComplex; - - size_t startIndex = 0; - - if (cacheEnabled) - { - for (size_t i = 0; i < uv->parts.size(); ++i) - { - TypeId type = uv->parts[i]; - - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) - { - startIndex = i; - break; - } - } - } - - for (size_t i = 0; i < uv->parts.size(); ++i) - { - TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy, isFunctionCall); - - if (innerState.errors.empty()) - { - found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - break; - } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - { - unificationTooComplex = e; - } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); - } - - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (!found) - { - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); - } + tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) @@ -797,6 +589,253 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool DEPRECATED_log.popSeen(superTy, subTy); } +void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy) +{ + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + std::optional firstFailedOption; + + size_t count = uv->options.size(); + size_t i = 0; + + for (TypeId type : uv->options) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' + if (!firstFailedOption && !isNil(type)) + firstFailedOption = {innerState.errors.front()}; + + failed = true; + } + + if (FFlag::LuauUseCommittingTxnLog) + { + if (i == count - 1) + { + log.concat(std::move(innerState.log)); + } + } + else + { + if (i != count - 1) + { + innerState.DEPRECATED_log.rollback(); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + } + + ++i; + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) + { + if (firstFailedOption) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + else + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + } +} + +void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall) +{ + // T <: A | B if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; + + size_t failedOptionCount = 0; + std::optional failedOption; + + bool foundHeuristic = false; + size_t startIndex = 0; + + if (const std::string* subName = getName(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + + if (!foundHeuristic && cacheEnabled) + { + auto& cache = sharedState.cachedUnify; + + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + { + startIndex = i; + break; + } + } + } + + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[(i + startIndex) % uv->options.size()]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(subTy, type, isFunctionCall); + + if (innerState.errors.empty()) + { + found = true; + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; + } + else if (!isNil(type)) + { + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; + } + + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); + } + + if (unificationTooComplex) + { + reportError(*unificationTooComplex); + } + else if (!found) + { + if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + else + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + } +} + +void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv) +{ + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; + } + + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (firstFailedOption) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); +} + +void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) +{ + // A & B <: T if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; + + size_t startIndex = 0; + + if (cacheEnabled) + { + auto& cache = sharedState.cachedUnify; + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[i]; + + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } + } + } + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy, isFunctionCall); + + if (innerState.errors.empty()) + { + found = true; + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; + } + + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (!found) + { + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + } +} + void Unifier::cacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); @@ -1119,8 +1158,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal auto [superTypes, superTail] = logAwareFlatten(superTp, log); auto [subTypes, subTail] = logAwareFlatten(subTp, log); - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && + (subTail && log.getMutable(*subTail)); auto superIter = WeirdIter(superTp, log); auto subIter = WeirdIter(subTp, log); @@ -1667,6 +1706,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableTypeVar* superTable = getMutable(superTy); TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1679,7 +1725,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) for (const auto& [propName, superProp] : superTable->props) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + + bool isAny = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(superProp.type)) : get(follow(superProp.type)); + + if (subIter == subTable->props.end() && !isOptional(superProp.type) && !isAny) missingProperties.push_back(propName); } @@ -1697,7 +1747,10 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) for (const auto& [propName, subProp] : subTable->props) { auto superIter = superTable->props.find(propName); - if (superIter == superTable->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + + bool isAny = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(subProp.type)) : get(follow(subProp.type)); + if (superIter == superTable->props.end() && !isOptional(subProp.type) && !isAny) extraProperties.push_back(propName); } @@ -1775,6 +1828,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableTypeVar* ttv = getMutable(pendingSub); LUAU_ASSERT(ttv); ttv->props[name] = prop; + subTable = ttv; } else { @@ -1831,6 +1885,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) PendingType* pendingSuper = log.queue(superTy); TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); pendingSuperTtv->props[name] = clone; + superTable = pendingSuperTtv; } else { @@ -1853,6 +1908,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) PendingType* pendingSuper = log.queue(superTy); TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); pendingSuperTtv->props[name] = prop; + superTable = pendingSuperTtv; } else { @@ -1967,7 +2023,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else { - DEPRECATED_log(subTy); + DEPRECATED_log(subTable); subTable->boundTo = superTy; } } @@ -2408,8 +2464,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError( - TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); if (FFlag::LuauUseCommittingTxnLog) log.concat(std::move(innerState.log)); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3c607d24..f559e2e0 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) @@ -957,7 +956,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial) { nextLexeme(); - AstExpr* expr = parsePrimaryExpr(/* asStatement= */ FFlag::LuauFixAmbiguousErrorRecoveryInAssign); + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); if (!isExprLValue(expr)) expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp index 254df3f0..a509ab89 100644 --- a/CLI/Coverage.cpp +++ b/CLI/Coverage.cpp @@ -68,7 +68,7 @@ void coverageDump(const char* path) fprintf(f, "TN:\n"); - for (int fref: gCoverage.functions) + for (int fref : gCoverage.functions) { lua_getref(L, fref); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index c6807022..fe005aec 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -77,7 +77,7 @@ std::optional readFile(const std::string& name) std::optional readStdin() { std::string result; - char buffer[4096] = { }; + char buffer[4096] = {}; while (fgets(buffer, sizeof(buffer), stdin) != nullptr) result.append(buffer); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index ab0f0ed0..5af6b508 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -10,7 +10,7 @@ #include "Profiler.h" #include "Coverage.h" -#include "linenoise.hpp" +#include "isocline.h" #include @@ -240,9 +240,10 @@ std::string runCode(lua_State* L, const std::string& source) return std::string(); } -static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) +static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) { - std::string_view lookup = editBuffer + start; + auto* L = reinterpret_cast(ic_completion_arg(cenv)); + std::string_view lookup = editBuffer; char lastSep = 0; for (;;) @@ -268,13 +269,14 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) { - std::string completion(editBuffer + std::string(key.substr(prefix.size()))); + std::string completedComponent(key.substr(prefix.size())); + std::string completion(editBuffer + completedComponent); if (valueType == LUA_TFUNCTION) { // Add an opening paren for function calls by default. completion += "("; } - completions.push_back(completion); + ic_add_completion_ex(cenv, completion.data(), key.data(), nullptr); } } lua_pop(L, 1); @@ -310,19 +312,23 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_pop(L, 1); } -static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) +static bool isMethodOrFunctionChar(const char* s, long len) { - size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == ':' || editBuffer[start - 1] == '_')) - start--; + char c = *s; + return len == 1 && (isalnum(c) || c == '.' || c == ':' || c == '_'); +} + +static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) +{ + auto* L = reinterpret_cast(ic_completion_arg(cenv)); // look the value up in current global table first lua_pushvalue(L, LUA_GLOBALSINDEX); - completeIndexer(L, editBuffer, start, completions); + ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); // and in actual global table after that lua_getglobal(L, "_G"); - completeIndexer(L, editBuffer, start, completions); + ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); } struct LinenoiseScopedHistory @@ -341,13 +347,11 @@ struct LinenoiseScopedHistory } if (!historyFilepath.empty()) - linenoise::LoadHistory(historyFilepath.c_str()); + ic_set_history(historyFilepath.c_str(), -1 /* default entries (= 200) */); } ~LinenoiseScopedHistory() { - if (!historyFilepath.empty()) - linenoise::SaveHistory(historyFilepath.c_str()); } std::string historyFilepath; @@ -355,28 +359,32 @@ struct LinenoiseScopedHistory static void runReplImpl(lua_State* L) { - linenoise::SetCompletionCallback([L](const char* editBuffer, std::vector& completions) { - completeRepl(L, editBuffer, completions); - }); + ic_set_default_completer(completeRepl, L); + + // Make brace matching easier to see + ic_style_def("ic-bracematch", "teal"); + + // Prevent auto insertion of braces + ic_enable_brace_insertion(false); std::string buffer; LinenoiseScopedHistory scopedHistory; for (;;) { - bool quit = false; - std::string line = linenoise::Readline(buffer.empty() ? "> " : ">> ", quit); - if (quit) + const char* line = ic_readline(buffer.empty() ? "" : ">"); + if (!line) break; if (buffer.empty() && runCode(L, std::string("return ") + line) == std::string()) { - linenoise::AddHistory(line.c_str()); + ic_history_add(line); continue; } + if (!buffer.empty()) + buffer += "\n"; buffer += line; - buffer += " "; // linenoise doesn't work very well with multiline history entries std::string error = runCode(L, buffer); @@ -390,8 +398,9 @@ static void runReplImpl(lua_State* L) fprintf(stdout, "%s\n", error.c_str()); } - linenoise::AddHistory(buffer.c_str()); + ic_history_add(buffer.c_str()); buffer.clear(); + free((void*)line); } } diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp index b3131712..75995e6a 100644 --- a/CLI/ReplEntry.cpp +++ b/CLI/ReplEntry.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Repl.h" @@ -7,4 +6,4 @@ int main(int argc, char** argv) { return replMain(argc, argv); -} \ No newline at end of file +} diff --git a/CMakeLists.txt b/CMakeLists.txt index b9f7a9e1..881d3c3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ if(EXT_PLATFORM_STRING) endif() cmake_minimum_required(VERSION 3.0) -project(Luau LANGUAGES CXX) +project(Luau LANGUAGES CXX C) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) @@ -16,6 +16,7 @@ add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) add_library(Luau.VM STATIC) +add_library(isocline STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) @@ -52,6 +53,8 @@ target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) +target_include_directories(isocline PUBLIC extern/isocline/include) + set(LUAU_OPTIONS) if(MSVC) @@ -75,9 +78,16 @@ if(LUAU_BUILD_WEB) list(APPEND LUAU_OPTIONS -fexceptions) endif() +set(ISOCLINE_OPTIONS) + +if (NOT MSVC) + list(APPEND ISOCLINE_OPTIONS -Wno-unused-function) +endif() + target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) +target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: @@ -89,8 +99,9 @@ if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) - target_include_directories(Luau.Repl.CLI PRIVATE extern) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) + target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) + + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM isocline) if(UNIX) find_library(LIBPTHREAD pthread) @@ -113,7 +124,7 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.CLI.Test PRIVATE extern CLI) - target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM isocline) if(UNIX) find_library(LIBPTHREAD pthread) if (LIBPTHREAD) diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index a907271c..26360c49 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,7 +4,7 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin, false) +LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin2, false) namespace Luau { @@ -64,7 +64,7 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; - if (FFlag::LuauCompileSelectBuiltin && builtin.isGlobal("select")) + if (FFlag::LuauCompileSelectBuiltin2 && builtin.isGlobal("select")) return LBF_SELECT_VARARG; if (builtin.object == "math") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7da85244..e4253adc 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,7 +16,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) -LUAU_FASTFLAG(LuauCompileSelectBuiltin) +LUAU_FASTFLAG(LuauCompileSelectBuiltin2) namespace Luau { @@ -266,7 +266,7 @@ struct Compiler void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); LUAU_ASSERT(targetCount == 1); LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); @@ -291,6 +291,9 @@ struct Compiler // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten compileExprTemp(expr->func, regs); + if (argreg != regs + 1) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1), argreg, 0); + bytecode.emitABC(LOP_GETVARARGS, uint8_t(regs + 2), 0, 0); size_t callLabel = bytecode.emitLabel(); @@ -405,7 +408,7 @@ struct Compiler if (bfid == LBF_SELECT_VARARG) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp index 9dc2f0a4..5a866e87 100644 --- a/Compiler/src/TableShape.cpp +++ b/Compiler/src/TableShape.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "TableShape.h" -LUAU_FASTFLAGVARIABLE(LuauPredictTableSizeLoop, false) - namespace Luau { namespace Compile @@ -87,9 +85,6 @@ struct ShapeVisitor : AstVisitor } else if (AstExprLocal* iter = index->as()) { - if (!FFlag::LuauPredictTableSizeLoop) - return; - if (const unsigned int* bound = loops.find(iter->local)) { TableShape& shape = shapes[*table]; @@ -143,9 +138,6 @@ struct ShapeVisitor : AstVisitor bool visit(AstStatFor* node) override { - if (!FFlag::LuauPredictTableSizeLoop) - return true; - AstExprConstantNumber* from = node->from->as(); AstExprConstantNumber* to = node->to->as(); diff --git a/Makefile b/Makefile index 638c4c63..80eff018 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ VM_SOURCES=$(wildcard VM/src/*.cpp) VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) VM_TARGET=$(BUILD)/libluauvm.a +ISOCLINE_SOURCES=extern/isocline/src/isocline.c +ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o) +ISOCLINE_TARGET=$(BUILD)/libisocline.a + TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests @@ -43,7 +47,7 @@ ifneq ($(flags),) TESTS_ARGS+=--fflags=$(flags) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) # common flags CXXFLAGS=-g -Wall @@ -90,8 +94,9 @@ $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include +$(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include $(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern -$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern +$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include @@ -116,8 +121,8 @@ coverage: $(TESTS_TARGET) $(TESTS_TARGET) llvm-profdata merge default.profraw default-flags.profraw -o default.profdata rm default.profraw default-flags.profraw - llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests - llvm-cov report -ignore-filename-regex=\(tests\|extern\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests + llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests + llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests llvm-cov export -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info format: @@ -135,8 +140,8 @@ luau-analyze: $(ANALYZE_CLI_TARGET) ln -fs $^ $@ # executable targets -$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) -$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) +$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) +$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): @@ -154,8 +159,9 @@ $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) $(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS) $(VM_TARGET): $(VM_OBJECTS) +$(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS) -$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(VM_TARGET): +$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): ar rcs $@ $^ # object file targets @@ -163,6 +169,10 @@ $(BUILD)/%.cpp.o: %.cpp @mkdir -p $(dir $@) $(CXX) $< $(CXXFLAGS) -c -MMD -MP -o $@ +$(BUILD)/%.c.o: %.c + @mkdir -p $(dir $@) + $(CXX) -x c $< $(CXXFLAGS) -c -MMD -MP -o $@ + # protobuf fuzzer setup fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. diff --git a/Sources.cmake b/Sources.cmake index 22e7af22..b36b6db5 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -167,6 +167,11 @@ target_sources(Luau.VM PRIVATE VM/src/lvm.h ) +target_sources(isocline PRIVATE + extern/isocline/include/isocline.h + extern/isocline/src/isocline.c +) + if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 7e0832e7..c5bf1c18 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -83,7 +83,7 @@ #endif #ifndef LUAI_GCSTEPSIZE -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ #endif /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ @@ -153,6 +153,6 @@ long l; \ } -#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ +#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ #define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2 diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 19222861..7592a14c 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -250,7 +250,7 @@ static int coclose(lua_State* L) { lua_pushboolean(L, false); if (lua_gettop(co)) - lua_xmove(co, L, 1); /* move error message */ + lua_xmove(co, L, 1); /* move error message */ lua_resetthread(co); return 2; } diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 2b5382bb..e9930f7a 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,7 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauBytecodeV2Read) LUAU_FASTFLAG(LuauBytecodeV2Force) static const char* getfuncname(Closure* f); @@ -96,7 +95,7 @@ static int getlinedefined(Proto* p) { if (FFlag::LuauBytecodeV2Force) return p->linedefined; - else if (FFlag::LuauBytecodeV2Read && p->linedefined >= 0) + else if (p->linedefined >= 0) return p->linedefined; else return luaG_getline(p, 0); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 6088f71c..582d4627 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -90,7 +90,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) uv->tt = LUA_TUPVAL; uv->marked = luaC_white(g); uv->memcat = L->activememcat; - uv->v = level; /* current value lives in the stack */ + uv->v = level; /* current value lives in the stack */ // chain the upvalue in the threads open upvalue list at the proper position UpVal* next = *pp; @@ -138,8 +138,8 @@ void luaF_unlinkupval(UpVal* uv) void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { - if (uv->v != &uv->u.value) /* is it open? */ - luaF_unlinkupval(uv); /* remove from open list */ + if (uv->v != &uv->u.value) /* is it open? */ + luaF_unlinkupval(uv); /* remove from open list */ luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); /* free upvalue */ } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 82ac0009..835572fa 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -759,6 +759,8 @@ static int sweepgcopage(lua_State* L, lua_Page* page) // when true is returned it means that the element was deleted if (sweepgco(L, page, gco)) { + LUAU_ASSERT(busyBlocks > 0); + // if the last block was removed, page would be removed as well if (--busyBlocks == 0) return int(pos - start) / blockSize + 1; diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index e1dbce50..de85cf59 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -57,8 +57,7 @@ const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms -// TODO (FFlagLuauGcPagedSweep): when 'next' is removed, 'kBlockHeader' can be used unconditionally -const size_t kGCOHeader = sizeof(GCheader) > kBlockHeader ? sizeof(GCheader) : kBlockHeader; +const size_t kGCOLinkOffset = (sizeof(GCheader) + sizeof(void*) - 1) & ~(sizeof(void*) - 1); // GCO pages contain freelist links after the GC header struct SizeClassConfig { @@ -101,12 +100,12 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; -// size class for a block of size sz +// size class for a block of size sz; returns -1 for size=0 because empty allocations take no space #define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSize ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) -#define freegcolink(block) (*(void**)((char*)block + kGCOHeader)) +#define freegcolink(block) (*(void**)((char*)block + kGCOLinkOffset)) /* ** About the realloc function: @@ -157,7 +156,7 @@ l_noret luaM_toobig(lua_State* L) luaG_runerror(L, "memory allocation error: block too big"); } -static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) +static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); @@ -253,7 +252,7 @@ static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** g return page; } -static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) +static void freepageold(lua_State* L, lua_Page* page, uint8_t sizeClass) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); @@ -310,7 +309,7 @@ static void freeclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopa freepage(L, gcopageset, page); } -static void* luaM_newblock(lua_State* L, int sizeClass) +static void* newblock(lua_State* L, int sizeClass) { global_State* g = L->global; lua_Page* page = g->freepages[sizeClass]; @@ -321,7 +320,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) if (FFlag::LuauGcPagedSweep) page = newclasspage(L, g->freepages, NULL, sizeClass, true); else - page = luaM_newpage(L, sizeClass); + page = newpageold(L, sizeClass); } LUAU_ASSERT(!page->prev); @@ -363,7 +362,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) return (char*)block + kBlockHeader; } -static void* luaM_newgcoblock(lua_State* L, int sizeClass) +static void* newgcoblock(lua_State* L, int sizeClass) { LUAU_ASSERT(FFlag::LuauGcPagedSweep); @@ -390,11 +389,10 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) } else { + block = page->freeList; + ASAN_UNPOISON_MEMORY_REGION((char*)block + sizeof(GCheader), page->blockSize - sizeof(GCheader)); + // when separate block metadata is not used, free list link is stored inside the block data itself - block = (char*)page->freeList - kGCOHeader; - - ASAN_UNPOISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); - page->freeList = freegcolink(block); page->busyBlocks++; } @@ -412,7 +410,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) return (char*)block; } -static void luaM_freeblock(lua_State* L, int sizeClass, void* block) +static void freeblock(lua_State* L, int sizeClass, void* block) { global_State* g = L->global; @@ -450,11 +448,11 @@ static void luaM_freeblock(lua_State* L, int sizeClass, void* block) if (FFlag::LuauGcPagedSweep) freeclasspage(L, g->freepages, NULL, page, sizeClass); else - luaM_freepage(L, page, sizeClass); + freepageold(L, page, sizeClass); } } -static void luaM_freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) +static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) { LUAU_ASSERT(FFlag::LuauGcPagedSweep); @@ -474,9 +472,9 @@ static void luaM_freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page // when separate block metadata is not used, free list link is stored inside the block data itself freegcolink(block) = page->freeList; - page->freeList = (char*)block + kGCOHeader; + page->freeList = block; - ASAN_POISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + ASAN_POISON_MEMORY_REGION((char*)block + sizeof(GCheader), page->blockSize - sizeof(GCheader)); page->busyBlocks--; @@ -491,7 +489,7 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) int nclass = sizeclass(nsize); - void* block = nclass >= 0 ? luaM_newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + void* block = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); if (block == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -506,6 +504,9 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) if (!FFlag::LuauGcPagedSweep) return (GCObject*)luaM_new_(L, nsize, memcat); + // we need to accommodate space for link for free blocks (freegcolink) + LUAU_ASSERT(nsize >= kGCOLinkOffset + sizeof(void*)); + global_State* g = L->global; int nclass = sizeclass(nsize); @@ -514,9 +515,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) if (nclass >= 0) { - LUAU_ASSERT(nsize > 8); - - block = luaM_newgcoblock(L, nclass); + block = newgcoblock(L, nclass); } else { @@ -546,7 +545,7 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) int oclass = sizeclass(osize); if (oclass >= 0) - luaM_freeblock(L, oclass, block); + freeblock(L, oclass, block); else (*g->frealloc)(L, g->ud, block, osize, 0); @@ -571,7 +570,7 @@ void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, { block->gch.tt = LUA_TNIL; - luaM_freegcoblock(L, oclass, block, page); + freegcoblock(L, oclass, block, page); } else { @@ -596,7 +595,7 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 // if either block needs to be allocated using a block allocator, we can't use realloc directly if (nclass >= 0 || oclass >= 0) { - result = nclass >= 0 ? luaM_newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + result = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); if (result == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -604,7 +603,7 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 memcpy(result, block, osize < nsize ? osize : nsize); if (oclass >= 0) - luaM_freeblock(L, oclass, block); + freeblock(L, oclass, block); else (*g->frealloc)(L, g->ud, block, osize, 0); } @@ -659,6 +658,8 @@ void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context // when true is returned it means that the element was deleted if (visitor(context, page, gco)) { + LUAU_ASSERT(busyBlocks > 0); + // if the last block was removed, page would be removed as well if (--busyBlocks == 0) break; diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index cb22cc23..9bbc43de 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -57,7 +57,7 @@ void luaS_resize(lua_State* L, int newsize) { TString* p = tb->hash[i]; while (p) - { /* for each node in the list */ + { /* for each node in the list */ // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required TString* next = (TString*)p->next; /* save next */ unsigned int h = p->hash; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e58ff2a8..cba3670a 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -676,14 +676,9 @@ static void luau_execute(lua_State* L) VM_PROTECT_PC(); // set may fail TValue* res = luaH_setstr(L, h, tsvalue(kv)); - - if (res != luaO_nilobject) - { - int cachedslot = gval2slot(h, res); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, cachedslot); - } - + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); setobj(L, res, ra); luaC_barriert(L, h, ra); VM_NEXT(); diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index cdb276c0..2472cd90 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens @@ -157,11 +156,12 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : FFlag::LuauBytecodeV2Read ? (version != LBC_VERSION && version != LBC_VERSION_FUTURE) : (version != LBC_VERSION)) + if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : (version != LBC_VERSION && version != LBC_VERSION_FUTURE)) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, + FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); return 1; } @@ -292,7 +292,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } - if (FFlag::LuauBytecodeV2Force || (FFlag::LuauBytecodeV2Read && version == LBC_VERSION_FUTURE)) + if (FFlag::LuauBytecodeV2Force || version == LBC_VERSION_FUTURE) p->linedefined = readVarInt(data, size, offset); else p->linedefined = -1; diff --git a/extern/isocline/.gitignore b/extern/isocline/.gitignore new file mode 100644 index 00000000..470cc813 --- /dev/null +++ b/extern/isocline/.gitignore @@ -0,0 +1,16 @@ +out/ +build/ +dist/ +doc/html/ +.vs/ +.vscode/ +.stack-work/ +.DS_Store +*.user +*.exe +*.hi +*.o +*_stub.h +*.lock +history.txt +isocline.debug.txt \ No newline at end of file diff --git a/extern/isocline/LICENSE b/extern/isocline/LICENSE new file mode 100644 index 00000000..7ac3104b --- /dev/null +++ b/extern/isocline/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Daan Leijen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/extern/isocline/include/isocline.h b/extern/isocline/include/isocline.h new file mode 100644 index 00000000..0d46cf3f --- /dev/null +++ b/extern/isocline/include/isocline.h @@ -0,0 +1,627 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_ISOCLINE_H +#define IC_ISOCLINE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include // size_t +#include // bool +#include // uint32_t +#include // term_vprintf + + +/*! \mainpage +Isocline C API reference. + +Isocline is a pure C library that can be used as an alternative to the GNU readline library. + +See the [Github repository](https://github.com/daanx/isocline#readme) +for general information and building the library. + +Contents: +- \ref readline +- \ref bbcode +- \ref history +- \ref completion +- \ref highlight +- \ref options +- \ref helper +- \ref completex +- \ref term +- \ref async +- \ref alloc +*/ + +/// \defgroup readline Readline +/// The basic readline interface. +/// \{ + +/// Isocline version: 102 = 1.0.2. +#define IC_VERSION (104) + + +/// Read input from the user using rich editing abilities. +/// @param prompt_text The prompt text, can be NULL for the default (""). +/// The displayed prompt becomes `prompt_text` followed by the `prompt_marker` ("> "). +/// @returns the heap allocated input on succes, which should be `free`d by the caller. +/// Returns NULL on error, or if the user typed ctrl+d or ctrl+c. +/// +/// If the standard input (`stdin`) has no editing capability +/// (like a dumb terminal (e.g. `TERM`=`dumb`), running in a debuggen, a pipe or redirected file, etc.) +/// the input is read directly from the input stream up to the +/// next line without editing capability. +/// See also \a ic_set_prompt_marker(), \a ic_style_def() +/// +/// @see ic_set_prompt_marker(), ic_style_def() +char* ic_readline(const char* prompt_text); + +/// \} + + +//-------------------------------------------------------------- +/// \defgroup bbcode Formatted Text +/// Formatted text using [bbcode markup](https://github.com/daanx/isocline#bbcode-format). +/// \{ + +/// Print to the terminal while respection bbcode markup. +/// Any unclosed tags are closed automatically at the end of the print. +/// For example: +/// ``` +/// ic_print("[b]bold, [i]bold and italic[/i], [red]red and bold[/][/b] default."); +/// ic_print("[b]bold[/], [i b]bold and italic[/], [yellow on blue]yellow on blue background"); +/// ic_style_add("em","i color=#888800"); +/// ic_print("[em]emphasis"); +/// ``` +/// Properties that can be assigned are: +/// * `color=` _clr_, `bgcolor=` _clr_: where _clr_ is either a hex value `#`RRGGBB or `#`RGB, a +/// standard HTML color name, or an ANSI palette name, like `ansi-maroon`, `ansi-default`, etc. +/// * `bold`,`italic`,`reverse`,`underline`: can be `on` or `off`. +/// * everything else is a style; all HTML and ANSI color names are also a style (so we can just use `red` +/// instead of `color=red`, or `on red` instead of `bgcolor=red`), and there are +/// the `b`, `i`, `u`, and `r` styles for bold, italic, underline, and reverse. +/// +/// See [here](https://github.com/daanx/isocline#bbcode-format) for a description of the full bbcode format. +void ic_print( const char* s ); + +/// Print with bbcode markup ending with a newline. +/// @see ic_print() +void ic_println( const char* s ); + +/// Print formatted with bbcode markup. +/// @see ic_print() +void ic_printf(const char* fmt, ...); + +/// Print formatted with bbcode markup. +/// @see ic_print +void ic_vprintf(const char* fmt, va_list args); + +/// Define or redefine a style. +/// @param style_name The name of the style. +/// @param fmt The `fmt` string is the content of a tag and can contain +/// other styles. This is very useful to theme the output of a program +/// by assigning standard styles like `em` or `warning` etc. +void ic_style_def( const char* style_name, const char* fmt ); + +/// Start a global style that is only reset when calling a matching ic_style_close(). +void ic_style_open( const char* fmt ); + +/// End a global style. +void ic_style_close(void); + +/// \} + + +//-------------------------------------------------------------- +// History +//-------------------------------------------------------------- +/// \defgroup history History +/// Readline input history. +/// \{ + +/// Enable history. +/// Use a \a NULL filename to not persist the history. Use -1 for max_entries to get the default (200). +void ic_set_history(const char* fname, long max_entries ); + +/// Remove the last entry in the history. +/// The last returned input from ic_readline() is automatically added to the history; this function removes it. +void ic_history_remove_last(void); + +/// Clear the history. +void ic_history_clear(void); + +/// Add an entry to the history +void ic_history_add( const char* entry ); + +/// \} + +//-------------------------------------------------------------- +// Basic Completion +//-------------------------------------------------------------- + +/// \defgroup completion Completion +/// Basic word completion. +/// \{ + +/// A completion environment +struct ic_completion_env_s; + +/// A completion environment +typedef struct ic_completion_env_s ic_completion_env_t; + +/// A completion callback that is called by isocline when tab is pressed. +/// It is passed a completion environment (containing the current input and the current cursor position), +/// the current input up-to the cursor (`prefix`) +/// and the user given argument when the callback was set. +/// When using completion transformers, like `ic_complete_quoted_word` the `prefix` contains the +/// the word to be completed without escape characters or quotes. +typedef void (ic_completer_fun_t)(ic_completion_env_t* cenv, const char* prefix ); + +/// Set the default completion handler. +/// @param completer The completion function +/// @param arg Argument passed to the \a completer. +/// There can only be one default completion function, setting it again disables the previous one. +/// The initial completer use `ic_complete_filename`. +void ic_set_default_completer( ic_completer_fun_t* completer, void* arg); + + +/// In a completion callback (usually from ic_complete_word()), use this function to add a completion. +/// (the completion string is copied by isocline and do not need to be preserved or allocated). +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completion(ic_completion_env_t* cenv, const char* completion); + +/// In a completion callback (usually from ic_complete_word()), use this function to add a completion. +/// The `display` is used to display the completion in the completion menu, and `help` is +/// displayed for hints for example. Both can be `NULL` for the default. +/// (all are copied by isocline and do not need to be preserved or allocated). +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completion_ex( ic_completion_env_t* cenv, const char* completion, const char* display, const char* help ); + +/// In a completion callback (usually from ic_complete_word()), use this function to add completions. +/// The `completions` array should be terminated with a NULL element, and all elements +/// are added as completions if they start with `prefix`. +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completions(ic_completion_env_t* cenv, const char* prefix, const char** completions); + +/// Complete a filename. +/// Complete a filename given a semi-colon separated list of root directories `roots` and +/// semi-colon separated list of possible extensions (excluding directories). +/// If `roots` is NULL, the current directory is the root ("."). +/// If `extensions` is NULL, any extension will match. +/// Each root directory should _not_ end with a directory separator. +/// If a directory is completed, the `dir_separator` is added at the end if it is not `0`. +/// Usually the `dir_separator` is `/` but it can be set to `\\` on Windows systems. +/// For example: +/// ``` +/// /ho --> /home/ +/// /home/.ba --> /home/.bashrc +/// ``` +/// (This already uses ic_complete_quoted_word() so do not call it from inside a word handler). +void ic_complete_filename( ic_completion_env_t* cenv, const char* prefix, char dir_separator, const char* roots, const char* extensions ); + + + +/// Function that returns whether a (utf8) character (of length `len`) is in a certain character class +/// @see ic_char_is_separator() etc. +typedef bool (ic_is_char_class_fun_t)(const char* s, long len); + + +/// Complete a _word_ (i.e. _token_). +/// Calls the user provided function `fun` to complete on the +/// current _word_. Almost all user provided completers should use this function. +/// If `is_word_char` is NULL, the default `&ic_char_is_nonseparator` is used. +/// The `prefix` passed to `fun` is modified to only contain the current word, and +/// any results from `ic_add_completion` are automatically adjusted to replace that part. +/// For example, on the input "hello w", a the user `fun` only gets `w` and can just complete +/// with "world" resulting in "hello world" without needing to consider `delete_before` etc. +/// @see ic_complete_qword() for completing quoted and escaped tokens. +void ic_complete_word(ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char); + + +/// Complete a quoted _word_. +/// Calls the user provided function `fun` to complete while taking +/// care of quotes and escape characters. Almost all user provided completers should use +/// this function. The `prefix` passed to `fun` is modified to be unquoted and unescaped, and +/// any results from `ic_add_completion` are automatically quoted and escaped again. +/// For example, completing `hello world`, the `fun` always just completes `hel` or `hello w` to `hello world`, +/// but depending on user input, it will complete as: +/// ``` +/// hel --> hello\ world +/// hello\ w --> hello\ world +/// hello w --> # no completion, the word is just 'w'> +/// "hel --> "hello world" +/// "hello w --> "hello world" +/// ``` +/// with proper quotes and escapes. +/// If `is_word_char` is NULL, the default `&ic_char_is_nonseparator` is used. +/// @see ic_complete_quoted_word() to customize the word boundary, quotes etc. +void ic_complete_qword( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char ); + + + +/// Complete a _word_. +/// Calls the user provided function `fun` to complete while taking +/// care of quotes and escape characters. Almost all user provided completers should use this function. +/// The `is_word_char` is a set of characters that are part of a "word". Use NULL for the default (`&ic_char_is_nonseparator`). +/// The `escape_char` is the escaping character, usually `\` but use 0 to not have escape characters. +/// The `quote_chars` define the quotes, use NULL for the default `"\'\""` quotes. +/// @see ic_complete_word() which uses the default values for `non_word_chars`, `quote_chars` and `\` for escape characters. +void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t fun, + ic_is_char_class_fun_t* is_word_char, char escape_char, const char* quote_chars ); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup highlight Syntax Highlighting +/// Basic syntax highlighting. +/// \{ + +/// A syntax highlight environment +struct ic_highlight_env_s; +typedef struct ic_highlight_env_s ic_highlight_env_t; + +/// A syntax highlighter callback that is called by readline to syntax highlight user input. +typedef void (ic_highlight_fun_t)(ic_highlight_env_t* henv, const char* input, void* arg); + +/// Set a syntax highlighter. +/// There can only be one highlight function, setting it again disables the previous one. +void ic_set_default_highlighter(ic_highlight_fun_t* highlighter, void* arg); + +/// Set the style of characters starting at position `pos`. +void ic_highlight(ic_highlight_env_t* henv, long pos, long count, const char* style ); + +/// Experimental: Convenience callback for a function that highlights `s` using bbcode's. +/// The returned string should be allocated and is free'd by the caller. +typedef char* (ic_highlight_format_fun_t)(const char* s, void* arg); + +/// Experimental: Convenience function for highlighting with bbcodes. +/// Can be called in a `ic_highlight_fun_t` callback to colorize the `input` using the +/// the provided `formatted` input that is the styled `input` with bbcodes. The +/// content of `formatted` without bbcode tags should match `input` exactly. +void ic_highlight_formatted(ic_highlight_env_t* henv, const char* input, const char* formatted); + +/// \} + +//-------------------------------------------------------------- +// Readline with a specific completer and highlighter +//-------------------------------------------------------------- + +/// \defgroup readline +/// \{ + +/// Read input from the user using rich editing abilities, +/// using a particular completion function and highlighter for this call only. +/// both can be NULL in which case the defaults are used. +/// @see ic_readline(), ic_set_prompt_marker(), ic_set_default_completer(), ic_set_default_highlighter(). +char* ic_readline_ex(const char* prompt_text, ic_completer_fun_t* completer, void* completer_arg, + ic_highlight_fun_t* highlighter, void* highlighter_arg); + +/// \} + + +//-------------------------------------------------------------- +// Options +//-------------------------------------------------------------- + +/// \defgroup options Options +/// \{ + +/// Set a prompt marker and a potential marker for extra lines with multiline input. +/// Pass \a NULL for the `prompt_marker` for the default marker (`"> "`). +/// Pass \a NULL for continuation prompt marker to make it equal to the `prompt_marker`. +void ic_set_prompt_marker( const char* prompt_marker, const char* continuation_prompt_marker ); + +/// Get the current prompt marker. +const char* ic_get_prompt_marker(void); + +/// Get the current continuation prompt marker. +const char* ic_get_continuation_prompt_marker(void); + +/// Disable or enable multi-line input (enabled by default). +/// Returns the previous setting. +bool ic_enable_multiline( bool enable ); + +/// Disable or enable sound (enabled by default). +/// A beep is used when tab cannot find any completion for example. +/// Returns the previous setting. +bool ic_enable_beep( bool enable ); + +/// Disable or enable color output (enabled by default). +/// Returns the previous setting. +bool ic_enable_color( bool enable ); + +/// Disable or enable duplicate entries in the history (disabled by default). +/// Returns the previous setting. +bool ic_enable_history_duplicates( bool enable ); + +/// Disable or enable automatic tab completion after a completion +/// to expand as far as possible if the completions are unique. (disabled by default). +/// Returns the previous setting. +bool ic_enable_auto_tab( bool enable ); + +/// Disable or enable preview of a completion selection (enabled by default) +/// Returns the previous setting. +bool ic_enable_completion_preview( bool enable ); + +/// Disable or enable automatic identation of continuation lines in multiline +/// input so it aligns with the initial prompt. +/// Returns the previous setting. +bool ic_enable_multiline_indent(bool enable); + +/// Disable or enable display of short help messages for history search etc. +/// (full help is always dispayed when pressing F1 regardless of this setting) +/// @returns the previous setting. +bool ic_enable_inline_help(bool enable); + +/// Disable or enable hinting (enabled by default) +/// Shows a hint inline when there is a single possible completion. +/// @returns the previous setting. +bool ic_enable_hint(bool enable); + +/// Set millisecond delay before a hint is displayed. Can be zero. (500ms by default). +long ic_set_hint_delay(long delay_ms); + +/// Disable or enable syntax highlighting (enabled by default). +/// This applies regardless whether a syntax highlighter callback was set (`ic_set_highlighter`) +/// Returns the previous setting. +bool ic_enable_highlight(bool enable); + + +/// Set millisecond delay for reading escape sequences in order to distinguish +/// a lone ESC from the start of a escape sequence. The defaults are 100ms and 10ms, +/// but it may be increased if working with very slow terminals. +void ic_set_tty_esc_delay(long initial_delay_ms, long followup_delay_ms); + +/// Enable highlighting of matching braces (and error highlight unmatched braces).` +bool ic_enable_brace_matching(bool enable); + +/// Set matching brace pairs. +/// Pass \a NULL for the default `"()[]{}"`. +void ic_set_matching_braces(const char* brace_pairs); + +/// Enable automatic brace insertion (enabled by default). +bool ic_enable_brace_insertion(bool enable); + +/// Set matching brace pairs for automatic insertion. +/// Pass \a NULL for the default `()[]{}\"\"''` +void ic_set_insertion_braces(const char* brace_pairs); + +/// \} + + +//-------------------------------------------------------------- +// Advanced Completion +//-------------------------------------------------------------- + +/// \defgroup completex Advanced Completion +/// \{ + +/// Get the raw current input (and cursor position if `cursor` != NULL) for the completion. +/// Usually completer functions should look at their `prefix` though as transformers +/// like `ic_complete_word` may modify the prefix (for example, unescape it). +const char* ic_completion_input( ic_completion_env_t* cenv, long* cursor ); + +/// Get the completion argument passed to `ic_set_completer`. +void* ic_completion_arg( const ic_completion_env_t* cenv ); + +/// Do we have already some completions? +bool ic_has_completions( const ic_completion_env_t* cenv ); + +/// Do we already have enough completions and should we return if possible? (for improved latency) +bool ic_stop_completing( const ic_completion_env_t* cenv); + + +/// Primitive completion, cannot be used with most transformers (like `ic_complete_word` and `ic_complete_qword`). +/// When completed, `delete_before` _bytes_ are deleted before the cursor position, +/// `delete_after` _bytes_ are deleted after the cursor, and finally `completion` is inserted. +/// The `display` is used to display the completion in the completion menu, and `help` is displayed +/// with hinting. Both `display` and `help` can be NULL. +/// (all are copied by isocline and do not need to be preserved or allocated). +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completion_prim( ic_completion_env_t* cenv, const char* completion, + const char* display, const char* help, + long delete_before, long delete_after); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup helper Character Classes. +/// Convenience functions for character classes, highlighting and completion. +/// \{ + +/// Convenience: return the position of a previous code point in a UTF-8 string `s` from postion `pos`. +/// Returns `-1` if `pos <= 0` or `pos > strlen(s)` (or other errors). +long ic_prev_char( const char* s, long pos ); + +/// Convenience: return the position of the next code point in a UTF-8 string `s` from postion `pos`. +/// Returns `-1` if `pos < 0` or `pos >= strlen(s)` (or other errors). +long ic_next_char( const char* s, long pos ); + +/// Convenience: does a string `s` starts with a given `prefix` ? +bool ic_starts_with( const char* s, const char* prefix ); + +/// Convenience: does a string `s` starts with a given `prefix` ignoring (ascii) case? +bool ic_istarts_with( const char* s, const char* prefix ); + + +/// Convenience: character class for whitespace `[ \t\r\n]`. +bool ic_char_is_white(const char* s, long len); + +/// Convenience: character class for non-whitespace `[^ \t\r\n]`. +bool ic_char_is_nonwhite(const char* s, long len); + +/// Convenience: character class for separators. +/// (``[ \t\r\n,.;:/\\(){}\[\]]``.) +/// This is used for word boundaries in isocline. +bool ic_char_is_separator(const char* s, long len); + +/// Convenience: character class for non-separators. +bool ic_char_is_nonseparator(const char* s, long len); + +/// Convenience: character class for letters (`[A-Za-z]` and any unicode > 0x80). +bool ic_char_is_letter(const char* s, long len); + +/// Convenience: character class for digits (`[0-9]`). +bool ic_char_is_digit(const char* s, long len); + +/// Convenience: character class for hexadecimal digits (`[A-Fa-f0-9]`). +bool ic_char_is_hexdigit(const char* s, long len); + +/// Convenience: character class for identifier letters (`[A-Za-z0-9_-]` and any unicode > 0x80). +bool ic_char_is_idletter(const char* s, long len); + +/// Convenience: character class for filename letters (_not in_ " \t\r\n`@$><=;|&\{\}\(\)\[\]]"). +bool ic_char_is_filename_letter(const char* s, long len); + + +/// Convenience: If this is a token start, return the length. Otherwise return 0. +long ic_is_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char); + +/// Convenience: Does this match the specified token? +/// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +/// E.g. `ic_match_token("function",0,&ic_char_is_letter,"fun")` returns 0. +/// while `ic_match_token("fun x",0,&ic_char_is_letter,"fun"})` returns 3. +long ic_match_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char* token); + + +/// Convenience: Do any of the specified tokens match? +/// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +/// E.g. `ic_match_any_token("function",0,&ic_char_is_letter,{"fun","func",NULL})` returns 0. +/// while `ic_match_any_token("func x",0,&ic_char_is_letter,{"fun","func",NULL})` returns 4. +long ic_match_any_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char** tokens); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup term Terminal +/// +/// Experimental: Low level terminal output. +/// Ensures basic ANSI SGR escape sequences are processed +/// in a portable way (e.g. on Windows) +/// \{ + +/// Initialize for terminal output. +/// Call this before using the terminal write functions (`ic_term_write`) +/// Does nothing on most platforms but on Windows it sets the console to UTF8 output and possible +/// enables virtual terminal processing. +void ic_term_init(void); + +/// Call this when done with the terminal functions. +void ic_term_done(void); + +/// Flush the terminal output. +/// (happens automatically on newline characters ('\n') as well). +void ic_term_flush(void); + +/// Write a string to the console (and process CSI escape sequences). +void ic_term_write(const char* s); + +/// Write a string to the console and end with a newline +/// (and process CSI escape sequences). +void ic_term_writeln(const char* s); + +/// Write a formatted string to the console. +/// (and process CSI escape sequences) +void ic_term_writef(const char* fmt, ...); + +/// Write a formatted string to the console. +void ic_term_vwritef(const char* fmt, va_list args); + +/// Set text attributes from a style. +void ic_term_style( const char* style ); + +/// Set text attribute to bold. +void ic_term_bold(bool enable); + +/// Set text attribute to underline. +void ic_term_underline(bool enable); + +/// Set text attribute to italic. +void ic_term_italic(bool enable); + +/// Set text attribute to reverse video. +void ic_term_reverse(bool enable); + +/// Set text attribute to ansi color palette index between 0 and 255 (or 256 for the ANSI "default" color). +/// (auto matched to smaller palette if not supported) +void ic_term_color_ansi(bool foreground, int color); + +/// Set text attribute to 24-bit RGB color (between `0x000000` and `0xFFFFFF`). +/// (auto matched to smaller palette if not supported) +void ic_term_color_rgb(bool foreground, uint32_t color ); + +/// Reset the text attributes. +void ic_term_reset( void ); + +/// Get the palette used by the terminal: +/// This is usually initialized from the COLORTERM environment variable. The +/// possible values of COLORTERM for each palette are given in parenthesis. +/// +/// - 1: monochrome (`monochrome`) +/// - 3: old ANSI terminal with 8 colors, using bold for bright (`8color`/`3bit`) +/// - 4: regular ANSI terminal with 16 colors. (`16color`/`4bit`) +/// - 8: terminal with ANSI 256 color palette. (`256color`/`8bit`) +/// - 24: true-color terminal with full RGB colors. (`truecolor`/`24bit`/`direct`) +int ic_term_get_color_bits( void ); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup async ASync +/// Async support +/// \{ + +/// Thread-safe way to asynchronously unblock a readline. +/// Behaves as if the user pressed the `ctrl-C` character +/// (resulting in returning NULL from `ic_readline`). +/// Returns `true` if the event was successfully delivered. +/// (This may not be supported on all platforms, but it is +/// functional on Linux, macOS and Windows). +bool ic_async_stop(void); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup alloc Custom Allocation +/// Register allocation functions for custom allocators +/// \{ + +typedef void* (ic_malloc_fun_t)( size_t size ); +typedef void* (ic_realloc_fun_t)( void* p, size_t newsize ); +typedef void (ic_free_fun_t)( void* p ); + +/// Initialize with custom allocation functions. +/// This must be called as the first function in a program! +void ic_init_custom_alloc( ic_malloc_fun_t* _malloc, ic_realloc_fun_t* _realloc, ic_free_fun_t* _free ); + +/// Free a potentially custom alloc'd pointer (in particular, the result returned from `ic_readline`) +void ic_free( void* p ); + +/// Allocate using the current memory allocator. +void* ic_malloc(size_t sz); + +/// Duplicate a string using the current memory allocator. +const char* ic_strdup( const char* s ); + +/// \} + +#ifdef __cplusplus +} +#endif + +#endif /// IC_ISOCLINE_H diff --git a/extern/isocline/readme.md b/extern/isocline/readme.md new file mode 100644 index 00000000..1f4709bd --- /dev/null +++ b/extern/isocline/readme.md @@ -0,0 +1,460 @@ + + + + +# Isocline: a portable readline alternative. + +Isocline is a pure C library that can be used as an alternative to the GNU readline library (latest release v1.0.9, 2022-01-15). + +- Small: less than 8k lines and can be compiled as a single C file without + any dependencies or configuration (e.g. `gcc -c src/isocline.c`). + +- Portable: works on Unix, Windows, and macOS, and uses a minimal + subset of ANSI escape sequences. + +- Features: extensive multi-line editing mode (`shift-tab`), (24-bit) color, history, completion, unicode, + undo/redo, incremental history search, inline hints, syntax highlighting, brace matching, + closing brace insertion, auto indentation, graceful fallback, support for custom allocators, etc. + +- License: MIT. + +- Comes with a Haskell binding ([`System.Console.Isocline`][hdoc]. + +Enjoy, + Daan + + + +# Demo + +![recording](doc/record-macos.svg) + +Shows in order: unicode, syntax highlighting, brace matching, jump to matching brace, auto indent, multiline editing, 24-bit colors, inline hinting, filename completion, and incremental history search. +(screen capture was made with [termtosvg] by Nicolas Bedos) + +# Usage + +Include the isocline header in your C or C++ source: +```C +#include +``` + +and call `ic_readline` to get user input with rich editing abilities: +```C +char* input; +while( (input = ic_readline("prompt")) != NULL ) { // ctrl+d/c or errors return NULL + printf("you typed:\n%s\n", input); // use the input + free(input); +} +``` + +See the [example] for a full example with completion, syntax highligting, history, etc. + +# Run the Example + +You can compile and run the [example] as: +``` +$ gcc -o example -Iinclude test/example.c src/isocline.c +$ ./example +``` + +or, the Haskell [example][HaskellExample]: +``` +$ ghc -ihaskell test/Example.hs src/isocline.c +$ ./test/Example +``` + + +# Editing with Isocline + +Isocline tries to be as compatible as possible with standard [GNU Readline] key bindings. + +### Overview: +```apl + home/ctrl-a cursor end/ctrl-e + ┌─────────────────┼───────────────┐ (navigate) + │ ctrl-left │ ctrl-right │ + │ ┌───────┼──────┐ │ ctrl-r : search history + ▼ ▼ ▼ ▼ ▼ tab : complete word + prompt> it is the quintessential language shift-tab: insert new line + ▲ ▲ ▲ ▲ esc : delete input, done + │ └──────────────┘ │ ctrl-z : undo + │ alt-backsp alt-d │ + └─────────────────────────────────┘ (delete) + ctrl-u ctrl-k +``` + +Note: on macOS, the meta (alt) key is not directly available in most terminals. +Terminal/iTerm2 users can activate the meta key through +`Terminal` → `Preferences` → `Settings` → `Use option as meta key`. + +### Key Bindings + +These are also shown when pressing `F1` on a Isocline prompt. We use `^` as a shorthand for `ctrl-`: + +| Navigation | | +|-------------------|-------------------------------------------------| +| `left`,`^b` | go one character to the left | +| `right`,`^f ` | go one character to the right | +| `up ` | go one row up, or back in the history | +| `down ` | go one row down, or forward in the history | +| `^left ` | go to the start of the previous word | +| `^right ` | go to the end the current word | +| `home`,`^a ` | go to the start of the current line | +| `end`,`^e ` | go to the end of the current line | +| `pgup`,`^home ` | go to the start of the current input | +| `pgdn`,`^end ` | go to the end of the current input | +| `alt-m ` | jump to matching brace | +| `^p ` | go back in the history | +| `^n ` | go forward in the history | +| `^r`,`^s ` | search the history starting with the current word | + + +| Deletion | | +|-------------------|-------------------------------------------------| +| `del`,`^d ` | delete the current character | +| `backsp`,`^h ` | delete the previous character | +| `^w ` | delete to preceding white space | +| `alt-backsp ` | delete to the start of the current word | +| `alt-d ` | delete to the end of the current word | +| `^u ` | delete to the start of the current line | +| `^k ` | delete to the end of the current line | +| `esc ` | delete the current input, or done with empty input | + + +| Editing | | +|-------------------|-------------------------------------------------| +| `enter ` | accept current input | +| `^enter`,`^j`,`shift-tab` | create a new line for multi-line input | +| `^l ` | clear screen | +| `^t ` | swap with previous character (move character backward) | +| `^z`,`^_ ` | undo | +| `^y ` | redo | +| `tab ` | try to complete the current input | + + +| Completion menu | | +|-------------------|-------------------------------------------------| +| `enter`,`left` | use the currently selected completion | +| `1` - `9` | use completion N from the menu | +| `tab, down ` | select the next completion | +| `shift-tab, up` | select the previous completion | +| `esc ` | exit menu without completing | +| `pgdn`,`^enter`,`^j` | show all further possible completions | + + +| Incremental history search | | +|-------------------|-------------------------------------------------| +| `enter ` | use the currently found history entry | +| `backsp`,`^z ` | go back to the previous match (undo) | +| `tab`,`^r`,`up` | find the next match | +| `shift-tab`,`^s`,`down` | find an earlier match | +| `esc ` | exit search | + + +# Build the Library + +### Build as a Single Source + +Copy the sources (in `include` and `src`) into your project, or add the library as a [submodule]: +``` +$ git submodule add https://github.com/daanx/isocline +``` +and add `isocline/src/isocline.c` to your build rules -- no configuration is needed. + +### Build with CMake + +Clone the repository and run cmake to build a static library (`.a`/`.lib`): +``` +$ git clone https://github.com/daanx/isocline +$ cd isocline +$ mkdir -p build/release +$ cd build/release +$ cmake ../.. +$ cmake --build . +``` +This builds a static library `libisocline.a` (or `isocline.lib` on Windows) +and the example program: +``` +$ ./example +``` + +### Build the Haskell Library + +See the Haskell [readme][Haskell] for instructions to build and use the Haskell library. + + +# API Reference + +* See the [C API reference][docapi] and the [example] for example usage of history, completion, etc. + +* See the [Haskell API reference][hdoc] on Hackage and the Haskell [example][HaskellExample]. + + +# Motivation + +Isocline was created for use in the [Koka] interactive compiler. +This required: pure C (no dependency on a C++ runtime or other libraries), +portable (across Linux, macOS, and Windows), unicode support, +a BSD-style license, and good functionality for completion and multi-line editing. + +Some other excellent libraries that we considered: +[GNU readline], +[editline](https://github.com/troglobit/editline), +[linenoise](https://github.com/antirez/linenoise), +[replxx](https://github.com/AmokHuginnsson/replxx), and +[Haskeline](https://github.com/judah/haskeline). + + +# Formatted Output + +Isocline also exposes functions for rich terminal output +as `ic_print` (and `ic_println` and `ic_printf`). +Inspired by the (Python) [Rich][RichBBcode] library, +this supports a form of [bbcode]'s to format the output: +```c +ic_println( "[b]bold [red]and red[/red][/b]" ); +``` +Each print automatically closes any open tags that were +not yet closed. Also, you can use a general close +tag as `[/]` to close the innermost tag, so the +following print is equivalent to the earlier one: +```c +ic_println( "[b]bold [red]and red[/]" ); +``` +There can be multiple styles in one tag +(where the first name is used for the closing tag): +```c +ic_println( "[u #FFD700]underlined gold[/]" ); +``` + +Sometimes, you need to display arbitrary messages +that may contain sequences that you would not like +to be interpreted as bbcode tags. One way to do +this is the `[!`_tag_`]` which ignores formatting +up to a close tag of the form `[/`_tag_`]`. +```c +ic_printf( "[red]red? [!pre]%s[/pre].\n", "[blue]not blue!" ); +``` + +Predefined styles include `b` (bold), +`u` (underline), `i` (italic), and `r` (reverse video), but +you can (re)define any style yourself as: +```c +ic_style_def("warning", "crimson u"); +``` + +and use them like any builtin style or property: +```c +ic_println( "[warning]this is a warning![/]" ); +``` +which is great for adding themes to your application. + +Each `ic_print` function always closes any unclosed tags automatically. +To open a style persistently, use `ic_style_open` with a matching +`ic_style_close` which scopes over any `ic_print` statements in between. +```c +ic_style_open("warning"); +ic_println("[b]crimson underlined and bold[/]"); +ic_style_close(); +``` + +# Advanced + + +## BBCode Format + +An open tag can have multiple white space separated +entries that are +either a _style name_, or a primitive _property_[`=`_value_]. + +### Styles + +Isocline provides the following builtin styles as property shorthands: +`b` (bold), `u` (underline), `i` (italic), `r` (reverse video), +and some builtin styles for syntax highlighting: +`keyword`, `control` (control-flow keywords), `string`, +`comment`, `number`, `type`, `constant`. + +Predefined styles used by Isocline itself are: + +- `ic-prompt`: prompt style, e.g. `ic_style_def("ic-prompt", "yellow on blue")`. +- `ic-info`: information (like the numbers in a completion menu). +- `ic-diminish`: dim text (used for example in history search). +- `ic-emphasis`: emphasized text (also used in history search). +- `ic-hint`: color of an inline hint. +- `ic-error`: error color (like an unmatched brace). +- `ic-bracematch`: color of matching parenthesis. + +### Properties + +Boolean properties are by default `on`: + +- `bold` [`=`(`on`|`off`)] +- `italic` [`=`(`on`|`off`)] +- `underline` [`=`(`on`|`off`)] +- `reverse` [`=`(`on`|`off`)] + +Color properties can be assigned a _color_: + +- `color=`_color_ +- `bgcolor=`_color_ +- _color_: equivalent to `color=`_color_. +- `on` _color_: equivalent to `bgcolor=`_color_. + +A color value can be specified in many ways: + +- any standard HTML [color name][htmlcolors]. +- any of the 16 standard ANSI [color names][ansicolors] by prefixing `ansi-` + (like `ansi-black` or `ansi-maroon`). + The actual color value of these depend on the a terminal theme. +- `#`_rrggbb_ or `#`_rgb_ for a specific 24-bit color. +- `ansi-color=`_idx_: where 0 <= _idx_ <= 256 specifies an entry in the + standard ANSI 256 [color palette][ansicolor256], where 256 is used for the ANSI + default color. + + +## Environment Variables + +- `NO_COLOR`: if present no colors are displayed. +- `CLICOLOR=1`: if set, the `LSCOLORS` or `LS_COLORS` environment variables are used to colorize + filename completions. +- `COLORTERM=`(`truecolor`|`256color`|`16color`|`8color`|`monochrome`): enable a certain color palette, see the next section. +- `TERM`: used on some systems to determine the color + +## Colors + +Isocline supports 24-bit colors and any RGB colors are automatically +mapped to a reduced palette on older terminals if these do not +support true color. Detection of full color support +is not always possible to do automatically and you can +set the `COLORTERM` environment variable expicitly to force Isocline to use +a specific palette: +- `COLORTERM=truecolor`: use 24-bit colors. + +- `COLORTERM=256color`: use the ANSI 256 color palette. + +- `COLORTERM=16color` : use the regular ANSI 16 color + palette (8 normal and 8 bright colors). + +- `COLORTERM=8color`: use bold for bright colors. +- `COLORTERM=monochrome`: use no color. + +The above screenshots are made with the +[`test_colors.c`](https://github.com/daanx/isocline/blob/main/test/test_colors.c) program. You can test your own +terminal as: +``` +$ gcc -o test_colors -Iinclude test/test_colors.c src/isocline.c +$ ./test_colors +$ COLORTERM=truecolor ./test_colors +$ COLORTERM=16color ./test_colors +``` + +## ANSI Escape Sequences + +Isocline uses just few ANSI escape sequences that are widely +supported: +- `ESC[`_n_`A`, `ESC[`_n_`B`, `ESC[`_n_`C`, and `ESC[`_n_`D`, + for moving the cursor _n_ places up, down, right, and left. +- `ESC[K` to clear the line from the cursor. +- `ESC[`_n_`m` for colors, with _n_ one of: 0 (reset), 1,22 (bold), 3,23 (italic), + 4,24 (underline), 7,27 (reverse), 30-37,40-47,90-97,100-107 (color), + and 39,49 (select default color). +- `ESC[38;5;`_n_`m`, `ESC[48;5;`_n_`m`, `ESC[38;2;`_r_`;`_g_`;`_b_`m`, `ESC[48;2;`_r_`;`_g_`;`_b_`m`: + on terminals that support it, select + entry _n_ from the + 256 color ANSI palette (used with `XTERM=xterm-256color` for example), or directly specify + any 24-bit _rgb_ color (used with `COLORTERM=truecolor`) for the foreground or background. + +On Windows the above functionality is implemented using the Windows console API +(except if running in the new Windows Terminal which supports these escape +sequences natively). + +## Async and Threads + +Isocline is _not_ thread-safe and `ic_readline`_xxx_ and `ic_print`_xxx_ should +be used from one thread only. + +The best way to use `ic_readline` asynchronously is +to run it in a (blocking) dedicated thread and deliver +results from there to the async event loop. Isocline has the +```C +bool ic_async_stop(void) +``` +function that is thread-safe and can deliver an +asynchronous event to Isocline that unblocks a current +`ic_readline` and makes it behave as if the user pressed +`ctrl-c` (which returns NULL from the read line call). + +## Color Mapping + +To map full RGB colors to an ANSI 256 or 16-color palette +Isocline finds a palette color with the minimal "color distance" to +the original color. There are various +ways of calculating this: one way is to take the euclidean distance +in the sRGB space (_simple-rgb_), a slightly better way is to +take a weighted distance where the weight distribution is adjusted +according to how big the red component is ([redmean](https://en.wikipedia.org/wiki/Color_difference), +denoted as _delta-rgb_ in the figure), +this is used by Isocline), +and finally, we can first translate into a perceptually uniform color space +(CIElab) and calculate the distance there using the [CIEDE2000](https://en.wikipedia.org/wiki/Color_difference) +algorithm (_ciede2000_). Here are these three methods compared on +some colors: + +![color space comparison](doc/color/colorspace-map.png) + +Each top row is the true 24-bit RGB color. Surprisingly, +the sophisticated CIEDE2000 distance seems less good here compared to the +simpler methods (as in the upper left block for example) +(perhaps because this algorithm was created to find close +perceptual colors in images where lightness differences may be given +less weight?). CIEDE2000 also leads to more "outliers", for example as seen +in column 5. Given these results, Isocline uses _redmean_ for +color mapping. We also add a gray correction that makes it less +likely to substitute a color for a gray value (and the other way +around). + + +## Possible Future Extensions + +- Vi key bindings. +- kill buffer. +- make the `ic_print`_xxx_ functions thread-safe. +- extended low-level terminal functions. +- status and progress bars. +- prompt variants: confirm, etc. +- ... + +Contact me if you are interested in doing any of these :-) + + +# Releases + +* `2022-01-15`: v1.0.9: fix missing `ic_completion_arg` (issue #6), + fix null ptr check in ic_print (issue #7), fix crash when using /dev/null as both input and output. +* `2021-09-05`: v1.0.5: use our own wcwidth for consistency; + thanks to Hans-Georg Breunig for helping with testing on NetBSD. +* `2021-08-28`: v1.0.4: fix color query on Ubuntu/Gnome +* `2021-08-27`: v1.0.3: fix duplicates in completions +* `2021-08-23`: v1.0.2: fix windows eol wrapping +* `2021-08-21`: v1.0.1: fix line-buffering +* `2021-08-20`: v1.0.0: initial release + + + +[GNU readline]: https://tiswww.case.edu/php/chet/readline/rltop.html +[koka]: http://www.koka-lang.org +[submodule]: https://git-scm.com/book/en/v2/Git-Tools-Submodules +[Haskell]: https://github.com/daanx/isocline/tree/main/haskell +[HaskellExample]: https://github.com/daanx/isocline/blob/main/test/Example.hs +[example]: https://github.com/daanx/isocline/blob/main/test/example.c +[termtosvg]: https://github.com/nbedos/termtosvg +[Rich]: https://github.com/willmcgugan/rich +[RichBBcode]: https://rich.readthedocs.io/en/latest/markup.html +[bbcode]: https://en.wikipedia.org/wiki/BBCode +[htmlcolors]: https://en.wikipedia.org/wiki/Web_colors#HTML_color_names +[ansicolors]: https://en.wikipedia.org/wiki/Web_colors#Basic_colors +[ansicolor256]: https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit +[docapi]: https://daanx.github.io/isocline +[hdoc]: https://hackage.haskell.org/package/isocline/docs/System-Console-Isocline.html diff --git a/extern/isocline/src/attr.c b/extern/isocline/src/attr.c new file mode 100644 index 00000000..b5ad78f8 --- /dev/null +++ b/extern/isocline/src/attr.c @@ -0,0 +1,294 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include + +#include "common.h" +#include "stringbuf.h" // str_next_ofs +#include "attr.h" +#include "term.h" // color_from_ansi256 + +//------------------------------------------------------------- +// Attributes +//------------------------------------------------------------- + +ic_private attr_t attr_none(void) { + attr_t attr; + attr.value = 0; + return attr; +} + +ic_private attr_t attr_default(void) { + attr_t attr = attr_none(); + attr.x.color = IC_ANSI_DEFAULT; + attr.x.bgcolor = IC_ANSI_DEFAULT; + attr.x.bold = IC_OFF; + attr.x.underline = IC_OFF; + attr.x.reverse = IC_OFF; + attr.x.italic = IC_OFF; + return attr; +} + +ic_private bool attr_is_none(attr_t attr) { + return (attr.value == 0); +} + +ic_private bool attr_is_eq(attr_t attr1, attr_t attr2) { + return (attr1.value == attr2.value); +} + +ic_private attr_t attr_from_color( ic_color_t color ) { + attr_t attr = attr_none(); + attr.x.color = color; + return attr; +} + + +ic_private attr_t attr_update_with( attr_t oldattr, attr_t newattr ) { + attr_t attr = oldattr; + if (newattr.x.color != IC_COLOR_NONE) { attr.x.color = newattr.x.color; } + if (newattr.x.bgcolor != IC_COLOR_NONE) { attr.x.bgcolor = newattr.x.bgcolor; } + if (newattr.x.bold != IC_NONE) { attr.x.bold = newattr.x.bold; } + if (newattr.x.italic != IC_NONE) { attr.x.italic = newattr.x.italic; } + if (newattr.x.reverse != IC_NONE) { attr.x.reverse = newattr.x.reverse; } + if (newattr.x.underline != IC_NONE) { attr.x.underline = newattr.x.underline; } + return attr; +} + +static bool sgr_is_digit(char c) { + return (c >= '0' && c <= '9'); +} + +static bool sgr_is_sep( char c ) { + return (c==';' || c==':'); +} + +static bool sgr_next_par(const char* s, ssize_t* pi, ssize_t* par) { + const ssize_t i = *pi; + ssize_t n = 0; + while( sgr_is_digit(s[i+n])) { + n++; + } + if (n==0) { + *par = 0; + return true; + } + else { + *pi = i+n; + return ic_atoz(s+i, par); + } +} + +static bool sgr_next_par3(const char* s, ssize_t* pi, ssize_t* p1, ssize_t* p2, ssize_t* p3) { + bool ok = false; + ssize_t i = *pi; + if (sgr_next_par(s,&i,p1) && sgr_is_sep(s[i])) { + i++; + if (sgr_next_par(s,&i,p2) && sgr_is_sep(s[i])) { + i++; + if (sgr_next_par(s,&i,p3)) { + ok = true; + }; + } + } + *pi = i; + return ok; +} + +ic_private attr_t attr_from_sgr( const char* s, ssize_t len) { + attr_t attr = attr_none(); + for( ssize_t i = 0; i < len && s[i] != 0; i++) { + ssize_t cmd = 0; + if (!sgr_next_par(s,&i,&cmd)) continue; + switch(cmd) { + case 0: attr = attr_default(); break; + case 1: attr.x.bold = IC_ON; break; + case 3: attr.x.italic = IC_ON; break; + case 4: attr.x.underline = IC_ON; break; + case 7: attr.x.reverse = IC_ON; break; + case 22: attr.x.bold = IC_OFF; break; + case 23: attr.x.italic = IC_OFF; break; + case 24: attr.x.underline = IC_OFF; break; + case 27: attr.x.reverse = IC_OFF; break; + case 39: attr.x.color = IC_ANSI_DEFAULT; break; + case 49: attr.x.bgcolor = IC_ANSI_DEFAULT; break; + default: { + if (cmd >= 30 && cmd <= 37) { + attr.x.color = IC_ANSI_BLACK + (unsigned)(cmd - 30); + } + else if (cmd >= 40 && cmd <= 47) { + attr.x.bgcolor = IC_ANSI_BLACK + (unsigned)(cmd - 40); + } + else if (cmd >= 90 && cmd <= 97) { + attr.x.color = IC_ANSI_DARKGRAY + (unsigned)(cmd - 90); + } + else if (cmd >= 100 && cmd <= 107) { + attr.x.bgcolor = IC_ANSI_DARKGRAY + (unsigned)(cmd - 100); + } + else if ((cmd == 38 || cmd == 48) && sgr_is_sep(s[i])) { + // non-associative SGR :-( + ssize_t par = 0; + i++; + if (sgr_next_par(s, &i, &par)) { + if (par==5 && sgr_is_sep(s[i])) { + // ansi 256 index + i++; + if (sgr_next_par(s, &i, &par) && par >= 0 && par <= 0xFF) { + ic_color_t color = color_from_ansi256(par); + if (cmd==38) { attr.x.color = color; } + else { attr.x.bgcolor = color; } + } + } + else if (par == 2 && sgr_is_sep(s[i])) { + // rgb value + i++; + ssize_t r,g,b; + if (sgr_next_par3(s, &i, &r,&g,&b)) { + ic_color_t color = ic_rgbx(r,g,b); + if (cmd==38) { attr.x.color = color; } + else { attr.x.bgcolor = color; } + } + } + } + } + else { + debug_msg("attr: unknow ANSI SGR code: %zd\n", cmd ); + } + } + } + } + return attr; +} + +ic_private attr_t attr_from_esc_sgr( const char* s, ssize_t len) { + if (len <= 2 || s[0] != '\x1B' || s[1] != '[' || s[len-1] != 'm') return attr_none(); + return attr_from_sgr(s+2, len-2); +} + + +//------------------------------------------------------------- +// Attribute buffer +//------------------------------------------------------------- +struct attrbuf_s { + attr_t* attrs; + ssize_t capacity; + ssize_t count; + alloc_t* mem; +}; + +static bool attrbuf_ensure_capacity( attrbuf_t* ab, ssize_t needed ) { + if (needed <= ab->capacity) return true; + ssize_t newcap = (ab->capacity <= 0 ? 240 : (ab->capacity > 1000 ? ab->capacity + 1000 : 2*ab->capacity)); + if (needed > newcap) { newcap = needed; } + attr_t* newattrs = mem_realloc_tp( ab->mem, attr_t, ab->attrs, newcap ); + if (newattrs == NULL) return false; + ab->attrs = newattrs; + ab->capacity = newcap; + assert(needed <= ab->capacity); + return true; +} + +static bool attrbuf_ensure_extra( attrbuf_t* ab, ssize_t extra ) { + const ssize_t needed = ab->count + extra; + return attrbuf_ensure_capacity( ab, needed ); +} + + +ic_private attrbuf_t* attrbuf_new( alloc_t* mem ) { + attrbuf_t* ab = mem_zalloc_tp(mem,attrbuf_t); + if (ab == NULL) return NULL; + ab->mem = mem; + attrbuf_ensure_extra(ab,1); + return ab; +} + +ic_private void attrbuf_free( attrbuf_t* ab ) { + if (ab==NULL) return; + mem_free(ab->mem, ab->attrs); + mem_free(ab->mem, ab); +} + +ic_private void attrbuf_clear(attrbuf_t* ab) { + if (ab == NULL) return; + ab->count = 0; +} + +ic_private ssize_t attrbuf_len( attrbuf_t* ab ) { + return (ab==NULL ? 0 : ab->count); +} + +ic_private const attr_t* attrbuf_attrs( attrbuf_t* ab, ssize_t expected_len ) { + assert(expected_len <= ab->count ); + // expand if needed + if (ab->count < expected_len) { + if (!attrbuf_ensure_capacity(ab,expected_len)) return NULL; + for(ssize_t i = ab->count; i < expected_len; i++) { + ab->attrs[i] = attr_none(); + } + ab->count = expected_len; + } + return ab->attrs; +} + + + +static void attrbuf_update_set_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr, bool update ) { + const ssize_t end = pos + count; + if (!attrbuf_ensure_capacity(ab, end)) return; + ssize_t i; + // initialize if end is beyond the count (todo: avoid duplicate init and set if update==false?) + if (ab->count < end) { + for(i = ab->count; i < end; i++) { + ab->attrs[i] = attr_none(); + } + ab->count = end; + } + // fill pos to end with attr + for(i = pos; i < end; i++) { + ab->attrs[i] = (update ? attr_update_with(ab->attrs[i],attr) : attr); + } +} + +ic_private void attrbuf_set_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ) { + attrbuf_update_set_at(ab, pos, count, attr, false); +} + +ic_private void attrbuf_update_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ) { + attrbuf_update_set_at(ab, pos, count, attr, true); +} + +ic_private void attrbuf_insert_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ) { + if (pos < 0 || pos > ab->count || count <= 0) return; + if (!attrbuf_ensure_extra(ab,count)) return; + ic_memmove( ab->attrs + pos + count, ab->attrs + pos, (ab->count - pos)*ssizeof(attr_t) ); + ab->count += count; + attrbuf_set_at( ab, pos, count, attr ); +} + + +// note: must allow ab == NULL! +ic_private ssize_t attrbuf_append_n( stringbuf_t* sb, attrbuf_t* ab, const char* s, ssize_t len, attr_t attr ) { + if (s == NULL || len == 0) return sbuf_len(sb); + if (ab != NULL) { + if (!attrbuf_ensure_extra(ab,len)) return sbuf_len(sb); + attrbuf_set_at(ab, ab->count, len, attr); + } + return sbuf_append_n(sb,s,len); +} + +ic_private attr_t attrbuf_attr_at( attrbuf_t* ab, ssize_t pos ) { + if (ab==NULL || pos < 0 || pos > ab->count) return attr_none(); + return ab->attrs[pos]; +} + +ic_private void attrbuf_delete_at( attrbuf_t* ab, ssize_t pos, ssize_t count ) { + if (ab==NULL || pos < 0 || pos > ab->count) return; + if (pos + count > ab->count) { count = ab->count - pos; } + if (count == 0) return; + assert(pos + count <= ab->count); + ic_memmove( ab->attrs + pos, ab->attrs + pos + count, ab->count - (pos + count) ); + ab->count -= count; +} diff --git a/extern/isocline/src/attr.h b/extern/isocline/src/attr.h new file mode 100644 index 00000000..8f37d050 --- /dev/null +++ b/extern/isocline/src/attr.h @@ -0,0 +1,70 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_ATTR_H +#define IC_ATTR_H + +#include "common.h" +#include "stringbuf.h" + +//------------------------------------------------------------- +// text attributes +//------------------------------------------------------------- + +#define IC_ON (1) +#define IC_OFF (-1) +#define IC_NONE (0) + +// try to fit in 64 bits +// note: order is important for some compilers +// note: each color can actually be 25 bits +typedef union attr_s { + struct { + unsigned int color:28; + signed int bold:2; + signed int reverse:2; + unsigned int bgcolor:28; + signed int underline:2; + signed int italic:2; + } x; + uint64_t value; +} attr_t; + +ic_private attr_t attr_none(void); +ic_private attr_t attr_default(void); +ic_private attr_t attr_from_color( ic_color_t color ); + +ic_private bool attr_is_none(attr_t attr); +ic_private bool attr_is_eq(attr_t attr1, attr_t attr2); + +ic_private attr_t attr_update_with( attr_t attr, attr_t newattr ); + +ic_private attr_t attr_from_sgr( const char* s, ssize_t len); +ic_private attr_t attr_from_esc_sgr( const char* s, ssize_t len); + +//------------------------------------------------------------- +// attribute buffer used for rich rendering +//------------------------------------------------------------- + +struct attrbuf_s; +typedef struct attrbuf_s attrbuf_t; + +ic_private attrbuf_t* attrbuf_new( alloc_t* mem ); +ic_private void attrbuf_free( attrbuf_t* ab ); // ab can be NULL +ic_private void attrbuf_clear( attrbuf_t* ab ); // ab can be NULL +ic_private ssize_t attrbuf_len( attrbuf_t* ab); // ab can be NULL +ic_private const attr_t* attrbuf_attrs( attrbuf_t* ab, ssize_t expected_len ); +ic_private ssize_t attrbuf_append_n( stringbuf_t* sb, attrbuf_t* ab, const char* s, ssize_t len, attr_t attr ); + +ic_private void attrbuf_set_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ); +ic_private void attrbuf_update_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ); +ic_private void attrbuf_insert_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ); + +ic_private attr_t attrbuf_attr_at( attrbuf_t* ab, ssize_t pos ); +ic_private void attrbuf_delete_at( attrbuf_t* ab, ssize_t pos, ssize_t count ); + +#endif // IC_ATTR_H diff --git a/extern/isocline/src/bbcode.c b/extern/isocline/src/bbcode.c new file mode 100644 index 00000000..4d11ac38 --- /dev/null +++ b/extern/isocline/src/bbcode.c @@ -0,0 +1,842 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include +#include + +#include "common.h" +#include "attr.h" +#include "term.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// HTML color table +//------------------------------------------------------------- + +#include "bbcode_colors.c" + +//------------------------------------------------------------- +// Types +//------------------------------------------------------------- + +typedef struct style_s { + const char* name; // name of the style + attr_t attr; // attribute to apply +} style_t; + +typedef enum align_e { + IC_ALIGN_LEFT, + IC_ALIGN_CENTER, + IC_ALIGN_RIGHT +} align_t; + +typedef struct width_s { + ssize_t w; // > 0 + align_t align; + bool dots; // "..." (e.g. "sentence...") + char fill; // " " (e.g. "hello ") +} width_t; + +typedef struct tag_s { + const char* name; // tag open name + attr_t attr; // the saved attribute before applying the style + width_t width; // start sequence of at most "width" columns + ssize_t pos; // start position in the output (used for width restriction) +} tag_t; + + +static void tag_init(tag_t* tag) { + memset(tag,0,sizeof(*tag)); +} + +struct bbcode_s { + tag_t* tags; // stack of tags; one entry for each open tag + ssize_t tags_capacity; + ssize_t tags_nesting; + style_t* styles; // list of used defined styles + ssize_t styles_capacity; + ssize_t styles_count; + term_t* term; // terminal + alloc_t* mem; // allocator + // caches + stringbuf_t* out; // print buffer + attrbuf_t* out_attrs; + stringbuf_t* vout; // vprintf buffer +}; + + +//------------------------------------------------------------- +// Create, helpers +//------------------------------------------------------------- + +ic_private bbcode_t* bbcode_new( alloc_t* mem, term_t* term ) { + bbcode_t* bb = mem_zalloc_tp(mem,bbcode_t); + if (bb==NULL) return NULL; + bb->mem = mem; + bb->term = term; + bb->out = sbuf_new(mem); + bb->out_attrs = attrbuf_new(mem); + bb->vout = sbuf_new(mem); + return bb; +} + +ic_private void bbcode_free( bbcode_t* bb ) { + for(ssize_t i = 0; i < bb->styles_count; i++) { + mem_free(bb->mem, bb->styles[i].name); + } + mem_free(bb->mem, bb->tags); + mem_free(bb->mem, bb->styles); + sbuf_free(bb->vout); + sbuf_free(bb->out); + attrbuf_free(bb->out_attrs); + mem_free(bb->mem, bb); +} + +ic_private void bbcode_style_add( bbcode_t* bb, const char* style_name, attr_t attr ) { + if (bb->styles_count >= bb->styles_capacity) { + ssize_t newlen = bb->styles_capacity + 32; + style_t* p = mem_realloc_tp( bb->mem, style_t, bb->styles, newlen ); + if (p == NULL) return; + bb->styles = p; + bb->styles_capacity = newlen; + } + assert(bb->styles_count < bb->styles_capacity); + bb->styles[bb->styles_count].name = mem_strdup( bb->mem, style_name ); + bb->styles[bb->styles_count].attr = attr; + bb->styles_count++; +} + +static ssize_t bbcode_tag_push( bbcode_t* bb, const tag_t* tag ) { + if (bb->tags_nesting >= bb->tags_capacity) { + ssize_t newcap = bb->tags_capacity + 32; + tag_t* p = mem_realloc_tp( bb->mem, tag_t, bb->tags, newcap ); + if (p == NULL) return -1; + bb->tags = p; + bb->tags_capacity = newcap; + } + assert(bb->tags_nesting < bb->tags_capacity); + bb->tags[bb->tags_nesting] = *tag; + bb->tags_nesting++; + return (bb->tags_nesting-1); +} + +static void bbcode_tag_pop( bbcode_t* bb, tag_t* tag ) { + if (bb->tags_nesting <= 0) { + if (tag != NULL) { tag_init(tag); } + } + else { + bb->tags_nesting--; + if (tag != NULL) { + *tag = bb->tags[bb->tags_nesting]; + } + } +} + +//------------------------------------------------------------- +// Invalid parse/values/balance +//------------------------------------------------------------- + +static void bbcode_invalid(const char* fmt, ... ) { + if (getenv("ISOCLINE_BBCODE_DEBUG") != NULL) { + va_list args; + va_start(args,fmt); + vfprintf(stderr,fmt,args); + va_end(args); + } +} + +//------------------------------------------------------------- +// Set attributes +//------------------------------------------------------------- + + +static attr_t bbcode_open( bbcode_t* bb, ssize_t out_pos, const tag_t* tag, attr_t current ) { + // save current and set + tag_t cur; + tag_init(&cur); + cur.name = tag->name; + cur.attr = current; + cur.width = tag->width; + cur.pos = out_pos; + bbcode_tag_push(bb,&cur); + return attr_update_with( current, tag->attr ); +} + +static bool bbcode_close( bbcode_t* bb, ssize_t base, const char* name, tag_t* pprev ) { + // pop until match + while (bb->tags_nesting > base) { + tag_t prev; + bbcode_tag_pop(bb,&prev); + if (name==NULL || prev.name==NULL || ic_stricmp(prev.name,name) == 0) { + // matched + if (pprev != NULL) { *pprev = prev; } + return true; + } + else { + // unbalanced: we either continue popping or restore the tags depending if there is a matching open tag in our tags. + bool has_open_tag = false; + if (name != NULL) { + for( ssize_t i = bb->tags_nesting - 1; i > base; i--) { + if (bb->tags[i].name != NULL && ic_stricmp(bb->tags[i].name, name) == 0) { + has_open_tag = true; + break; + } + } + } + bbcode_invalid("bbcode: unbalanced tags: open [%s], close [/%s]\n", prev.name, name); + if (!has_open_tag) { + bbcode_tag_push( bb, &prev ); // restore the tags and ignore this close tag + break; + } + else { + // continue until we hit our open tag + } + } + } + if (pprev != NULL) { memset(pprev,0,sizeof(*pprev)); } + return false; +} + +//------------------------------------------------------------- +// Update attributes +//------------------------------------------------------------- + +static const char* attr_update_bool( const char* fname, signed int* field, const char* value ) { + if (value == NULL || value[0] == 0 || strcmp(value,"on") || strcmp(value,"true") || strcmp(value,"1")) { + *field = IC_ON; + } + else if (strcmp(value,"off") || strcmp(value,"false") || strcmp(value,"0")) { + *field = IC_OFF; + } + else { + bbcode_invalid("bbcode: invalid %s value: %s\n", fname, value ); + } + return fname; +} + +static const char* attr_update_color( const char* fname, ic_color_t* field, const char* value ) { + if (value == NULL || value[0] == 0 || strcmp(value,"none") == 0) { + *field = IC_COLOR_NONE; + return fname; + } + + // hex value + if (value[0] == '#') { + uint32_t rgb = 0; + if (sscanf(value,"#%x",&rgb) == 1) { + *field = ic_rgb(rgb); + } + else { + bbcode_invalid("bbcode: invalid color code: %s\n", value); + } + return fname; + } + + // search color names + ssize_t lo = 0; + ssize_t hi = IC_HTML_COLOR_COUNT-1; + while( lo <= hi ) { + ssize_t mid = (lo + hi) / 2; + style_color_t* info = &html_colors[mid]; + int cmp = strcmp(info->name,value); + if (cmp < 0) { + lo = mid+1; + } + else if (cmp > 0) { + hi = mid-1; + } + else { + *field = info->color; + return fname; + } + } + bbcode_invalid("bbcode: unknown %s: %s\n", fname, value); + *field = IC_COLOR_NONE; + return fname; +} + +static const char* attr_update_sgr( const char* fname, attr_t* attr, const char* value ) { + *attr = attr_update_with(*attr, attr_from_sgr(value, ic_strlen(value))); + return fname; +} + +static void attr_update_width( width_t* pwidth, char default_fill, const char* value ) { + // parse width value: ;;; + width_t width; + memset(&width, 0, sizeof(width)); + width.fill = default_fill; // use 0 for no-fill (as for max-width) + if (ic_atoz(value, &width.w)) { + ssize_t i = 0; + while( value[i] != ';' && value[i] != 0 ) { i++; } + if (value[i] == ';') { + i++; + ssize_t len = 0; + while( value[i+len] != ';' && value[i+len] != 0 ) { len++; } + if (len == 4 && ic_istarts_with(value+i,"left")) { + width.align = IC_ALIGN_LEFT; + } + else if (len == 5 && ic_istarts_with(value+i,"right")) { + width.align = IC_ALIGN_RIGHT; + } + else if (len == 6 && ic_istarts_with(value+i,"center")) { + width.align = IC_ALIGN_CENTER; + } + i += len; + if (value[i] == ';') { + i++; len = 0; + while( value[i+len] != ';' && value[i+len] != 0 ) { len++; } + if (len == 1) { width.fill = value[i]; } + i+= len; + if (value[i] == ';') { + i++; len = 0; + while( value[i+len] != ';' && value[i+len] != 0 ) { len++; } + if ((len == 2 && ic_istarts_with(value+i,"on")) || (len == 1 && value[i] == '1')) { width.dots = true; } + i += len; + } + } + } + } + else { + bbcode_invalid("bbcode: illegal width: %s\n", value); + } + *pwidth = width; +} + +static const char* attr_update_ansi_color( const char* fname, ic_color_t* color, const char* value ) { + ssize_t num = 0; + if (ic_atoz(value, &num) && num >= 0 && num <= 256) { + *color = color_from_ansi256(num); + } + return fname; +} + + +static const char* attr_update_property( tag_t* tag, const char* attr_name, const char* value ) { + const char* fname = NULL; + fname = "bold"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.bold = b; } + return fname; + } + fname = "italic"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.italic = b; } + return fname; + } + fname = "underline"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.underline = b; } + return fname; + } + fname = "reverse"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.reverse = b; } + return fname; + } + fname = "color"; + if (strcmp(attr_name,fname) == 0) { + unsigned int color = IC_COLOR_NONE; + attr_update_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.color = color; } + return fname; + } + fname = "bgcolor"; + if (strcmp(attr_name,fname) == 0) { + unsigned int color = IC_COLOR_NONE; + attr_update_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.bgcolor = color; } + return fname; + } + fname = "ansi-sgr"; + if (strcmp(attr_name,fname) == 0) { + attr_update_sgr(fname, &tag->attr, value); + return fname; + } + fname = "ansi-color"; + if (strcmp(attr_name,fname) == 0) { + ic_color_t color = IC_COLOR_NONE;; + attr_update_ansi_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.color = color; } + return fname; + } + fname = "ansi-bgcolor"; + if (strcmp(attr_name,fname) == 0) { + ic_color_t color = IC_COLOR_NONE;; + attr_update_ansi_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.bgcolor = color; } + return fname; + } + fname = "width"; + if (strcmp(attr_name,fname) == 0) { + attr_update_width(&tag->width, ' ', value); + return fname; + } + fname = "max-width"; + if (strcmp(attr_name,fname) == 0) { + attr_update_width(&tag->width, 0, value); + return "width"; + } + else { + return NULL; + } +} + +static const style_t builtin_styles[] = { + { "b", { { IC_COLOR_NONE, IC_ON , IC_NONE, IC_COLOR_NONE, IC_NONE, IC_NONE } } }, + { "r", { { IC_COLOR_NONE, IC_NONE, IC_ON , IC_COLOR_NONE, IC_NONE, IC_NONE } } }, + { "u", { { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_ON , IC_NONE } } }, + { "i", { { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_NONE, IC_ON } } }, + { "em", { { IC_COLOR_NONE, IC_ON , IC_NONE, IC_COLOR_NONE, IC_NONE, IC_NONE } } }, // bold + { "url",{ { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_ON, IC_NONE } } }, // underline + { NULL, { { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_NONE, IC_NONE } } } +}; + +static void attr_update_with_styles( tag_t* tag, const char* attr_name, const char* value, + bool usebgcolor, const style_t* styles, ssize_t count ) +{ + // direct hex color? + if (attr_name[0] == '#' && (value == NULL || value[0]==0)) { + value = attr_name; + attr_name = (usebgcolor ? "bgcolor" : "color"); + } + // first try if it is a builtin property + const char* name; + if ((name = attr_update_property(tag,attr_name,value)) != NULL) { + if (tag->name != NULL) tag->name = name; + return; + } + // then check all styles + while( count-- > 0 ) { + const style_t* style = styles + count; + if (strcmp(style->name,attr_name) == 0) { + tag->attr = attr_update_with(tag->attr,style->attr); + if (tag->name != NULL) tag->name = style->name; + return; + } + } + // check builtin styles; todo: binary search? + for( const style_t* style = builtin_styles; style->name != NULL; style++) { + if (strcmp(style->name,attr_name) == 0) { + tag->attr = attr_update_with(tag->attr,style->attr); + if (tag->name != NULL) tag->name = style->name; + return; + } + } + // check colors as a style + ssize_t lo = 0; + ssize_t hi = IC_HTML_COLOR_COUNT-1; + while( lo <= hi ) { + ssize_t mid = (lo + hi) / 2; + style_color_t* info = &html_colors[mid]; + int cmp = strcmp(info->name,attr_name); + if (cmp < 0) { + lo = mid+1; + } + else if (cmp > 0) { + hi = mid-1; + } + else { + attr_t cattr = attr_none(); + if (usebgcolor) { cattr.x.bgcolor = info->color; } + else { cattr.x.color = info->color; } + tag->attr = attr_update_with(tag->attr,cattr); + if (tag->name != NULL) tag->name = info->name; + return; + } + } + // not found + bbcode_invalid("bbcode: unknown style: %s\n", attr_name); +} + + +ic_private attr_t bbcode_style( bbcode_t* bb, const char* style_name ) { + tag_t tag; + tag_init(&tag); + attr_update_with_styles( &tag, style_name, NULL, false, bb->styles, bb->styles_count ); + return tag.attr; +} + +//------------------------------------------------------------- +// Parse tags +//------------------------------------------------------------- + +ic_private const char* parse_skip_white(const char* s) { + while( *s != 0 && *s != ']') { + if (!(*s == ' ' || *s == '\t' || *s == '\n' || *s == '\r')) break; + s++; + } + return s; +} + +ic_private const char* parse_skip_to_white(const char* s) { + while( *s != 0 && *s != ']') { + if (*s == ' ' || *s == '\t' || *s == '\n' || *s == '\r') break; + s++; + } + return parse_skip_white(s); +} + +ic_private const char* parse_skip_to_end(const char* s) { + while( *s != 0 && *s != ']' ) { s++; } + return s; +} + +ic_private const char* parse_attr_name(const char* s) { + if (*s == '#') { + s++; // hex rgb color as id + while( *s != 0 && *s != ']') { + if (!((*s >= 'a' && *s <= 'f') || (*s >= 'A' && *s <= 'Z') || (*s >= '0' && *s <= '9'))) break; + s++; + } + } + else { + while( *s != 0 && *s != ']') { + if (!((*s >= 'a' && *s <= 'z') || (*s >= 'A' && *s <= 'Z') || + (*s >= '0' && *s <= '9') || *s == '_' || *s == '-')) break; + s++; + } + } + return s; +} + +ic_private const char* parse_value(const char* s, const char** start, const char** end) { + if (*s == '"') { + s++; + *start = s; + while( *s != 0 ) { + if (*s == '"') break; + s++; + } + *end = s; + if (*s == '"') { s++; } + } + else if (*s == '#') { + *start = s; + s++; + while( *s != 0 ) { + if (!((*s >= 'a' && *s <= 'f') || (*s >= 'A' && *s <= 'Z') || (*s >= '0' && *s <= '9'))) break; + s++; + } + *end = s; + } + else { + *start = s; + while( *s != 0 ) { + if (!((*s >= 'a' && *s <= 'z') || (*s >= 'A' && *s <= 'F') || (*s >= '0' && *s <= '9') || *s == '-' || *s == '_')) break; + s++; + } + *end = s; + } + return s; +} + +ic_private const char* parse_tag_value( tag_t* tag, char* idbuf, const char* s, const style_t* styles, ssize_t scount ) { + // parse: \s*[\w-]+\s*(=\s*) + bool usebgcolor = false; + const char* id = s; + const char* idend = parse_attr_name(id); + const char* val = NULL; + const char* valend = NULL; + if (id == idend) { + bbcode_invalid("bbcode: empty identifier? %.10s...\n", id ); + return parse_skip_to_white(id); + } + // "on" bgcolor? + s = parse_skip_white(idend); + if (idend - id == 2 && ic_strnicmp(id,"on",2) == 0 && *s != '=') { + usebgcolor = true; + id = s; + idend = parse_attr_name(id); + if (id == idend) { + bbcode_invalid("bbcode: empty identifier follows 'on'? %.10s...\n", id ); + return parse_skip_to_white(id); + } + s = parse_skip_white(idend); + } + // value + if (*s == '=') { + s++; + s = parse_skip_white(s); + s = parse_value(s, &val, &valend); + s = parse_skip_white(s); + } + // limit name and attr to 128 bytes + char valbuf[128]; + ic_strncpy( idbuf, 128, id, idend - id); + ic_strncpy( valbuf, 128, val, valend - val); + ic_str_tolower(idbuf); + ic_str_tolower(valbuf); + attr_update_with_styles( tag, idbuf, valbuf, usebgcolor, styles, scount ); + return s; +} + +static const char* parse_tag_values( tag_t* tag, char* idbuf, const char* s, const style_t* styles, ssize_t scount ) { + s = parse_skip_white(s); + idbuf[0] = 0; + ssize_t count = 0; + while( *s != 0 && *s != ']') { + char idbuf_next[128]; + s = parse_tag_value(tag, (count==0 ? idbuf : idbuf_next), s, styles, scount); + count++; + } + if (*s == ']') { s++; } + return s; +} + +static const char* parse_tag( tag_t* tag, char* idbuf, bool* open, bool* pre, const char* s, const style_t* styles, ssize_t scount ) { + *open = true; + *pre = false; + if (*s != '[') return s; + s = parse_skip_white(s+1); + if (*s == '!') { // pre + *pre = true; + s = parse_skip_white(s+1); + } + else if (*s == '/') { + *open = false; + s = parse_skip_white(s+1); + }; + s = parse_tag_values( tag, idbuf, s, styles, scount); + return s; +} + + +//--------------------------------------------------------- +// Styles +//--------------------------------------------------------- + +static void bbcode_parse_tag_content( bbcode_t* bb, const char* s, tag_t* tag ) { + tag_init(tag); + if (s != NULL) { + char idbuf[128]; + parse_tag_values(tag, idbuf, s, bb->styles, bb->styles_count); + } +} + +ic_private void bbcode_style_def( bbcode_t* bb, const char* style_name, const char* s ) { + tag_t tag; + bbcode_parse_tag_content( bb, s, &tag); + bbcode_style_add(bb, style_name, tag.attr); +} + +ic_private void bbcode_style_open( bbcode_t* bb, const char* fmt ) { + tag_t tag; + bbcode_parse_tag_content(bb, fmt, &tag); + term_set_attr( bb->term, bbcode_open(bb, 0, &tag, term_get_attr(bb->term)) ); +} + +ic_private void bbcode_style_close( bbcode_t* bb, const char* fmt ) { + const ssize_t base = bb->tags_nesting - 1; // as we end a style + tag_t tag; + bbcode_parse_tag_content(bb, fmt, &tag); + tag_t prev; + if (bbcode_close(bb, base, tag.name, &prev)) { + term_set_attr( bb->term, prev.attr ); + } +} + +//--------------------------------------------------------- +// Restrict to width +//--------------------------------------------------------- + +static void bbcode_restrict_width( ssize_t start, width_t width, stringbuf_t* out, attrbuf_t* attr_out ) { + if (width.w <= 0) return; + assert(start <= sbuf_len(out)); + assert(attr_out == NULL || sbuf_len(out) == attrbuf_len(attr_out)); + const char* s = sbuf_string(out) + start; + const ssize_t len = sbuf_len(out) - start; + const ssize_t w = str_column_width(s); + if (w == width.w) return; // fits exactly + if (w > width.w) { + // too large + ssize_t innerw = (width.dots && width.w > 3 ? width.w-3 : width.w); + if (width.align == IC_ALIGN_RIGHT) { + // right align + const ssize_t ndel = str_skip_until_fit( s, innerw ); + sbuf_delete_at( out, start, ndel ); + attrbuf_delete_at( attr_out, start, ndel ); + if (innerw < width.w) { + // add dots + sbuf_insert_at( out, "...", start ); + attr_t attr = attrbuf_attr_at(attr_out, start); + attrbuf_insert_at( attr_out, start, 3, attr); + } + } + else { + // left or center align + ssize_t count = str_take_while_fit( s, innerw ); + sbuf_delete_at( out, start + count, len - count ); + attrbuf_delete_at( attr_out, start + count, len - count ); + if (innerw < width.w) { + // add dots + attr_t attr = attrbuf_attr_at(attr_out,start); + attrbuf_append_n( out, attr_out, "...", 3, attr ); + } + } + } + else { + // too short, pad to width + const ssize_t diff = (width.w - w); + const ssize_t pad_left = (width.align == IC_ALIGN_RIGHT ? diff : (width.align == IC_ALIGN_LEFT ? 0 : diff / 2)); + const ssize_t pad_right = (width.align == IC_ALIGN_LEFT ? diff : (width.align == IC_ALIGN_RIGHT ? 0 : diff - pad_left)); + if (width.fill != 0 && pad_left > 0) { + const attr_t attr = attrbuf_attr_at(attr_out,start); + for( ssize_t i = 0; i < pad_left; i++) { // todo: optimize + sbuf_insert_char_at(out, width.fill, start); + } + attrbuf_insert_at( attr_out, start, pad_left, attr ); + } + if (width.fill != 0 && pad_right > 0) { + const attr_t attr = attrbuf_attr_at(attr_out,sbuf_len(out) - 1); + char buf[2]; + buf[0] = width.fill; + buf[1] = 0; + for( ssize_t i = 0; i < pad_right; i++) { // todo: optimize + attrbuf_append_n( out, attr_out, buf, 1, attr ); + } + } + } +} + +//--------------------------------------------------------- +// Print +//--------------------------------------------------------- + +ic_private ssize_t bbcode_process_tag( bbcode_t* bb, const char* s, const ssize_t nesting_base, + stringbuf_t* out, attrbuf_t* attr_out, attr_t* cur_attr ) { + assert(*s == '['); + tag_t tag; + tag_init(&tag); + bool open = true; + bool ispre = false; + char idbuf[128]; + const char* end = parse_tag( &tag, idbuf, &open, &ispre, s, bb->styles, bb->styles_count ); // todo: styles + assert(end > s); + if (open) { + if (!ispre) { + // open tag + *cur_attr = bbcode_open( bb, sbuf_len(out), &tag, *cur_attr ); + } + else { + // scan pre to end tag + attr_t attr = attr_update_with(*cur_attr, tag.attr); + char pre[132]; + if (snprintf(pre, 132, "[/%s]", idbuf) < ssizeof(pre)) { + const char* etag = strstr(end,pre); + if (etag == NULL) { + const ssize_t len = ic_strlen(end); + attrbuf_append_n(out, attr_out, end, len, attr); + end += len; + } + else { + attrbuf_append_n(out, attr_out, end, (etag - end), attr); + end = etag + ic_strlen(pre); + } + } + } + } + else { + // pop the tag + tag_t prev; + if (bbcode_close( bb, nesting_base, tag.name, &prev)) { + *cur_attr = prev.attr; + if (prev.width.w > 0) { + // closed a width tag; restrict the output to width + bbcode_restrict_width( prev.pos, prev.width, out, attr_out); + } + } + } + return (end - s); +} + +ic_private void bbcode_append( bbcode_t* bb, const char* s, stringbuf_t* out, attrbuf_t* attr_out ) { + if (bb == NULL || s == NULL) return; + attr_t attr = attr_none(); + const ssize_t base = bb->tags_nesting; // base; will not be popped + ssize_t i = 0; + while( s[i] != 0 ) { + // handle no tags in bulk + ssize_t nobb = 0; + char c; + while( (c = s[i+nobb]) != 0) { + if (c == '[' || c == '\\') { break; } + if (c == '\x1B' && s[i+nobb+1] == '[') { + nobb++; // don't count 'ESC[' as a tag opener + } + nobb++; + } + if (nobb > 0) { attrbuf_append_n(out, attr_out, s+i, nobb, attr); } + i += nobb; + // tag + if (s[i] == '[') { + i += bbcode_process_tag(bb, s+i, base, out, attr_out, &attr); + } + else if (s[i] == '\\') { + if (s[i+1] == '\\' || s[i+1] == '[') { + attrbuf_append_n(out, attr_out, s+i+1, 1, attr); // escape '\[' and '\\' + i += 2; + } + else { + attrbuf_append_n(out, attr_out, s+i, 1, attr); // pass '\\' as is + i++; + } + } + } + // pop unclosed openings + assert(bb->tags_nesting >= base); + while( bb->tags_nesting > base ) { + bbcode_tag_pop(bb,NULL); + }; +} + +ic_private void bbcode_print( bbcode_t* bb, const char* s ) { + if (bb->out == NULL || bb->out_attrs == NULL || s == NULL) return; + assert(sbuf_len(bb->out) == 0 && attrbuf_len(bb->out_attrs) == 0); + bbcode_append( bb, s, bb->out, bb->out_attrs ); + term_write_formatted( bb->term, sbuf_string(bb->out), attrbuf_attrs(bb->out_attrs,sbuf_len(bb->out)) ); + attrbuf_clear(bb->out_attrs); + sbuf_clear(bb->out); +} + +ic_private void bbcode_println( bbcode_t* bb, const char* s ) { + bbcode_print(bb,s); + term_writeln(bb->term, ""); +} + +ic_private void bbcode_vprintf( bbcode_t* bb, const char* fmt, va_list args ) { + if (bb->vout == NULL || fmt == NULL) return; + assert(sbuf_len(bb->vout) == 0); + sbuf_append_vprintf(bb->vout,fmt,args); + bbcode_print(bb, sbuf_string(bb->vout)); + sbuf_clear(bb->vout); +} + +ic_private void bbcode_printf( bbcode_t* bb, const char* fmt, ... ) { + va_list args; + va_start(args,fmt); + bbcode_vprintf(bb,fmt,args); + va_end(args); +} + +ic_private ssize_t bbcode_column_width( bbcode_t* bb, const char* s ) { + if (s==NULL || s[0] == 0) return 0; + if (bb->vout == NULL) { return str_column_width(s); } + assert(sbuf_len(bb->vout) == 0); + bbcode_append( bb, s, bb->vout, NULL); + const ssize_t w = str_column_width(sbuf_string(bb->vout)); + sbuf_clear(bb->vout); + return w; +} diff --git a/extern/isocline/src/bbcode.h b/extern/isocline/src/bbcode.h new file mode 100644 index 00000000..be96bfe2 --- /dev/null +++ b/extern/isocline/src/bbcode.h @@ -0,0 +1,37 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_BBCODE_H +#define IC_BBCODE_H + +#include +#include "common.h" +#include "term.h" + +struct bbcode_s; +typedef struct bbcode_s bbcode_t; + +ic_private bbcode_t* bbcode_new( alloc_t* mem, term_t* term ); +ic_private void bbcode_free( bbcode_t* bb ); + +ic_private void bbcode_style_add( bbcode_t* bb, const char* style_name, attr_t attr ); +ic_private void bbcode_style_def( bbcode_t* bb, const char* style_name, const char* s ); +ic_private void bbcode_style_open( bbcode_t* bb, const char* fmt ); +ic_private void bbcode_style_close( bbcode_t* bb, const char* fmt ); +ic_private attr_t bbcode_style( bbcode_t* bb, const char* style_name ); + +ic_private void bbcode_print( bbcode_t* bb, const char* s ); +ic_private void bbcode_println( bbcode_t* bb, const char* s ); +ic_private void bbcode_printf( bbcode_t* bb, const char* fmt, ... ); +ic_private void bbcode_vprintf( bbcode_t* bb, const char* fmt, va_list args ); + +ic_private ssize_t bbcode_column_width( bbcode_t* bb, const char* s ); + +// allows `attr_out == NULL`. +ic_private void bbcode_append( bbcode_t* bb, const char* s, stringbuf_t* out, attrbuf_t* attr_out ); + +#endif // IC_BBCODE_H diff --git a/extern/isocline/src/bbcode_colors.c b/extern/isocline/src/bbcode_colors.c new file mode 100644 index 00000000..245cd3de --- /dev/null +++ b/extern/isocline/src/bbcode_colors.c @@ -0,0 +1,194 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +// This file is included from "bbcode.c" and contains html color names + +#include "common.h" + +typedef struct style_color_s { + const char* name; + ic_color_t color; +} style_color_t; + +#define IC_HTML_COLOR_COUNT (172) + +// ordered list of HTML color names (so we can use binary search) +static style_color_t html_colors[IC_HTML_COLOR_COUNT+1] = { + { "aliceblue", IC_RGB(0xf0f8ff) }, + { "ansi-aqua", IC_ANSI_AQUA }, + { "ansi-black", IC_ANSI_BLACK }, + { "ansi-blue", IC_ANSI_BLUE }, + { "ansi-cyan", IC_ANSI_CYAN }, + { "ansi-darkgray", IC_ANSI_DARKGRAY }, + { "ansi-darkgrey", IC_ANSI_DARKGRAY }, + { "ansi-default", IC_ANSI_DEFAULT }, + { "ansi-fuchsia", IC_ANSI_FUCHSIA }, + { "ansi-gray", IC_ANSI_GRAY }, + { "ansi-green", IC_ANSI_GREEN }, + { "ansi-grey", IC_ANSI_GRAY }, + { "ansi-lightgray", IC_ANSI_LIGHTGRAY }, + { "ansi-lightgrey", IC_ANSI_LIGHTGRAY }, + { "ansi-lime" , IC_ANSI_LIME }, + { "ansi-magenta", IC_ANSI_MAGENTA }, + { "ansi-maroon", IC_ANSI_MAROON }, + { "ansi-navy", IC_ANSI_NAVY }, + { "ansi-olive", IC_ANSI_OLIVE }, + { "ansi-purple", IC_ANSI_PURPLE }, + { "ansi-red", IC_ANSI_RED }, + { "ansi-silver", IC_ANSI_SILVER }, + { "ansi-teal", IC_ANSI_TEAL }, + { "ansi-white", IC_ANSI_WHITE }, + { "ansi-yellow", IC_ANSI_YELLOW }, + { "antiquewhite", IC_RGB(0xfaebd7) }, + { "aqua", IC_RGB(0x00ffff) }, + { "aquamarine", IC_RGB(0x7fffd4) }, + { "azure", IC_RGB(0xf0ffff) }, + { "beige", IC_RGB(0xf5f5dc) }, + { "bisque", IC_RGB(0xffe4c4) }, + { "black", IC_RGB(0x000000) }, + { "blanchedalmond", IC_RGB(0xffebcd) }, + { "blue", IC_RGB(0x0000ff) }, + { "blueviolet", IC_RGB(0x8a2be2) }, + { "brown", IC_RGB(0xa52a2a) }, + { "burlywood", IC_RGB(0xdeb887) }, + { "cadetblue", IC_RGB(0x5f9ea0) }, + { "chartreuse", IC_RGB(0x7fff00) }, + { "chocolate", IC_RGB(0xd2691e) }, + { "coral", IC_RGB(0xff7f50) }, + { "cornflowerblue", IC_RGB(0x6495ed) }, + { "cornsilk", IC_RGB(0xfff8dc) }, + { "crimson", IC_RGB(0xdc143c) }, + { "cyan", IC_RGB(0x00ffff) }, + { "darkblue", IC_RGB(0x00008b) }, + { "darkcyan", IC_RGB(0x008b8b) }, + { "darkgoldenrod", IC_RGB(0xb8860b) }, + { "darkgray", IC_RGB(0xa9a9a9) }, + { "darkgreen", IC_RGB(0x006400) }, + { "darkgrey", IC_RGB(0xa9a9a9) }, + { "darkkhaki", IC_RGB(0xbdb76b) }, + { "darkmagenta", IC_RGB(0x8b008b) }, + { "darkolivegreen", IC_RGB(0x556b2f) }, + { "darkorange", IC_RGB(0xff8c00) }, + { "darkorchid", IC_RGB(0x9932cc) }, + { "darkred", IC_RGB(0x8b0000) }, + { "darksalmon", IC_RGB(0xe9967a) }, + { "darkseagreen", IC_RGB(0x8fbc8f) }, + { "darkslateblue", IC_RGB(0x483d8b) }, + { "darkslategray", IC_RGB(0x2f4f4f) }, + { "darkslategrey", IC_RGB(0x2f4f4f) }, + { "darkturquoise", IC_RGB(0x00ced1) }, + { "darkviolet", IC_RGB(0x9400d3) }, + { "deeppink", IC_RGB(0xff1493) }, + { "deepskyblue", IC_RGB(0x00bfff) }, + { "dimgray", IC_RGB(0x696969) }, + { "dimgrey", IC_RGB(0x696969) }, + { "dodgerblue", IC_RGB(0x1e90ff) }, + { "firebrick", IC_RGB(0xb22222) }, + { "floralwhite", IC_RGB(0xfffaf0) }, + { "forestgreen", IC_RGB(0x228b22) }, + { "fuchsia", IC_RGB(0xff00ff) }, + { "gainsboro", IC_RGB(0xdcdcdc) }, + { "ghostwhite", IC_RGB(0xf8f8ff) }, + { "gold", IC_RGB(0xffd700) }, + { "goldenrod", IC_RGB(0xdaa520) }, + { "gray", IC_RGB(0x808080) }, + { "green", IC_RGB(0x008000) }, + { "greenyellow", IC_RGB(0xadff2f) }, + { "grey", IC_RGB(0x808080) }, + { "honeydew", IC_RGB(0xf0fff0) }, + { "hotpink", IC_RGB(0xff69b4) }, + { "indianred", IC_RGB(0xcd5c5c) }, + { "indigo", IC_RGB(0x4b0082) }, + { "ivory", IC_RGB(0xfffff0) }, + { "khaki", IC_RGB(0xf0e68c) }, + { "lavender", IC_RGB(0xe6e6fa) }, + { "lavenderblush", IC_RGB(0xfff0f5) }, + { "lawngreen", IC_RGB(0x7cfc00) }, + { "lemonchiffon", IC_RGB(0xfffacd) }, + { "lightblue", IC_RGB(0xadd8e6) }, + { "lightcoral", IC_RGB(0xf08080) }, + { "lightcyan", IC_RGB(0xe0ffff) }, + { "lightgoldenrodyellow", IC_RGB(0xfafad2) }, + { "lightgray", IC_RGB(0xd3d3d3) }, + { "lightgreen", IC_RGB(0x90ee90) }, + { "lightgrey", IC_RGB(0xd3d3d3) }, + { "lightpink", IC_RGB(0xffb6c1) }, + { "lightsalmon", IC_RGB(0xffa07a) }, + { "lightseagreen", IC_RGB(0x20b2aa) }, + { "lightskyblue", IC_RGB(0x87cefa) }, + { "lightslategray", IC_RGB(0x778899) }, + { "lightslategrey", IC_RGB(0x778899) }, + { "lightsteelblue", IC_RGB(0xb0c4de) }, + { "lightyellow", IC_RGB(0xffffe0) }, + { "lime", IC_RGB(0x00ff00) }, + { "limegreen", IC_RGB(0x32cd32) }, + { "linen", IC_RGB(0xfaf0e6) }, + { "magenta", IC_RGB(0xff00ff) }, + { "maroon", IC_RGB(0x800000) }, + { "mediumaquamarine", IC_RGB(0x66cdaa) }, + { "mediumblue", IC_RGB(0x0000cd) }, + { "mediumorchid", IC_RGB(0xba55d3) }, + { "mediumpurple", IC_RGB(0x9370db) }, + { "mediumseagreen", IC_RGB(0x3cb371) }, + { "mediumslateblue", IC_RGB(0x7b68ee) }, + { "mediumspringgreen", IC_RGB(0x00fa9a) }, + { "mediumturquoise", IC_RGB(0x48d1cc) }, + { "mediumvioletred", IC_RGB(0xc71585) }, + { "midnightblue", IC_RGB(0x191970) }, + { "mintcream", IC_RGB(0xf5fffa) }, + { "mistyrose", IC_RGB(0xffe4e1) }, + { "moccasin", IC_RGB(0xffe4b5) }, + { "navajowhite", IC_RGB(0xffdead) }, + { "navy", IC_RGB(0x000080) }, + { "oldlace", IC_RGB(0xfdf5e6) }, + { "olive", IC_RGB(0x808000) }, + { "olivedrab", IC_RGB(0x6b8e23) }, + { "orange", IC_RGB(0xffa500) }, + { "orangered", IC_RGB(0xff4500) }, + { "orchid", IC_RGB(0xda70d6) }, + { "palegoldenrod", IC_RGB(0xeee8aa) }, + { "palegreen", IC_RGB(0x98fb98) }, + { "paleturquoise", IC_RGB(0xafeeee) }, + { "palevioletred", IC_RGB(0xdb7093) }, + { "papayawhip", IC_RGB(0xffefd5) }, + { "peachpuff", IC_RGB(0xffdab9) }, + { "peru", IC_RGB(0xcd853f) }, + { "pink", IC_RGB(0xffc0cb) }, + { "plum", IC_RGB(0xdda0dd) }, + { "powderblue", IC_RGB(0xb0e0e6) }, + { "purple", IC_RGB(0x800080) }, + { "rebeccapurple", IC_RGB(0x663399) }, + { "red", IC_RGB(0xff0000) }, + { "rosybrown", IC_RGB(0xbc8f8f) }, + { "royalblue", IC_RGB(0x4169e1) }, + { "saddlebrown", IC_RGB(0x8b4513) }, + { "salmon", IC_RGB(0xfa8072) }, + { "sandybrown", IC_RGB(0xf4a460) }, + { "seagreen", IC_RGB(0x2e8b57) }, + { "seashell", IC_RGB(0xfff5ee) }, + { "sienna", IC_RGB(0xa0522d) }, + { "silver", IC_RGB(0xc0c0c0) }, + { "skyblue", IC_RGB(0x87ceeb) }, + { "slateblue", IC_RGB(0x6a5acd) }, + { "slategray", IC_RGB(0x708090) }, + { "slategrey", IC_RGB(0x708090) }, + { "snow", IC_RGB(0xfffafa) }, + { "springgreen", IC_RGB(0x00ff7f) }, + { "steelblue", IC_RGB(0x4682b4) }, + { "tan", IC_RGB(0xd2b48c) }, + { "teal", IC_RGB(0x008080) }, + { "thistle", IC_RGB(0xd8bfd8) }, + { "tomato", IC_RGB(0xff6347) }, + { "turquoise", IC_RGB(0x40e0d0) }, + { "violet", IC_RGB(0xee82ee) }, + { "wheat", IC_RGB(0xf5deb3) }, + { "white", IC_RGB(0xffffff) }, + { "whitesmoke", IC_RGB(0xf5f5f5) }, + { "yellow", IC_RGB(0xffff00) }, + { "yellowgreen", IC_RGB(0x9acd32) }, + {NULL, 0} +}; diff --git a/extern/isocline/src/common.c b/extern/isocline/src/common.c new file mode 100644 index 00000000..1d9fb566 --- /dev/null +++ b/extern/isocline/src/common.c @@ -0,0 +1,347 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include "common.h" + + +//------------------------------------------------------------- +// String wrappers for ssize_t +//------------------------------------------------------------- + +ic_private ssize_t ic_strlen( const char* s ) { + if (s==NULL) return 0; + return to_ssize_t(strlen(s)); +} + +ic_private void ic_memmove( void* dest, const void* src, ssize_t n ) { + assert(dest!=NULL && src != NULL); + if (n <= 0) return; + memmove(dest,src,to_size_t(n)); +} + + +ic_private void ic_memcpy( void* dest, const void* src, ssize_t n ) { + assert(dest!=NULL && src != NULL); + if (dest == NULL || src == NULL || n <= 0) return; + memcpy(dest,src,to_size_t(n)); +} + +ic_private void ic_memset(void* dest, uint8_t value, ssize_t n) { + assert(dest!=NULL); + if (dest == NULL || n <= 0) return; + memset(dest,(int8_t)value,to_size_t(n)); +} + +ic_private bool ic_memnmove( void* dest, ssize_t dest_size, const void* src, ssize_t n ) { + assert(dest!=NULL && src != NULL); + if (n <= 0) return true; + if (dest_size < n) { assert(false); return false; } + memmove(dest,src,to_size_t(n)); + return true; +} + +ic_private bool ic_strcpy( char* dest, ssize_t dest_size /* including 0 */, const char* src) { + assert(dest!=NULL && src != NULL); + if (dest == NULL || dest_size <= 0) return false; + ssize_t slen = ic_strlen(src); + if (slen >= dest_size) return false; + strcpy(dest,src); + assert(dest[slen] == 0); + return true; +} + + +ic_private bool ic_strncpy( char* dest, ssize_t dest_size /* including 0 */, const char* src, ssize_t n) { + assert(dest!=NULL && n < dest_size); + if (dest == NULL || dest_size <= 0) return false; + if (n >= dest_size) return false; + if (src==NULL || n <= 0) { + dest[0] = 0; + } + else { + strncpy(dest,src,to_size_t(n)); + dest[n] = 0; + } + return true; +} + +//------------------------------------------------------------- +// String matching +//------------------------------------------------------------- + +ic_public bool ic_starts_with( const char* s, const char* prefix ) { + if (s==prefix) return true; + if (prefix==NULL) return true; + if (s==NULL) return false; + + ssize_t i; + for( i = 0; s[i] != 0 && prefix[i] != 0; i++) { + if (s[i] != prefix[i]) return false; + } + return (prefix[i] == 0); +} + +ic_private char ic_tolower( char c ) { + return (c >= 'A' && c <= 'Z' ? c - 'A' + 'a' : c); +} + +ic_private void ic_str_tolower(char* s) { + while(*s != 0) { + *s = ic_tolower(*s); + s++; + } +} + +ic_public bool ic_istarts_with( const char* s, const char* prefix ) { + if (s==prefix) return true; + if (prefix==NULL) return true; + if (s==NULL) return false; + + ssize_t i; + for( i = 0; s[i] != 0 && prefix[i] != 0; i++) { + if (ic_tolower(s[i]) != ic_tolower(prefix[i])) return false; + } + return (prefix[i] == 0); +} + + +ic_private int ic_strnicmp(const char* s1, const char* s2, ssize_t n) { + if (s1 == NULL && s2 == NULL) return 0; + if (s1 == NULL) return -1; + if (s2 == NULL) return 1; + ssize_t i; + for (i = 0; s1[i] != 0 && i < n; i++) { // note: if s2[i] == 0 the loop will stop as c1 != c2 + char c1 = ic_tolower(s1[i]); + char c2 = ic_tolower(s2[i]); + if (c1 < c2) return -1; + if (c1 > c2) return 1; + } + return ((i >= n || s2[i] == 0) ? 0 : -1); +} + +ic_private int ic_stricmp(const char* s1, const char* s2) { + ssize_t len1 = ic_strlen(s1); + ssize_t len2 = ic_strlen(s2); + if (len1 < len2) return -1; + if (len1 > len2) return 1; + return (ic_strnicmp(s1, s2, (len1 >= len2 ? len1 : len2))); +} + + +static const char* ic_stristr(const char* s, const char* pat) { + if (s==NULL) return NULL; + if (pat==NULL || pat[0] == 0) return s; + ssize_t patlen = ic_strlen(pat); + for (ssize_t i = 0; s[i] != 0; i++) { + if (ic_strnicmp(s + i, pat, patlen) == 0) return (s+i); + } + return NULL; +} + +ic_private bool ic_contains(const char* big, const char* s) { + if (big == NULL) return false; + if (s == NULL) return true; + return (strstr(big,s) != NULL); +} + +ic_private bool ic_icontains(const char* big, const char* s) { + if (big == NULL) return false; + if (s == NULL) return true; + return (ic_stristr(big,s) != NULL); +} + + +//------------------------------------------------------------- +// Unicode +// QUTF-8: See +// Raw bytes are code points 0xEE000 - 0xEE0FF +//------------------------------------------------------------- +#define IC_UNICODE_RAW ((unicode_t)(0xEE000U)) + +ic_private unicode_t unicode_from_raw(uint8_t c) { + return (IC_UNICODE_RAW + c); +} + +ic_private bool unicode_is_raw(unicode_t u, uint8_t* c) { + if (u >= IC_UNICODE_RAW && u <= IC_UNICODE_RAW + 0xFF) { + *c = (uint8_t)(u - IC_UNICODE_RAW); + return true; + } + else { + return false; + } +} + +ic_private void unicode_to_qutf8(unicode_t u, uint8_t buf[5]) { + memset(buf, 0, 5); + if (u <= 0x7F) { + buf[0] = (uint8_t)u; + } + else if (u <= 0x07FF) { + buf[0] = (0xC0 | ((uint8_t)(u >> 6))); + buf[1] = (0x80 | (((uint8_t)u) & 0x3F)); + } + else if (u <= 0xFFFF) { + buf[0] = (0xE0 | ((uint8_t)(u >> 12))); + buf[1] = (0x80 | (((uint8_t)(u >> 6)) & 0x3F)); + buf[2] = (0x80 | (((uint8_t)u) & 0x3F)); + } + else if (u <= 0x10FFFF) { + if (unicode_is_raw(u, &buf[0])) { + buf[1] = 0; + } + else { + buf[0] = (0xF0 | ((uint8_t)(u >> 18))); + buf[1] = (0x80 | (((uint8_t)(u >> 12)) & 0x3F)); + buf[2] = (0x80 | (((uint8_t)(u >> 6)) & 0x3F)); + buf[3] = (0x80 | (((uint8_t)u) & 0x3F)); + } + } +} + +// is this a utf8 continuation byte? +ic_private bool utf8_is_cont(uint8_t c) { + return ((c & 0xC0) == 0x80); +} + +ic_private unicode_t unicode_from_qutf8(const uint8_t* s, ssize_t len, ssize_t* count) { + unicode_t c0 = 0; + if (len <= 0 || s == NULL) { + goto fail; + } + // 1 byte + c0 = s[0]; + if (c0 <= 0x7F && len >= 1) { + if (count != NULL) *count = 1; + return c0; + } + else if (c0 <= 0xC1) { // invalid continuation byte or invalid 0xC0, 0xC1 + goto fail; + } + // 2 bytes + else if (c0 <= 0xDF && len >= 2 && utf8_is_cont(s[1])) { + if (count != NULL) *count = 2; + return (((c0 & 0x1F) << 6) | (s[1] & 0x3F)); + } + // 3 bytes: reject overlong and surrogate halves + else if (len >= 3 && + ((c0 == 0xE0 && s[1] >= 0xA0 && s[1] <= 0xBF && utf8_is_cont(s[2])) || + (c0 >= 0xE1 && c0 <= 0xEC && utf8_is_cont(s[1]) && utf8_is_cont(s[2])) + )) + { + if (count != NULL) *count = 3; + return (((c0 & 0x0F) << 12) | ((unicode_t)(s[1] & 0x3F) << 6) | (s[2] & 0x3F)); + } + // 4 bytes: reject overlong + else if (len >= 4 && + (((c0 == 0xF0 && s[1] >= 0x90 && s[1] <= 0xBF && utf8_is_cont(s[2]) && utf8_is_cont(s[3])) || + (c0 >= 0xF1 && c0 <= 0xF3 && utf8_is_cont(s[1]) && utf8_is_cont(s[2]) && utf8_is_cont(s[3])) || + (c0 == 0xF4 && s[1] >= 0x80 && s[1] <= 0x8F && utf8_is_cont(s[2]) && utf8_is_cont(s[3]))) + )) + { + if (count != NULL) *count = 4; + return (((c0 & 0x07) << 18) | ((unicode_t)(s[1] & 0x3F) << 12) | ((unicode_t)(s[2] & 0x3F) << 6) | (s[3] & 0x3F)); + } +fail: + if (count != NULL) *count = 1; + return unicode_from_raw(s[0]); +} + + +//------------------------------------------------------------- +// Debug +//------------------------------------------------------------- + +#if defined(IC_NO_DEBUG_MSG) +// nothing +#elif !defined(IC_DEBUG_TO_FILE) +ic_private void debug_msg(const char* fmt, ...) { + if (getenv("ISOCLINE_DEBUG")) { + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + } +} +#else +ic_private void debug_msg(const char* fmt, ...) { + static int debug_init; + static const char* debug_fname = "isocline.debug.txt"; + // initialize? + if (debug_init==0) { + debug_init = -1; + const char* rdebug = getenv("ISOCLINE_DEBUG"); + if (rdebug!=NULL && strcmp(rdebug,"1") == 0) { + FILE* fdbg = fopen(debug_fname, "w"); + if (fdbg!=NULL) { + debug_init = 1; + fclose(fdbg); + } + } + } + if (debug_init <= 0) return; + + // write debug messages + FILE* fdbg = fopen(debug_fname, "a"); + if (fdbg==NULL) return; + va_list args; + va_start(args, fmt); + vfprintf(fdbg, fmt, args); + fclose(fdbg); + va_end(args); +} +#endif + + +//------------------------------------------------------------- +// Allocation +//------------------------------------------------------------- + +ic_private void* mem_malloc(alloc_t* mem, ssize_t sz) { + return mem->malloc(to_size_t(sz)); +} + +ic_private void* mem_zalloc(alloc_t* mem, ssize_t sz) { + void* p = mem_malloc(mem, sz); + if (p != NULL) memset(p, 0, to_size_t(sz)); + return p; +} + +ic_private void* mem_realloc(alloc_t* mem, void* p, ssize_t newsz) { + return mem->realloc(p, to_size_t(newsz)); +} + +ic_private void mem_free(alloc_t* mem, const void* p) { + mem->free((void*)p); +} + +ic_private char* mem_strdup(alloc_t* mem, const char* s) { + if (s==NULL) return NULL; + ssize_t n = ic_strlen(s); + char* p = mem_malloc_tp_n(mem, char, n+1); + if (p == NULL) return NULL; + ic_memcpy(p, s, n+1); + return p; +} + +ic_private char* mem_strndup(alloc_t* mem, const char* s, ssize_t n) { + if (s==NULL || n < 0) return NULL; + char* p = mem_malloc_tp_n(mem, char, n+1); + if (p == NULL) return NULL; + ssize_t i; + for (i = 0; i < n && s[i] != 0; i++) { + p[i] = s[i]; + } + assert(i <= n); + p[i] = 0; + return p; +} + diff --git a/extern/isocline/src/common.h b/extern/isocline/src/common.h new file mode 100644 index 00000000..dd5b2569 --- /dev/null +++ b/extern/isocline/src/common.h @@ -0,0 +1,187 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +#pragma once +#ifndef IC_COMMON_H +#define IC_COMMON_H + +//------------------------------------------------------------- +// Headers and defines +//------------------------------------------------------------- + +#include // ssize_t +#include +#include +#include +#include +#include +#include "../include/isocline.h" // ic_malloc_fun_t, ic_color_t etc. + +# ifdef __cplusplus +# define ic_extern_c extern "C" +# else +# define ic_extern_c +# endif + +#if defined(IC_SEPARATE_OBJS) +# define ic_public ic_extern_c +# if defined(__GNUC__) // includes clang and icc +# define ic_private __attribute__((visibility("hidden"))) +# else +# define ic_private +# endif +#else +# define ic_private static +# define ic_public ic_extern_c +#endif + +#define ic_unused(x) (void)(x) + + +//------------------------------------------------------------- +// ssize_t +//------------------------------------------------------------- + +#if defined(_MSC_VER) +typedef intptr_t ssize_t; +#endif + +#define ssizeof(tp) (ssize_t)(sizeof(tp)) +static inline size_t to_size_t(ssize_t sz) { return (sz >= 0 ? (size_t)sz : 0); } +static inline ssize_t to_ssize_t(size_t sz) { return (sz <= SIZE_MAX/2 ? (ssize_t)sz : 0); } + +ic_private void ic_memmove(void* dest, const void* src, ssize_t n); +ic_private void ic_memcpy(void* dest, const void* src, ssize_t n); +ic_private void ic_memset(void* dest, uint8_t value, ssize_t n); +ic_private bool ic_memnmove(void* dest, ssize_t dest_size, const void* src, ssize_t n); + +ic_private ssize_t ic_strlen(const char* s); +ic_private bool ic_strcpy(char* dest, ssize_t dest_size /* including 0 */, const char* src); +ic_private bool ic_strncpy(char* dest, ssize_t dest_size /* including 0 */, const char* src, ssize_t n); + +ic_private bool ic_contains(const char* big, const char* s); +ic_private bool ic_icontains(const char* big, const char* s); +ic_private char ic_tolower(char c); +ic_private void ic_str_tolower(char* s); +ic_private int ic_stricmp(const char* s1, const char* s2); +ic_private int ic_strnicmp(const char* s1, const char* s2, ssize_t n); + + + +//--------------------------------------------------------------------- +// Unicode +// +// We use "qutf-8" (quite like utf-8) encoding and decoding. +// Internally we always use valid utf-8. If we encounter invalid +// utf-8 bytes (or bytes >= 0x80 from any other encoding) we encode +// these as special code points in the "raw plane" (0xEE000 - 0xEE0FF). +// When decoding we are then able to restore such raw bytes as-is. +// See +//--------------------------------------------------------------------- + +typedef uint32_t unicode_t; + +ic_private void unicode_to_qutf8(unicode_t u, uint8_t buf[5]); +ic_private unicode_t unicode_from_qutf8(const uint8_t* s, ssize_t len, ssize_t* nread); // validating + +ic_private unicode_t unicode_from_raw(uint8_t c); +ic_private bool unicode_is_raw(unicode_t u, uint8_t* c); + +ic_private bool utf8_is_cont(uint8_t c); + + +//------------------------------------------------------------- +// Colors +//------------------------------------------------------------- + +// A color is either RGB or an ANSI code. +// (RGB colors have bit 24 set to distinguish them from the ANSI color palette colors.) +// (Isocline will automatically convert from RGB on terminals that do not support full colors) +typedef uint32_t ic_color_t; + +// Create a color from a 24-bit color value. +ic_private ic_color_t ic_rgb(uint32_t hex); + +// Create a color from a 8-bit red/green/blue components. +// The value of each component is capped between 0 and 255. +ic_private ic_color_t ic_rgbx(ssize_t r, ssize_t g, ssize_t b); + +#define IC_COLOR_NONE (0) +#define IC_RGB(rgb) (0x1000000 | (uint32_t)(rgb)) // ic_rgb(rgb) // define to it can be used as a constant + +// ANSI colors. +// The actual colors used is usually determined by the terminal theme +// See +#define IC_ANSI_BLACK (30) +#define IC_ANSI_MAROON (31) +#define IC_ANSI_GREEN (32) +#define IC_ANSI_OLIVE (33) +#define IC_ANSI_NAVY (34) +#define IC_ANSI_PURPLE (35) +#define IC_ANSI_TEAL (36) +#define IC_ANSI_SILVER (37) +#define IC_ANSI_DEFAULT (39) + +#define IC_ANSI_GRAY (90) +#define IC_ANSI_RED (91) +#define IC_ANSI_LIME (92) +#define IC_ANSI_YELLOW (93) +#define IC_ANSI_BLUE (94) +#define IC_ANSI_FUCHSIA (95) +#define IC_ANSI_AQUA (96) +#define IC_ANSI_WHITE (97) + +#define IC_ANSI_DARKGRAY IC_ANSI_GRAY +#define IC_ANSI_LIGHTGRAY IC_ANSI_SILVER +#define IC_ANSI_MAGENTA IC_ANSI_FUCHSIA +#define IC_ANSI_CYAN IC_ANSI_AQUA + + + +//------------------------------------------------------------- +// Debug +//------------------------------------------------------------- + +#if defined(IC_NO_DEBUG_MSG) +#define debug_msg(fmt,...) (void)(0) +#else +ic_private void debug_msg( const char* fmt, ... ); +#endif + + +//------------------------------------------------------------- +// Abstract environment +//------------------------------------------------------------- +struct ic_env_s; +typedef struct ic_env_s ic_env_t; + + +//------------------------------------------------------------- +// Allocation +//------------------------------------------------------------- + +typedef struct alloc_s { + ic_malloc_fun_t* malloc; + ic_realloc_fun_t* realloc; + ic_free_fun_t* free; +} alloc_t; + + +ic_private void* mem_malloc( alloc_t* mem, ssize_t sz ); +ic_private void* mem_zalloc( alloc_t* mem, ssize_t sz ); +ic_private void* mem_realloc( alloc_t* mem, void* p, ssize_t newsz ); +ic_private void mem_free( alloc_t* mem, const void* p ); +ic_private char* mem_strdup( alloc_t* mem, const char* s); +ic_private char* mem_strndup( alloc_t* mem, const char* s, ssize_t n); + +#define mem_zalloc_tp(mem,tp) (tp*)mem_zalloc(mem,ssizeof(tp)) +#define mem_malloc_tp_n(mem,tp,n) (tp*)mem_malloc(mem,(n)*ssizeof(tp)) +#define mem_zalloc_tp_n(mem,tp,n) (tp*)mem_zalloc(mem,(n)*ssizeof(tp)) +#define mem_realloc_tp(mem,tp,p,n) (tp*)mem_realloc(mem,p,(n)*ssizeof(tp)) + + +#endif // IC_COMMON_H diff --git a/extern/isocline/src/completers.c b/extern/isocline/src/completers.c new file mode 100644 index 00000000..e9701c16 --- /dev/null +++ b/extern/isocline/src/completers.c @@ -0,0 +1,675 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" +#include "stringbuf.h" +#include "completions.h" + + + +//------------------------------------------------------------- +// Word completion +//------------------------------------------------------------- + +// free variables for word completion +typedef struct word_closure_s { + long delete_before_adjust; + void* prev_env; + ic_completion_fun_t* prev_complete; +} word_closure_t; + + +// word completion callback +static bool token_add_completion_ex(ic_env_t* env, void* closure, const char* replacement, const char* display, const char* help, long delete_before, long delete_after) { + word_closure_t* wenv = (word_closure_t*)(closure); + // call the previous completer with an adjusted delete-before + return (*wenv->prev_complete)(env, wenv->prev_env, replacement, display, help, wenv->delete_before_adjust + delete_before, delete_after); +} + + +ic_public void ic_complete_word(ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, + ic_is_char_class_fun_t* is_word_char) +{ + if (is_word_char == NULL) is_word_char = &ic_char_is_nonseparator; + + ssize_t len = ic_strlen(prefix); + ssize_t pos = len; // will be start of the 'word' (excluding a potential start quote) + while (pos > 0) { + // go back one code point + ssize_t ofs = str_prev_ofs(prefix, pos, NULL); + if (ofs <= 0) break; + if (!(*is_word_char)(prefix + (pos - ofs), (long)ofs)) { + break; + } + pos -= ofs; + } + if (pos < 0) { pos = 0; } + + // stop if empty word + // if (len == pos) return; + + // set up the closure + word_closure_t wenv; + wenv.delete_before_adjust = (long)(len - pos); + wenv.prev_complete = cenv->complete; + wenv.prev_env = cenv->env; + cenv->complete = &token_add_completion_ex; + cenv->closure = &wenv; + + // and call the user completion routine + (*fun)(cenv, prefix + pos); + + // restore the original environment + cenv->complete = wenv.prev_complete; + cenv->closure = wenv.prev_env; +} + + +//------------------------------------------------------------- +// Quoted word completion (with escape characters) +//------------------------------------------------------------- + +// free variables for word completion +typedef struct qword_closure_s { + char escape_char; + char quote; + long delete_before_adjust; + stringbuf_t* sbuf; + void* prev_env; + ic_is_char_class_fun_t* is_word_char; + ic_completion_fun_t* prev_complete; +} qword_closure_t; + + +// word completion callback +static bool qword_add_completion_ex(ic_env_t* env, void* closure, const char* replacement, const char* display, const char* help, + long delete_before, long delete_after) { + qword_closure_t* wenv = (qword_closure_t*)(closure); + sbuf_replace( wenv->sbuf, replacement ); + if (wenv->quote != 0) { + // add end quote + sbuf_append_char( wenv->sbuf, wenv->quote); + } + else { + // escape non-word characters if it was not quoted + ssize_t pos = 0; + ssize_t next; + while ( (next = sbuf_next_ofs(wenv->sbuf, pos, NULL)) > 0 ) + { + if (!(*wenv->is_word_char)(sbuf_string(wenv->sbuf) + pos, (long)next)) { // strchr(wenv->non_word_char, sbuf_char_at( wenv->sbuf, pos )) != NULL) { + sbuf_insert_char_at( wenv->sbuf, wenv->escape_char, pos); + pos++; + } + pos += next; + } + } + // and call the previous completion function + return (*wenv->prev_complete)( env, wenv->prev_env, sbuf_string(wenv->sbuf), display, help, wenv->delete_before_adjust + delete_before, delete_after ); +} + + +ic_public void ic_complete_qword( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char ) { + ic_complete_qword_ex( cenv, prefix, fun, is_word_char, '\\', NULL); +} + + +ic_public void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, + ic_is_char_class_fun_t* is_word_char, char escape_char, const char* quote_chars ) { + if (is_word_char == NULL) is_word_char = &ic_char_is_nonseparator ; + if (quote_chars == NULL) quote_chars = "'\""; + + ssize_t len = ic_strlen(prefix); + ssize_t pos; // will be start of the 'word' (excluding a potential start quote) + char quote = 0; + ssize_t quote_len = 0; + + // 1. look for a starting quote + if (quote_chars[0] != 0) { + // we go forward and count all quotes; if it is uneven, we need to complete quoted. + ssize_t qpos_open = -1; + ssize_t qpos_close = -1; + ssize_t qcount = 0; + pos = 0; + while(pos < len) { + if (prefix[pos] == escape_char && prefix[pos+1] != 0 && + !(*is_word_char)(prefix + pos + 1, 1)) // strchr(non_word_char, prefix[pos+1]) != NULL + { + pos++; // skip escape and next char + } + else if (qcount % 2 == 0 && strchr(quote_chars, prefix[pos]) != NULL) { + // open quote + qpos_open = pos; + quote = prefix[pos]; + qcount++; + } + else if (qcount % 2 == 1 && prefix[pos] == quote) { + // close quote + qpos_close = pos; + qcount++; + } + else if (!(*is_word_char)(prefix + pos, 1)) { // strchr(non_word_char, prefix[pos]) != NULL) { + qpos_close = -1; + } + ssize_t ofs = str_next_ofs( prefix, len, pos, NULL ); + if (ofs <= 0) break; + pos += ofs; + } + if ((qcount % 2 == 0 && qpos_close >= 0) || // if the last quote is only followed by word chars, we still complete it + (qcount % 2 == 1)) // opening quote found + { + quote_len = (len - qpos_open - 1); + pos = qpos_open + 1; // pos points to the word start just after the quote. + } + else { + quote = 0; + } + } + + // 2. if we did not find a quoted word, look for non-word-chars + if (quote == 0) { + pos = len; + while(pos > 0) { + // go back one code point + ssize_t ofs = str_prev_ofs(prefix, pos, NULL ); + if (ofs <= 0) break; + if (!(*is_word_char)(prefix + (pos - ofs), (long)ofs)) { // strchr(non_word_char, prefix[pos - ofs]) != NULL) { + // non word char, break if it is not escaped + if (pos <= ofs || prefix[pos - ofs - 1] != escape_char) break; + // otherwise go on + pos--; // skip escaped char + } + pos -= ofs; + } + } + + // stop if empty word + // if (len == pos) return; + + // allocate new unescaped word prefix + char* word = mem_strndup( cenv->env->mem, prefix + pos, (quote==0 ? len - pos : quote_len)); + if (word == NULL) return; + + if (quote == 0) { + // unescape prefix + ssize_t wlen = len - pos; + ssize_t wpos = 0; + while (wpos < wlen) { + ssize_t ofs = str_next_ofs(word, wlen, wpos, NULL); + if (ofs <= 0) break; + if (word[wpos] == escape_char && word[wpos+1] != 0 && + !(*is_word_char)(word + wpos + 1, (long)ofs)) // strchr(non_word_char, word[wpos+1]) != NULL) { + { + ic_memmove(word + wpos, word + wpos + 1, wlen - wpos /* including 0 */); + } + wpos += ofs; + } + } + #ifdef _WIN32 + else { + // remove inner quote: "c:\Program Files\"Win + ssize_t wlen = len - pos; + ssize_t wpos = 0; + while (wpos < wlen) { + ssize_t ofs = str_next_ofs(word, wlen, wpos, NULL); + if (ofs <= 0) break; + if (word[wpos] == escape_char && word[wpos+1] == quote) { + word[wpos+1] = escape_char; + ic_memmove(word + wpos, word + wpos + 1, wlen - wpos /* including 0 */); + } + wpos += ofs; + } + } + #endif + + // set up the closure + qword_closure_t wenv; + wenv.quote = quote; + wenv.is_word_char = is_word_char; + wenv.escape_char = escape_char; + wenv.delete_before_adjust = (long)(len - pos); + wenv.prev_complete = cenv->complete; + wenv.prev_env = cenv->env; + wenv.sbuf = sbuf_new(cenv->env->mem); + if (wenv.sbuf == NULL) { mem_free(cenv->env->mem, word); return; } + cenv->complete = &qword_add_completion_ex; + cenv->closure = &wenv; + + // and call the user completion routine + (*fun)( cenv, word ); + + // restore the original environment + cenv->complete = wenv.prev_complete; + cenv->closure = wenv.prev_env; + + sbuf_free(wenv.sbuf); + mem_free(cenv->env->mem, word); +} + + + + +//------------------------------------------------------------- +// Complete file names +// Listing files +//------------------------------------------------------------- +#include + +typedef enum file_type_e { + // must follow BSD style LSCOLORS order + FT_DEFAULT = 0, + FT_DIR, + FT_SYM, + FT_SOCK, + FT_PIPE, + FT_BLOCK, + FT_CHAR, + FT_SETUID, + FT_SETGID, + FT_DIR_OW_STICKY, + FT_DIR_OW, + FT_DIR_STICKY, + FT_EXE, + FT_LAST +} file_type_t; + +static int cli_color; // 1 enabled, 0 not initialized, -1 disabled +static const char* lscolors = "exfxcxdxbxegedabagacad"; // default BSD setting +static const char* ls_colors; +static const char* ls_colors_names[] = { "no=","di=","ln=","so=","pi=","bd=","cd=","su=","sg=","tw=","ow=","st=","ex=", NULL }; + +static bool ls_colors_init(void) { + if (cli_color != 0) return (cli_color >= 1); + // colors enabled? + const char* s = getenv("CLICOLOR"); + if (s==NULL || (strcmp(s, "1")!=0 && strcmp(s, "") != 0)) { + cli_color = -1; + return false; + } + cli_color = 1; + s = getenv("LS_COLORS"); + if (s != NULL) { ls_colors = s; } + s = getenv("LSCOLORS"); + if (s != NULL) { lscolors = s; } + return true; +} + +static bool ls_valid_esc(ssize_t c) { + return ((c==0 || c==1 || c==4 || c==7 || c==22 || c==24 || c==27) || + (c >= 30 && c <= 37) || (c >= 40 && c <= 47) || + (c >= 90 && c <= 97) || (c >= 100 && c <= 107)); +} + +static bool ls_colors_from_key(stringbuf_t* sb, const char* key) { + // find key + ssize_t keylen = ic_strlen(key); + if (keylen <= 0) return false; + const char* p = strstr(ls_colors, key); + if (p == NULL) return false; + p += keylen; + if (key[keylen-1] != '=') { + if (*p != '=') return false; + p++; + } + ssize_t len = 0; + while (p[len] != 0 && p[len] != ':') { + len++; + } + if (len <= 0) return false; + sbuf_append(sb, "[ansi-sgr=\"" ); + sbuf_append_n(sb, p, len ); + sbuf_append(sb, "\"]"); + return true; +} + +static int ls_colors_from_char(char c) { + if (c >= 'a' && c <= 'h') { return (c - 'a'); } + else if (c >= 'A' && c <= 'H') { return (c - 'A') + 8; } + else if (c == 'x') { return 256; } + else return 256; // default +} + +static bool ls_colors_append(stringbuf_t* sb, file_type_t ft, const char* ext) { + if (!ls_colors_init()) return false; + if (ls_colors != NULL) { + // GNU style + if (ft == FT_DEFAULT && ext != NULL) { + // first try extension match + if (ls_colors_from_key(sb, ext)) return true; + } + if (ft >= FT_DEFAULT && ft < FT_LAST) { + // then a filetype match + const char* key = ls_colors_names[ft]; + if (ls_colors_from_key(sb, key)) return true; + } + } + else if (lscolors != NULL) { + // BSD style + char fg = 'x'; + char bg = 'x'; + if (ic_strlen(lscolors) > (2*(ssize_t)ft)+1) { + fg = lscolors[2*ft]; + bg = lscolors[2*ft + 1]; + } + sbuf_appendf(sb, "[ansi-color=%d ansi-bgcolor=%d]", ls_colors_from_char(fg), ls_colors_from_char(bg) ); + return true; + } + return false; +} + +static void ls_colorize(bool no_lscolor, stringbuf_t* sb, file_type_t ft, const char* name, const char* ext, char dirsep) { + bool close = (no_lscolor ? false : ls_colors_append( sb, ft, ext)); + sbuf_append(sb, "[!pre]" ); + sbuf_append(sb, name); + if (dirsep != 0) sbuf_append_char(sb, dirsep); + sbuf_append(sb,"[/pre]" ); + if (close) { sbuf_append(sb, "[/]"); } +} + +#if defined(_WIN32) +#include +#include + +static bool os_is_dir(const char* cpath) { + struct _stat64 st = { 0 }; + _stat64(cpath, &st); + return ((st.st_mode & _S_IFDIR) != 0); +} + +static file_type_t os_get_filetype(const char* cpath) { + struct _stat64 st = { 0 }; + _stat64(cpath, &st); + if (((st.st_mode) & _S_IFDIR) != 0) return FT_DIR; + if (((st.st_mode) & _S_IFCHR) != 0) return FT_CHAR; + if (((st.st_mode) & _S_IFIFO) != 0) return FT_PIPE; + if (((st.st_mode) & _S_IEXEC) != 0) return FT_EXE; + return FT_DEFAULT; +} + + +#define dir_cursor intptr_t +#define dir_entry struct __finddata64_t + +static bool os_findfirst(alloc_t* mem, const char* path, dir_cursor* d, dir_entry* entry) { + stringbuf_t* spath = sbuf_new(mem); + if (spath == NULL) return false; + sbuf_append(spath, path); + sbuf_append(spath, "\\*"); + *d = _findfirsti64(sbuf_string(spath), entry); + mem_free(mem,spath); + return (*d != -1); +} + +static bool os_findnext(dir_cursor d, dir_entry* entry) { + return (_findnexti64(d, entry) == 0); +} + +static void os_findclose(dir_cursor d) { + _findclose(d); +} + +static const char* os_direntry_name(dir_entry* entry) { + return entry->name; +} + +static bool os_path_is_absolute( const char* path ) { + if (path != NULL && path[0] != 0 && path[1] == ':' && (path[2] == '\\' || path[2] == '/' || path[2] == 0)) { + char drive = path[0]; + return ((drive >= 'A' && drive <= 'Z') || (drive >= 'a' && drive <= 'z')); + } + else return false; +} + +ic_private char ic_dirsep(void) { + return '\\'; +} +#else + +#include +#include +#include +#include + +static bool os_is_dir(const char* cpath) { + struct stat st; + memset(&st, 0, sizeof(st)); + stat(cpath, &st); + return (S_ISDIR(st.st_mode)); +} + +static file_type_t os_get_filetype(const char* cpath) { + struct stat st; + memset(&st, 0, sizeof(st)); + lstat(cpath, &st); + switch ((st.st_mode)&S_IFMT) { + case S_IFSOCK: return FT_SOCK; + case S_IFLNK: { + return FT_SYM; + } + case S_IFIFO: return FT_PIPE; + case S_IFCHR: return FT_CHAR; + case S_IFBLK: return FT_BLOCK; + case S_IFDIR: { + if ((st.st_mode & S_ISUID) != 0) return FT_SETUID; + if ((st.st_mode & S_ISGID) != 0) return FT_SETGID; + if ((st.st_mode & S_IWGRP) != 0 && (st.st_mode & S_ISVTX) != 0) return FT_DIR_OW_STICKY; + if ((st.st_mode & S_IWGRP)) return FT_DIR_OW; + if ((st.st_mode & S_ISVTX)) return FT_DIR_STICKY; + return FT_DIR; + } + case S_IFREG: + default: { + if ((st.st_mode & S_IXUSR) != 0) return FT_EXE; + return FT_DEFAULT; + } + } +} + + +#define dir_cursor DIR* +#define dir_entry struct dirent* + +static bool os_findnext(dir_cursor d, dir_entry* entry) { + *entry = readdir(d); + return (*entry != NULL); +} + +static bool os_findfirst(alloc_t* mem, const char* cpath, dir_cursor* d, dir_entry* entry) { + ic_unused(mem); + *d = opendir(cpath); + if (*d == NULL) { + return false; + } + else { + return os_findnext(*d, entry); + } +} + +static void os_findclose(dir_cursor d) { + closedir(d); +} + +static const char* os_direntry_name(dir_entry* entry) { + return (*entry)->d_name; +} + +static bool os_path_is_absolute( const char* path ) { + return (path != NULL && path[0] == '/'); +} + +ic_private char ic_dirsep(void) { + return '/'; +} +#endif + + + +//------------------------------------------------------------- +// File completion +//------------------------------------------------------------- + +static bool ends_with_n(const char* name, ssize_t name_len, const char* ending, ssize_t len) { + if (name_len < len) return false; + if (ending == NULL || len <= 0) return true; + for (ssize_t i = 1; i <= len; i++) { + char c1 = name[name_len - i]; + char c2 = ending[len - i]; + #ifdef _WIN32 + if (ic_tolower(c1) != ic_tolower(c2)) return false; + #else + if (c1 != c2) return false; + #endif + } + return true; +} + +static bool match_extension(const char* name, const char* extensions) { + if (extensions == NULL || extensions[0] == 0) return true; + if (name == NULL) return false; + ssize_t name_len = ic_strlen(name); + ssize_t len = ic_strlen(extensions); + ssize_t cur = 0; + //debug_msg("match extensions: %s ~ %s", name, extensions); + for (ssize_t end = 0; end <= len; end++) { + if (extensions[end] == ';' || extensions[end] == 0) { + if (ends_with_n(name, name_len, extensions+cur, (end - cur))) { + return true; + } + cur = end+1; + } + } + return false; +} + +static bool filename_complete_indir( ic_completion_env_t* cenv, stringbuf_t* dir, + stringbuf_t* dir_prefix, stringbuf_t* display, + const char* base_prefix, + char dir_sep, const char* extensions ) +{ + dir_cursor d = 0; + dir_entry entry; + bool cont = true; + if (os_findfirst(cenv->env->mem, sbuf_string(dir), &d, &entry)) { + do { + const char* name = os_direntry_name(&entry); + if (name != NULL && strcmp(name, ".") != 0 && strcmp(name, "..") != 0 && + ic_istarts_with(name, base_prefix)) + { + // possible match, first check if it is a directory + file_type_t ft; + bool isdir; + const ssize_t plen = sbuf_len(dir_prefix); + sbuf_append(dir_prefix, name); + { // check directory and potentially add a dirsep to the dir_prefix + const ssize_t dlen = sbuf_len(dir); + sbuf_append_char(dir,ic_dirsep()); + sbuf_append(dir,name); + ft = os_get_filetype(sbuf_string(dir)); + isdir = os_is_dir(sbuf_string(dir)); + if (isdir && dir_sep != 0) { + sbuf_append_char(dir_prefix,dir_sep); + } + sbuf_delete_from(dir,dlen); // restore dir + } + if (isdir || match_extension(name, extensions)) { + // add completion + sbuf_clear(display); + ls_colorize(cenv->env->no_lscolors, display, ft, name, NULL, (isdir ? dir_sep : 0)); + cont = ic_add_completion_ex(cenv, sbuf_string(dir_prefix), sbuf_string(display), NULL); + } + sbuf_delete_from( dir_prefix, plen ); // restore dir_prefix + } + } while (cont && os_findnext(d, &entry)); + os_findclose(d); + } + return cont; +} + +typedef struct filename_closure_s { + const char* roots; + const char* extensions; + char dir_sep; +} filename_closure_t; + +static void filename_completer( ic_completion_env_t* cenv, const char* prefix ) { + if (prefix == NULL) return; + filename_closure_t* fclosure = (filename_closure_t*)cenv->arg; + stringbuf_t* root_dir = sbuf_new(cenv->env->mem); + stringbuf_t* dir_prefix = sbuf_new(cenv->env->mem); + stringbuf_t* display = sbuf_new(cenv->env->mem); + if (root_dir!=NULL && dir_prefix != NULL && display != NULL) + { + // split prefix in dir_prefix / base. + const char* base = strrchr(prefix,'/'); + #ifdef _WIN32 + const char* base2 = strrchr(prefix,'\\'); + if (base == NULL || base2 > base) base = base2; + #endif + if (base != NULL) { + base++; + sbuf_append_n(dir_prefix, prefix, base - prefix ); // includes dir separator + } + + // absolute path + if (os_path_is_absolute(prefix)) { + // do not use roots but try to complete directly + if (base != NULL) { + sbuf_append_n( root_dir, prefix, (base - prefix)); // include dir separator + } + filename_complete_indir( cenv, root_dir, dir_prefix, display, + (base != NULL ? base : prefix), + fclosure->dir_sep, fclosure->extensions ); + } + else { + // relative path, complete with respect to every root. + const char* next; + const char* root = fclosure->roots; + while ( root != NULL ) { + // create full root in `root_dir` + sbuf_clear(root_dir); + next = strchr(root,';'); + if (next == NULL) { + sbuf_append( root_dir, root ); + root = NULL; + } + else { + sbuf_append_n( root_dir, root, next - root ); + root = next + 1; + } + sbuf_append_char( root_dir, ic_dirsep()); + + // add the dir_prefix to the root + if (base != NULL) { + sbuf_append_n( root_dir, prefix, (base - prefix) - 1); + } + + // and complete in this directory + filename_complete_indir( cenv, root_dir, dir_prefix, display, + (base != NULL ? base : prefix), + fclosure->dir_sep, fclosure->extensions); + } + } + } + sbuf_free(display); + sbuf_free(root_dir); + sbuf_free(dir_prefix); +} + +ic_public void ic_complete_filename( ic_completion_env_t* cenv, const char* prefix, char dir_sep, const char* roots, const char* extensions ) { + if (roots == NULL) roots = "."; + if (extensions == NULL) extensions = ""; + if (dir_sep == 0) dir_sep = ic_dirsep(); + filename_closure_t fclosure; + fclosure.dir_sep = dir_sep; + fclosure.roots = roots; + fclosure.extensions = extensions; + cenv->arg = &fclosure; + ic_complete_qword_ex( cenv, prefix, &filename_completer, &ic_char_is_filename_letter, '\\', "'\""); +} diff --git a/extern/isocline/src/completions.c b/extern/isocline/src/completions.c new file mode 100644 index 00000000..01453efc --- /dev/null +++ b/extern/isocline/src/completions.c @@ -0,0 +1,326 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" +#include "stringbuf.h" +#include "completions.h" + + +//------------------------------------------------------------- +// Completions +//------------------------------------------------------------- + +typedef struct completion_s { + const char* replacement; + const char* display; + const char* help; + ssize_t delete_before; + ssize_t delete_after; +} completion_t; + +struct completions_s { + ic_completer_fun_t* completer; + void* completer_arg; + ssize_t completer_max; + ssize_t count; + ssize_t len; + completion_t* elems; + alloc_t* mem; +}; + +static void default_filename_completer( ic_completion_env_t* cenv, const char* prefix ); + +ic_private completions_t* completions_new(alloc_t* mem) { + completions_t* cms = mem_zalloc_tp(mem, completions_t); + if (cms == NULL) return NULL; + cms->mem = mem; + cms->completer = &default_filename_completer; + return cms; +} + +ic_private void completions_free(completions_t* cms) { + if (cms == NULL) return; + completions_clear(cms); + if (cms->elems != NULL) { + mem_free(cms->mem, cms->elems); + cms->elems = NULL; + cms->count = 0; + cms->len = 0; + } + mem_free(cms->mem, cms); // free ourselves +} + + +ic_private void completions_clear(completions_t* cms) { + while (cms->count > 0) { + completion_t* cm = cms->elems + cms->count - 1; + mem_free( cms->mem, cm->display); + mem_free( cms->mem, cm->replacement); + mem_free( cms->mem, cm->help); + memset(cm,0,sizeof(*cm)); + cms->count--; + } +} + +static void completions_push(completions_t* cms, const char* replacement, const char* display, const char* help, ssize_t delete_before, ssize_t delete_after) +{ + if (cms->count >= cms->len) { + ssize_t newlen = (cms->len <= 0 ? 32 : cms->len*2); + completion_t* newelems = mem_realloc_tp(cms->mem, completion_t, cms->elems, newlen ); + if (newelems == NULL) return; + cms->elems = newelems; + cms->len = newlen; + } + assert(cms->count < cms->len); + completion_t* cm = cms->elems + cms->count; + cm->replacement = mem_strdup(cms->mem,replacement); + cm->display = mem_strdup(cms->mem,display); + cm->help = mem_strdup(cms->mem,help); + cm->delete_before = delete_before; + cm->delete_after = delete_after; + cms->count++; +} + +ic_private ssize_t completions_count(completions_t* cms) { + return cms->count; +} + +static bool completions_contains(completions_t* cms, const char* replacement) { + for( ssize_t i = 0; i < cms->count; i++ ) { + const completion_t* c = cms->elems + i; + if (strcmp(replacement,c->replacement) == 0) { return true; } + } + return false; +} + +ic_private bool completions_add(completions_t* cms, const char* replacement, const char* display, const char* help, ssize_t delete_before, ssize_t delete_after) { + if (cms->completer_max <= 0) return false; + cms->completer_max--; + //debug_msg("completion: add: %d,%d, %s\n", delete_before, delete_after, replacement); + if (!completions_contains(cms,replacement)) { + completions_push(cms, replacement, display, help, delete_before, delete_after); + } + return true; +} + +static completion_t* completions_get(completions_t* cms, ssize_t index) { + if (index < 0 || cms->count <= 0 || index >= cms->count) return NULL; + return &cms->elems[index]; +} + +ic_private const char* completions_get_display( completions_t* cms, ssize_t index, const char** help ) { + if (help != NULL) { *help = NULL; } + completion_t* cm = completions_get(cms, index); + if (cm == NULL) return NULL; + if (help != NULL) { *help = cm->help; } + return (cm->display != NULL ? cm->display : cm->replacement); +} + +ic_private const char* completions_get_help( completions_t* cms, ssize_t index ) { + completion_t* cm = completions_get(cms, index); + if (cm == NULL) return NULL; + return cm->help; +} + +ic_private const char* completions_get_hint(completions_t* cms, ssize_t index, const char** help) { + if (help != NULL) { *help = NULL; } + completion_t* cm = completions_get(cms, index); + if (cm == NULL) return NULL; + ssize_t len = ic_strlen(cm->replacement); + if (len < cm->delete_before) return NULL; + const char* hint = (cm->replacement + cm->delete_before); + if (*hint == 0 || utf8_is_cont((uint8_t)(*hint))) return NULL; // utf8 boundary? + if (help != NULL) { *help = cm->help; } + return hint; +} + +ic_private void completions_set_completer(completions_t* cms, ic_completer_fun_t* completer, void* arg) { + cms->completer = completer; + cms->completer_arg = arg; +} + +ic_private void completions_get_completer(completions_t* cms, ic_completer_fun_t** completer, void** arg) { + *completer = cms->completer; + *arg = cms->completer_arg; +} + + +ic_public void* ic_completion_arg( const ic_completion_env_t* cenv ) { + return (cenv == NULL ? NULL : cenv->env->completions->completer_arg); +} + +ic_public bool ic_has_completions( const ic_completion_env_t* cenv ) { + return (cenv == NULL ? false : cenv->env->completions->count > 0); +} + +ic_public bool ic_stop_completing( const ic_completion_env_t* cenv) { + return (cenv == NULL ? true : cenv->env->completions->completer_max <= 0); +} + + +static ssize_t completion_apply( completion_t* cm, stringbuf_t* sbuf, ssize_t pos ) { + if (cm == NULL) return -1; + debug_msg( "completion: apply: %s at %zd\n", cm->replacement, pos); + ssize_t start = pos - cm->delete_before; + if (start < 0) start = 0; + ssize_t n = cm->delete_before + cm->delete_after; + if (ic_strlen(cm->replacement) == n && strncmp(sbuf_string_at(sbuf,start), cm->replacement, to_size_t(n)) == 0) { + // no changes + return -1; + } + else { + sbuf_delete_from_to( sbuf, start, pos + cm->delete_after ); + return sbuf_insert_at(sbuf, cm->replacement, start); + } +} + +ic_private ssize_t completions_apply( completions_t* cms, ssize_t index, stringbuf_t* sbuf, ssize_t pos ) { + completion_t* cm = completions_get(cms, index); + return completion_apply( cm, sbuf, pos ); +} + + +static int completion_compare(const void* p1, const void* p2) { + if (p1 == NULL || p2 == NULL) return 0; + const completion_t* cm1 = (const completion_t*)p1; + const completion_t* cm2 = (const completion_t*)p2; + return ic_stricmp(cm1->replacement, cm2->replacement); +} + +ic_private void completions_sort(completions_t* cms) { + if (cms->count <= 0) return; + qsort(cms->elems, to_size_t(cms->count), sizeof(cms->elems[0]), &completion_compare); +} + +#define IC_MAX_PREFIX (256) + +// find longest common prefix and complete with that. +ic_private ssize_t completions_apply_longest_prefix(completions_t* cms, stringbuf_t* sbuf, ssize_t pos) { + if (cms->count <= 1) { + return completions_apply(cms,0,sbuf,pos); + } + + // set initial prefix to the first entry + completion_t* cm = completions_get(cms, 0); + if (cm == NULL) return -1; + + char prefix[IC_MAX_PREFIX+1]; + ssize_t delete_before = cm->delete_before; + ic_strncpy( prefix, IC_MAX_PREFIX+1, cm->replacement, IC_MAX_PREFIX ); + prefix[IC_MAX_PREFIX] = 0; + + // and visit all others to find the longest common prefix + for(ssize_t i = 1; i < cms->count; i++) { + cm = completions_get(cms,i); + if (cm->delete_before != delete_before) { // deletions must match delete_before + prefix[0] = 0; + break; + } + // check if it is still a prefix + const char* r = cm->replacement; + ssize_t j; + for(j = 0; prefix[j] != 0 && r[j] != 0; j++) { + if (prefix[j] != r[j]) break; + } + prefix[j] = 0; + if (j <= 0) break; + } + + // check the length + ssize_t len = ic_strlen(prefix); + if (len <= 0 || len < delete_before) return -1; + + // we found a prefix :-) + completion_t cprefix; + memset(&cprefix,0,sizeof(cprefix)); + cprefix.delete_before = delete_before; + cprefix.replacement = prefix; + ssize_t newpos = completion_apply( &cprefix, sbuf, pos); + if (newpos < 0) return newpos; + + // adjust all delete_before for the new replacement + for( ssize_t i = 0; i < cms->count; i++) { + cm = completions_get(cms,i); + cm->delete_before = len; + } + + return newpos; +} + + +//------------------------------------------------------------- +// Completer functions +//------------------------------------------------------------- + +ic_public bool ic_add_completions(ic_completion_env_t* cenv, const char* prefix, const char** completions) { + for (const char** pc = completions; *pc != NULL; pc++) { + if (ic_istarts_with(*pc, prefix)) { + if (!ic_add_completion_ex(cenv, *pc, NULL, NULL)) return false; + } + } + return true; +} + +ic_public bool ic_add_completion(ic_completion_env_t* cenv, const char* replacement) { + return ic_add_completion_ex(cenv, replacement, NULL, NULL); +} + +ic_public bool ic_add_completion_ex( ic_completion_env_t* cenv, const char* replacement, const char* display, const char* help ) { + return ic_add_completion_prim(cenv,replacement,display,help,0,0); +} + +ic_public bool ic_add_completion_prim(ic_completion_env_t* cenv, const char* replacement, const char* display, const char* help, long delete_before, long delete_after) { + return (*cenv->complete)(cenv->env, cenv->closure, replacement, display, help, delete_before, delete_after ); +} + +static bool prim_add_completion(ic_env_t* env, void* funenv, const char* replacement, const char* display, const char* help, long delete_before, long delete_after) { + ic_unused(funenv); + return completions_add(env->completions, replacement, display, help, delete_before, delete_after); +} + +ic_public void ic_set_default_completer(ic_completer_fun_t* completer, void* arg) { + ic_env_t* env = ic_get_env(); if (env == NULL) return; + completions_set_completer(env->completions, completer, arg); +} + +ic_private ssize_t completions_generate(struct ic_env_s* env, completions_t* cms, const char* input, ssize_t pos, ssize_t max) { + completions_clear(cms); + if (cms->completer == NULL || input == NULL || ic_strlen(input) < pos) return 0; + + // set up env + ic_completion_env_t cenv; + cenv.env = env; + cenv.input = input, + cenv.cursor = (long)pos; + cenv.arg = cms->completer_arg; + cenv.complete = &prim_add_completion; + cenv.closure = NULL; + const char* prefix = mem_strndup(cms->mem, input, pos); + cms->completer_max = max; + + // and complete + cms->completer(&cenv,prefix); + + // restore + mem_free(cms->mem,prefix); + return completions_count(cms); +} + +// The default completer is no completion is set +static void default_filename_completer( ic_completion_env_t* cenv, const char* prefix ) { + #ifdef _WIN32 + const char sep = '\\'; + #else + const char sep = '/'; + #endif + ic_complete_filename( cenv, prefix, sep, ".", NULL); +} diff --git a/extern/isocline/src/completions.h b/extern/isocline/src/completions.h new file mode 100644 index 00000000..8361d507 --- /dev/null +++ b/extern/isocline/src/completions.h @@ -0,0 +1,52 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_COMPLETIONS_H +#define IC_COMPLETIONS_H + +#include "common.h" +#include "stringbuf.h" + + +//------------------------------------------------------------- +// Completions +//------------------------------------------------------------- +#define IC_MAX_COMPLETIONS_TO_SHOW (1000) +#define IC_MAX_COMPLETIONS_TO_TRY (IC_MAX_COMPLETIONS_TO_SHOW/4) + +typedef struct completions_s completions_t; + +ic_private completions_t* completions_new(alloc_t* mem); +ic_private void completions_free(completions_t* cms); +ic_private void completions_clear(completions_t* cms); +ic_private bool completions_add(completions_t* cms , const char* replacement, const char* display, const char* help, ssize_t delete_before, ssize_t delete_after); +ic_private ssize_t completions_count(completions_t* cms); +ic_private ssize_t completions_generate(struct ic_env_s* env, completions_t* cms , const char* input, ssize_t pos, ssize_t max); +ic_private void completions_sort(completions_t* cms); +ic_private void completions_set_completer(completions_t* cms, ic_completer_fun_t* completer, void* arg); +ic_private const char* completions_get_display(completions_t* cms , ssize_t index, const char** help); +ic_private const char* completions_get_hint(completions_t* cms, ssize_t index, const char** help); +ic_private void completions_get_completer(completions_t* cms, ic_completer_fun_t** completer, void** arg); + +ic_private ssize_t completions_apply(completions_t* cms, ssize_t index, stringbuf_t* sbuf, ssize_t pos); +ic_private ssize_t completions_apply_longest_prefix(completions_t* cms, stringbuf_t* sbuf, ssize_t pos); + +//------------------------------------------------------------- +// Completion environment +//------------------------------------------------------------- +typedef bool (ic_completion_fun_t)( ic_env_t* env, void* funenv, const char* replacement, const char* display, const char* help, long delete_before, long delete_after ); + +struct ic_completion_env_s { + ic_env_t* env; // the isocline environment + const char* input; // current full input + long cursor; // current cursor position + void* arg; // argument given to `ic_set_completer` + void* closure; // free variables for function composition + ic_completion_fun_t* complete; // function that adds a completion +}; + +#endif // IC_COMPLETIONS_H diff --git a/extern/isocline/src/editline.c b/extern/isocline/src/editline.c new file mode 100644 index 00000000..270c42d9 --- /dev/null +++ b/extern/isocline/src/editline.c @@ -0,0 +1,1142 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include + +#include "common.h" +#include "term.h" +#include "tty.h" +#include "env.h" +#include "stringbuf.h" +#include "history.h" +#include "completions.h" +#include "undo.h" +#include "highlight.h" + +//------------------------------------------------------------- +// The editor state +//------------------------------------------------------------- + + + +// editor state +typedef struct editor_s { + stringbuf_t* input; // current user input + stringbuf_t* extra; // extra displayed info (for completion menu etc) + stringbuf_t* hint; // hint displayed as part of the input + stringbuf_t* hint_help; // help for a hint. + ssize_t pos; // current cursor position in the input + ssize_t cur_rows; // current used rows to display our content (including extra content) + ssize_t cur_row; // current row that has the cursor (0 based, relative to the prompt) + ssize_t termw; + bool modified; // has a modification happened? (used for history navigation for example) + bool disable_undo; // temporarily disable auto undo (for history search) + ssize_t history_idx; // current index in the history + editstate_t* undo; // undo buffer + editstate_t* redo; // redo buffer + const char* prompt_text; // text of the prompt before the prompt marker + alloc_t* mem; // allocator + // caches + attrbuf_t* attrs; // reuse attribute buffers + attrbuf_t* attrs_extra; +} editor_t; + + + + + +//------------------------------------------------------------- +// Main edit line +//------------------------------------------------------------- +static char* edit_line( ic_env_t* env, const char* prompt_text ); // defined at bottom +static void edit_refresh(ic_env_t* env, editor_t* eb); + +ic_private char* ic_editline(ic_env_t* env, const char* prompt_text) { + tty_start_raw(env->tty); + term_start_raw(env->term); + char* line = edit_line(env,prompt_text); + term_end_raw(env->term,false); + tty_end_raw(env->tty); + term_writeln(env->term,""); + term_flush(env->term); + return line; +} + + +//------------------------------------------------------------- +// Undo/Redo +//------------------------------------------------------------- + +// capture the current edit state +static void editor_capture(editor_t* eb, editstate_t** es ) { + if (!eb->disable_undo) { + editstate_capture( eb->mem, es, sbuf_string(eb->input), eb->pos ); + } +} + +static void editor_undo_capture(editor_t* eb ) { + editor_capture(eb, &eb->undo ); +} + +static void editor_undo_forget(editor_t* eb) { + if (eb->disable_undo) return; + const char* input = NULL; + ssize_t pos = 0; + editstate_restore(eb->mem, &eb->undo, &input, &pos); + mem_free(eb->mem, input); +} + +static void editor_restore(editor_t* eb, editstate_t** from, editstate_t** to ) { + if (eb->disable_undo) return; + if (*from == NULL) return; + const char* input; + if (to != NULL) { editor_capture( eb, to ); } + if (!editstate_restore( eb->mem, from, &input, &eb->pos )) return; + sbuf_replace( eb->input, input ); + mem_free(eb->mem, input); + eb->modified = false; +} + +static void editor_undo_restore(editor_t* eb, bool with_redo ) { + editor_restore(eb, &eb->undo, (with_redo ? &eb->redo : NULL)); +} + +static void editor_redo_restore(editor_t* eb ) { + editor_restore(eb, &eb->redo, &eb->undo); + eb->modified = false; +} + +static void editor_start_modify(editor_t* eb ) { + editor_undo_capture(eb); + editstate_done(eb->mem, &eb->redo); // clear redo + eb->modified = true; +} + + + +static bool editor_pos_is_at_end(editor_t* eb ) { + return (eb->pos == sbuf_len(eb->input)); +} + +//------------------------------------------------------------- +// Row/Column width and positioning +//------------------------------------------------------------- + + +static void edit_get_prompt_width( ic_env_t* env, editor_t* eb, bool in_extra, ssize_t* promptw, ssize_t* cpromptw ) { + if (in_extra) { + *promptw = 0; + *cpromptw = 0; + } + else { + // todo: cache prompt widths + ssize_t textw = bbcode_column_width(env->bbcode, eb->prompt_text); + ssize_t markerw = bbcode_column_width(env->bbcode, env->prompt_marker); + ssize_t cmarkerw = bbcode_column_width(env->bbcode, env->cprompt_marker); + *promptw = markerw + textw; + *cpromptw = (env->no_multiline_indent || *promptw < cmarkerw ? cmarkerw : *promptw); + } +} + +static ssize_t edit_get_rowcol( ic_env_t* env, editor_t* eb, rowcol_t* rc ) { + ssize_t promptw, cpromptw; + edit_get_prompt_width(env, eb, false, &promptw, &cpromptw); + return sbuf_get_rc_at_pos( eb->input, eb->termw, promptw, cpromptw, eb->pos, rc ); +} + +static void edit_set_pos_at_rowcol( ic_env_t* env, editor_t* eb, ssize_t row, ssize_t col ) { + ssize_t promptw, cpromptw; + edit_get_prompt_width(env, eb, false, &promptw, &cpromptw); + ssize_t pos = sbuf_get_pos_at_rc( eb->input, eb->termw, promptw, cpromptw, row, col ); + if (pos < 0) return; + eb->pos = pos; + edit_refresh(env, eb); +} + +static bool edit_pos_is_at_row_end( ic_env_t* env, editor_t* eb ) { + rowcol_t rc; + edit_get_rowcol( env, eb, &rc ); + return rc.last_on_row; +} + +static void edit_write_prompt( ic_env_t* env, editor_t* eb, ssize_t row, bool in_extra ) { + if (in_extra) return; + bbcode_style_open(env->bbcode, "ic-prompt"); + if (row==0) { + // regular prompt text + bbcode_print( env->bbcode, eb->prompt_text ); + } + else if (!env->no_multiline_indent) { + // multiline continuation indentation + // todo: cache prompt widths + ssize_t textw = bbcode_column_width(env->bbcode, eb->prompt_text ); + ssize_t markerw = bbcode_column_width(env->bbcode, env->prompt_marker); + ssize_t cmarkerw = bbcode_column_width(env->bbcode, env->cprompt_marker); + if (cmarkerw < markerw + textw) { + term_write_repeat(env->term, " ", markerw + textw - cmarkerw ); + } + } + // the marker + bbcode_print(env->bbcode, (row == 0 ? env->prompt_marker : env->cprompt_marker )); + bbcode_style_close(env->bbcode,NULL); +} + +//------------------------------------------------------------- +// Refresh +//------------------------------------------------------------- + +typedef struct refresh_info_s { + ic_env_t* env; + editor_t* eb; + attrbuf_t* attrs; + bool in_extra; + ssize_t first_row; + ssize_t last_row; +} refresh_info_t; + +static bool edit_refresh_rows_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(res); ic_unused(startw); + const refresh_info_t* info = (const refresh_info_t*)(arg); + term_t* term = info->env->term; + + // debug_msg("edit: line refresh: row %zd, len: %zd\n", row, row_len); + if (row < info->first_row) return false; + if (row > info->last_row) return true; // should not occur + + // term_clear_line(term); + edit_write_prompt(info->env, info->eb, row, info->in_extra); + + //' write output + if (info->attrs == NULL || (info->env->no_highlight && info->env->no_bracematch)) { + term_write_n( term, s + row_start, row_len ); + } + else { + term_write_formatted_n( term, s + row_start, attrbuf_attrs(info->attrs, row_start + row_len) + row_start, row_len ); + } + + // write line ending + if (row < info->last_row) { + if (is_wrap && tty_is_utf8(info->env->tty)) { + #ifndef __APPLE__ + bbcode_print( info->env->bbcode, "[ic-dim]\xE2\x86\x90"); // left arrow + #else + bbcode_print( info->env->bbcode, "[ic-dim]\xE2\x86\xB5" ); // return symbol + #endif + } + term_clear_to_end_of_line(term); + term_writeln(term, ""); + } + else { + term_clear_to_end_of_line(term); + } + return (row >= info->last_row); +} + +static void edit_refresh_rows(ic_env_t* env, editor_t* eb, stringbuf_t* input, attrbuf_t* attrs, + ssize_t promptw, ssize_t cpromptw, bool in_extra, + ssize_t first_row, ssize_t last_row) +{ + if (input == NULL) return; + refresh_info_t info; + info.env = env; + info.eb = eb; + info.attrs = attrs; + info.in_extra = in_extra; + info.first_row = first_row; + info.last_row = last_row; + sbuf_for_each_row( input, eb->termw, promptw, cpromptw, &edit_refresh_rows_iter, &info, NULL); +} + + +static void edit_refresh(ic_env_t* env, editor_t* eb) +{ + // calculate the new cursor row and total rows needed + ssize_t promptw, cpromptw; + edit_get_prompt_width( env, eb, false, &promptw, &cpromptw ); + + if (eb->attrs != NULL) { + highlight( env->mem, env->bbcode, sbuf_string(eb->input), eb->attrs, + (env->no_highlight ? NULL : env->highlighter), env->highlighter_arg ); + } + + // highlight matching braces + if (eb->attrs != NULL && !env->no_bracematch) { + highlight_match_braces(sbuf_string(eb->input), eb->attrs, eb->pos, ic_env_get_match_braces(env), + bbcode_style(env->bbcode,"ic-bracematch"), bbcode_style(env->bbcode,"ic-error")); + } + + // insert hint + if (sbuf_len(eb->hint) > 0) { + if (eb->attrs != NULL) { + attrbuf_insert_at( eb->attrs, eb->pos, sbuf_len(eb->hint), bbcode_style(env->bbcode, "ic-hint") ); + } + sbuf_insert_at(eb->input, sbuf_string(eb->hint), eb->pos ); + } + + // render extra (like a completion menu) + stringbuf_t* extra = NULL; + if (sbuf_len(eb->extra) > 0) { + extra = sbuf_new(eb->mem); + if (extra != NULL) { + if (sbuf_len(eb->hint_help) > 0) { + bbcode_append(env->bbcode, sbuf_string(eb->hint_help), extra, eb->attrs_extra); + } + bbcode_append(env->bbcode, sbuf_string(eb->extra), extra, eb->attrs_extra); + } + } + + // calculate rows and row/col position + rowcol_t rc = { 0 }; + const ssize_t rows_input = sbuf_get_rc_at_pos( eb->input, eb->termw, promptw, cpromptw, eb->pos, &rc ); + rowcol_t rc_extra = { 0 }; + ssize_t rows_extra = 0; + if (extra != NULL) { + rows_extra = sbuf_get_rc_at_pos( extra, eb->termw, 0, 0, 0 /*pos*/, &rc_extra ); + } + const ssize_t rows = rows_input + rows_extra; + debug_msg("edit: refresh: rows %zd, cursor: %zd,%zd (previous rows %zd, cursor row %zd)\n", rows, rc.row, rc.col, eb->cur_rows, eb->cur_row); + + // only render at most terminal height rows + const ssize_t termh = term_get_height(env->term); + ssize_t first_row = 0; // first visible row + ssize_t last_row = rows - 1; // last visible row + if (rows > termh) { + first_row = rc.row - termh + 1; // ensure cursor is visible + if (first_row < 0) first_row = 0; + last_row = first_row + termh - 1; + } + assert(last_row - first_row < termh); + + // reduce flicker + buffer_mode_t bmode = term_set_buffer_mode(env->term, BUFFERED); + + // back up to the first line + term_start_of_line(env->term); + term_up(env->term, (eb->cur_row >= termh ? termh-1 : eb->cur_row) ); + // term_clear_lines_to_end(env->term); // gives flicker in old Windows cmd prompt + + // render rows + edit_refresh_rows( env, eb, eb->input, eb->attrs, promptw, cpromptw, false, first_row, last_row ); + if (rows_extra > 0) { + assert(extra != NULL); + const ssize_t first_rowx = (first_row > rows_input ? first_row - rows_input : 0); + const ssize_t last_rowx = last_row - rows_input; assert(last_rowx >= 0); + edit_refresh_rows(env, eb, extra, eb->attrs_extra, 0, 0, true, first_rowx, last_rowx); + } + + // overwrite trailing rows we do not use anymore + ssize_t rrows = last_row - first_row + 1; // rendered rows + if (rrows < termh && rows < eb->cur_rows) { + ssize_t clear = eb->cur_rows - rows; + while (rrows < termh && clear > 0) { + clear--; + rrows++; + term_writeln(env->term,""); + term_clear_line(env->term); + } + } + + // move cursor back to edit position + term_start_of_line(env->term); + term_up(env->term, first_row + rrows - 1 - rc.row ); + term_right(env->term, rc.col + (rc.row == 0 ? promptw : cpromptw)); + + // and refresh + term_flush(env->term); + + // stop buffering + term_set_buffer_mode(env->term, bmode); + + // restore input by removing the hint + sbuf_delete_at(eb->input, eb->pos, sbuf_len(eb->hint)); + sbuf_delete_at(eb->extra, 0, sbuf_len(eb->hint_help)); + attrbuf_clear(eb->attrs); + attrbuf_clear(eb->attrs_extra); + sbuf_free(extra); + + // update previous + eb->cur_rows = rows; + eb->cur_row = rc.row; +} + +// clear current output +static void edit_clear(ic_env_t* env, editor_t* eb ) { + term_attr_reset(env->term); + term_up(env->term, eb->cur_row); + + // overwrite all rows + for( ssize_t i = 0; i < eb->cur_rows; i++) { + term_clear_line(env->term); + term_writeln(env->term, ""); + } + + // move cursor back + term_up(env->term, eb->cur_rows - eb->cur_row ); +} + + +// clear screen and refresh +static void edit_clear_screen(ic_env_t* env, editor_t* eb ) { + ssize_t cur_rows = eb->cur_rows; + eb->cur_rows = term_get_height(env->term) - 1; + edit_clear(env,eb); + eb->cur_rows = cur_rows; + edit_refresh(env,eb); +} + + +// refresh after a terminal window resized (but before doing further edit operations!) +static bool edit_resize(ic_env_t* env, editor_t* eb ) { + // update dimensions + term_update_dim(env->term); + ssize_t newtermw = term_get_width(env->term); + if (eb->termw == newtermw) return false; + + // recalculate the row layout assuming the hardwrapping for the new terminal width + ssize_t promptw, cpromptw; + edit_get_prompt_width( env, eb, false, &promptw, &cpromptw ); + sbuf_insert_at(eb->input, sbuf_string(eb->hint), eb->pos); // insert used hint + + // render extra (like a completion menu) + stringbuf_t* extra = NULL; + if (sbuf_len(eb->extra) > 0) { + extra = sbuf_new(eb->mem); + if (extra != NULL) { + if (sbuf_len(eb->hint_help) > 0) { + bbcode_append(env->bbcode, sbuf_string(eb->hint_help), extra, NULL); + } + bbcode_append(env->bbcode, sbuf_string(eb->extra), extra, NULL); + } + } + rowcol_t rc = { 0 }; + const ssize_t rows_input = sbuf_get_wrapped_rc_at_pos( eb->input, eb->termw, newtermw, promptw, cpromptw, eb->pos, &rc ); + rowcol_t rc_extra = { 0 }; + ssize_t rows_extra = 0; + if (extra != NULL) { + rows_extra = sbuf_get_wrapped_rc_at_pos(extra, eb->termw, newtermw, 0, 0, 0 /*pos*/, &rc_extra); + } + ssize_t rows = rows_input + rows_extra; + debug_msg("edit: resize: new rows: %zd, cursor row: %zd (previous: rows: %zd, cursor row %zd)\n", rows, rc.row, eb->cur_rows, eb->cur_row); + + // update the newly calculated row and rows + eb->cur_row = rc.row; + if (rows > eb->cur_rows) { + eb->cur_rows = rows; + } + eb->termw = newtermw; + edit_refresh(env,eb); + + // remove hint again + sbuf_delete_at(eb->input, eb->pos, sbuf_len(eb->hint)); + sbuf_free(extra); + return true; +} + +static void editor_append_hint_help(editor_t* eb, const char* help) { + sbuf_clear(eb->hint_help); + if (help != NULL) { + sbuf_replace(eb->hint_help, "[ic-info]"); + sbuf_append(eb->hint_help, help); + sbuf_append(eb->hint_help, "[/ic-info]\n"); + } +} + +// refresh with possible hint +static void edit_refresh_hint(ic_env_t* env, editor_t* eb) { + if (env->no_hint || env->hint_delay > 0) { + // refresh without hint first + edit_refresh(env, eb); + if (env->no_hint) return; + } + + // and see if we can construct a hint (displayed after a delay) + ssize_t count = completions_generate(env, env->completions, sbuf_string(eb->input), eb->pos, 2); + if (count == 1) { + const char* help = NULL; + const char* hint = completions_get_hint(env->completions, 0, &help); + if (hint != NULL) { + sbuf_replace(eb->hint, hint); + editor_append_hint_help(eb, help); + // do auto-tabbing? + if (env->complete_autotab) { + stringbuf_t* sb = sbuf_new(env->mem); // temporary buffer for completion + if (sb != NULL) { + sbuf_replace( sb, sbuf_string(eb->input) ); + ssize_t pos = eb->pos; + const char* extra_hint = hint; + do { + ssize_t newpos = sbuf_insert_at( sb, extra_hint, pos ); + if (newpos <= pos) break; + pos = newpos; + count = completions_generate(env, env->completions, sbuf_string(sb), pos, 2); + if (count == 1) { + const char* extra_help = NULL; + extra_hint = completions_get_hint(env->completions, 0, &extra_help); + if (extra_hint != NULL) { + editor_append_hint_help(eb, extra_help); + sbuf_append(eb->hint, extra_hint); + } + } + } + while(count == 1); + sbuf_free(sb); + } + } + } + } + + if (env->hint_delay <= 0) { + // refresh with hint directly + edit_refresh(env, eb); + } +} + +//------------------------------------------------------------- +// Edit operations +//------------------------------------------------------------- + +static void edit_history_prev(ic_env_t* env, editor_t* eb); +static void edit_history_next(ic_env_t* env, editor_t* eb); + +static void edit_undo_restore(ic_env_t* env, editor_t* eb) { + editor_undo_restore(eb, true); + edit_refresh(env,eb); +} + +static void edit_redo_restore(ic_env_t* env, editor_t* eb) { + editor_redo_restore(eb); + edit_refresh(env,eb); +} + +static void edit_cursor_left(ic_env_t* env, editor_t* eb) { + ssize_t cwidth = 1; + ssize_t prev = sbuf_prev(eb->input,eb->pos,&cwidth); + if (prev < 0) return; + rowcol_t rc; + edit_get_rowcol( env, eb, &rc); + eb->pos = prev; + edit_refresh(env,eb); +} + +static void edit_cursor_right(ic_env_t* env, editor_t* eb) { + ssize_t cwidth = 1; + ssize_t next = sbuf_next(eb->input,eb->pos,&cwidth); + if (next < 0) return; + rowcol_t rc; + edit_get_rowcol( env, eb, &rc); + eb->pos = next; + edit_refresh(env,eb); +} + +static void edit_cursor_line_end(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + eb->pos = end; + edit_refresh(env,eb); +} + +static void edit_cursor_line_start(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_cursor_next_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_word_end(eb->input,eb->pos); + if (end < 0) return; + eb->pos = end; + edit_refresh(env,eb); +} + +static void edit_cursor_prev_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_word_start(eb->input,eb->pos); + if (start < 0) return; + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_cursor_next_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_ws_word_end(eb->input, eb->pos); + if (end < 0) return; + eb->pos = end; + edit_refresh(env, eb); +} + +static void edit_cursor_prev_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_ws_word_start(eb->input, eb->pos); + if (start < 0) return; + eb->pos = start; + edit_refresh(env, eb); +} + +static void edit_cursor_to_start(ic_env_t* env, editor_t* eb) { + eb->pos = 0; + edit_refresh(env,eb); +} + +static void edit_cursor_to_end(ic_env_t* env, editor_t* eb) { + eb->pos = sbuf_len(eb->input); + edit_refresh(env,eb); +} + + +static void edit_cursor_row_up(ic_env_t* env, editor_t* eb) { + rowcol_t rc; + edit_get_rowcol( env, eb, &rc); + if (rc.row == 0) { + edit_history_prev(env,eb); + } + else { + edit_set_pos_at_rowcol( env, eb, rc.row - 1, rc.col ); + } +} + +static void edit_cursor_row_down(ic_env_t* env, editor_t* eb) { + rowcol_t rc; + ssize_t rows = edit_get_rowcol( env, eb, &rc); + if (rc.row + 1 >= rows) { + edit_history_next(env,eb); + } + else { + edit_set_pos_at_rowcol( env, eb, rc.row + 1, rc.col ); + } +} + + +static void edit_cursor_match_brace(ic_env_t* env, editor_t* eb) { + ssize_t match = find_matching_brace( sbuf_string(eb->input), eb->pos, ic_env_get_match_braces(env), NULL ); + if (match < 0) return; + eb->pos = match; + edit_refresh(env,eb); +} + +static void edit_backspace(ic_env_t* env, editor_t* eb) { + if (eb->pos <= 0) return; + editor_start_modify(eb); + eb->pos = sbuf_delete_char_before(eb->input,eb->pos); + edit_refresh(env,eb); +} + +static void edit_delete_char(ic_env_t* env, editor_t* eb) { + if (eb->pos >= sbuf_len(eb->input)) return; + editor_start_modify(eb); + sbuf_delete_char_at(eb->input,eb->pos); + edit_refresh(env,eb); +} + +static void edit_delete_all(ic_env_t* env, editor_t* eb) { + if (sbuf_len(eb->input) <= 0) return; + editor_start_modify(eb); + sbuf_clear(eb->input); + eb->pos = 0; + edit_refresh(env,eb); +} + +static void edit_delete_to_end_of_line(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + // if on an empty line, remove it completely + if (start == end && sbuf_char_at(eb->input,end) == '\n') { + end++; + } + else if (start == end && sbuf_char_at(eb->input,start - 1) == '\n') { + eb->pos--; + } + sbuf_delete_from_to( eb->input, eb->pos, end ); + edit_refresh(env,eb); +} + +static void edit_delete_to_start_of_line(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + // delete start newline if it was an empty line + bool goright = false; + if (start > 0 && sbuf_char_at(eb->input,start-1) == '\n' && start == end) { + // if it is an empty line remove it + start--; + // afterwards, move to start of next line if it exists (so the cursor stays on the same row) + goright = true; + } + sbuf_delete_from_to( eb->input, start, eb->pos ); + eb->pos = start; + if (goright) edit_cursor_right(env,eb); + edit_refresh(env,eb); +} + +static void edit_delete_line(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + // delete newline as well so no empty line is left; + bool goright = false; + if (start > 0 && sbuf_char_at(eb->input,start-1) == '\n') { + start--; + // afterwards, move to start of next line if it exists (so the cursor stays on the same row) + goright = true; + } + else if (sbuf_char_at(eb->input,end) == '\n') { + end++; + } + sbuf_delete_from_to(eb->input,start,end); + eb->pos = start; + if (goright) edit_cursor_right(env,eb); + edit_refresh(env,eb); +} + +static void edit_delete_to_start_of_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_word_start(eb->input,eb->pos); + if (start < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to( eb->input, start, eb->pos ); + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_delete_to_end_of_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_word_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to( eb->input, eb->pos, end ); + edit_refresh(env,eb); +} + +static void edit_delete_to_start_of_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_ws_word_start(eb->input, eb->pos); + if (start < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to(eb->input, start, eb->pos); + eb->pos = start; + edit_refresh(env, eb); +} + +static void edit_delete_to_end_of_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_ws_word_end(eb->input, eb->pos); + if (end < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to(eb->input, eb->pos, end); + edit_refresh(env, eb); +} + + +static void edit_delete_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_word_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_word_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to(eb->input,start,end); + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_swap_char( ic_env_t* env, editor_t* eb ) { + if (eb->pos <= 0 || eb->pos == sbuf_len(eb->input)) return; + editor_start_modify(eb); + eb->pos = sbuf_swap_char(eb->input,eb->pos); + edit_refresh(env,eb); +} + +static void edit_multiline_eol(ic_env_t* env, editor_t* eb) { + if (eb->pos <= 0) return; + if (sbuf_string(eb->input)[eb->pos-1] != env->multiline_eol) return; + editor_start_modify(eb); + // replace line continuation with a real newline + sbuf_delete_at( eb->input, eb->pos-1, 1); + sbuf_insert_at( eb->input, "\n", eb->pos-1); + edit_refresh(env,eb); +} + +static void edit_insert_unicode(ic_env_t* env, editor_t* eb, unicode_t u) { + editor_start_modify(eb); + ssize_t nextpos = sbuf_insert_unicode_at(eb->input, u, eb->pos); + if (nextpos >= 0) eb->pos = nextpos; + edit_refresh_hint(env, eb); +} + +static void edit_auto_brace(ic_env_t* env, editor_t* eb, char c) { + if (env->no_autobrace) return; + const char* braces = ic_env_get_auto_braces(env); + for (const char* b = braces; *b != 0; b += 2) { + if (*b == c) { + const char close = b[1]; + //if (sbuf_char_at(eb->input, eb->pos) != close) { + sbuf_insert_char_at(eb->input, close, eb->pos); + bool balanced = false; + find_matching_brace(sbuf_string(eb->input), eb->pos, braces, &balanced ); + if (!balanced) { + // don't insert if it leads to an unbalanced expression. + sbuf_delete_char_at(eb->input, eb->pos); + } + //} + return; + } + else if (b[1] == c) { + // close brace, check if there we don't overwrite to the right + if (sbuf_char_at(eb->input, eb->pos) == c) { + sbuf_delete_char_at(eb->input, eb->pos); + } + return; + } + } +} + +static void editor_auto_indent(editor_t* eb, const char* pre, const char* post ) { + assert(eb->pos > 0 && sbuf_char_at(eb->input,eb->pos-1) == '\n'); + ssize_t prelen = ic_strlen(pre); + if (prelen > 0) { + if (eb->pos - 1 < prelen) return; + if (!ic_starts_with(sbuf_string(eb->input) + eb->pos - 1 - prelen, pre)) return; + if (!ic_starts_with(sbuf_string(eb->input) + eb->pos, post)) return; + eb->pos = sbuf_insert_at(eb->input, " ", eb->pos); + sbuf_insert_char_at(eb->input, '\n', eb->pos); + } +} + +static void edit_insert_char(ic_env_t* env, editor_t* eb, char c) { + editor_start_modify(eb); + ssize_t nextpos = sbuf_insert_char_at( eb->input, c, eb->pos ); + if (nextpos >= 0) eb->pos = nextpos; + edit_auto_brace(env, eb, c); + if (c=='\n') { + editor_auto_indent(eb, "{", "}"); // todo: custom auto indent tokens? + } + edit_refresh_hint(env,eb); +} + +//------------------------------------------------------------- +// Help +//------------------------------------------------------------- + +#include "editline_help.c" + +//------------------------------------------------------------- +// History +//------------------------------------------------------------- + +#include "editline_history.c" + +//------------------------------------------------------------- +// Completion +//------------------------------------------------------------- + +#include "editline_completion.c" + + +//------------------------------------------------------------- +// Edit line: main edit loop +//------------------------------------------------------------- + +static char* edit_line( ic_env_t* env, const char* prompt_text ) +{ + // set up an edit buffer + editor_t eb; + memset(&eb, 0, sizeof(eb)); + eb.mem = env->mem; + eb.input = sbuf_new(env->mem); + eb.extra = sbuf_new(env->mem); + eb.hint = sbuf_new(env->mem); + eb.hint_help= sbuf_new(env->mem); + eb.termw = term_get_width(env->term); + eb.pos = 0; + eb.cur_rows = 1; + eb.cur_row = 0; + eb.modified = false; + eb.prompt_text = (prompt_text != NULL ? prompt_text : ""); + eb.history_idx = 0; + editstate_init(&eb.undo); + editstate_init(&eb.redo); + if (eb.input==NULL || eb.extra==NULL || eb.hint==NULL || eb.hint_help==NULL) { + return NULL; + } + + // caching + if (!(env->no_highlight && env->no_bracematch)) { + eb.attrs = attrbuf_new(env->mem); + eb.attrs_extra = attrbuf_new(env->mem); + } + + // show prompt + edit_write_prompt(env, &eb, 0, false); + + // always a history entry for the current input + history_push(env->history, ""); + + // process keys + code_t c; // current key code + while(true) { + // read a character + term_flush(env->term); + if (env->hint_delay <= 0 || sbuf_len(eb.hint) == 0) { + // blocking read + c = tty_read(env->tty); + } + else { + // timeout to display hint + if (!tty_read_timeout(env->tty, env->hint_delay, &c)) { + // timed-out + if (sbuf_len(eb.hint) > 0) { + // display hint + edit_refresh(env, &eb); + } + c = tty_read(env->tty); + } + else { + // clear the pending hint if we got input before the delay expired + sbuf_clear(eb.hint); + sbuf_clear(eb.hint_help); + } + } + + // update terminal in case of a resize + if (tty_term_resize_event(env->tty)) { + edit_resize(env,&eb); + } + + // clear hint only after a potential resize (so resize row calculations are correct) + const bool had_hint = (sbuf_len(eb.hint) > 0); + sbuf_clear(eb.hint); + sbuf_clear(eb.hint_help); + + // if the user tries to move into a hint with left-cursor or end, we complete it first + if ((c == KEY_RIGHT || c == KEY_END) && had_hint) { + edit_generate_completions(env, &eb, true); + c = KEY_NONE; + } + + // Operations that may return + if (c == KEY_ENTER) { + if (!env->singleline_only && eb.pos > 0 && + sbuf_string(eb.input)[eb.pos-1] == env->multiline_eol && + edit_pos_is_at_row_end(env,&eb)) + { + // replace line-continuation with newline + edit_multiline_eol(env,&eb); + } + else { + // otherwise done + break; + } + } + else if (c == KEY_CTRL_D) { + if (eb.pos == 0 && editor_pos_is_at_end(&eb)) break; // ctrl+D on empty quits with NULL + edit_delete_char(env,&eb); // otherwise it is like delete + } + else if (c == KEY_CTRL_C || c == KEY_EVENT_STOP) { + break; // ctrl+C or STOP event quits with NULL + } + else if (c == KEY_ESC) { + if (eb.pos == 0 && editor_pos_is_at_end(&eb)) break; // ESC on empty input returns with empty input + edit_delete_all(env,&eb); // otherwise delete the current input + // edit_delete_line(env,&eb); // otherwise delete the current line + } + else if (c == KEY_BELL /* ^G */) { + edit_delete_all(env,&eb); + break; // ctrl+G cancels (and returns empty input) + } + + // Editing Operations + else switch(c) { + // events + case KEY_EVENT_RESIZE: // not used + edit_resize(env,&eb); + break; + case KEY_EVENT_AUTOTAB: + edit_generate_completions(env, &eb, true); + break; + + // completion, history, help, undo + case KEY_TAB: + case WITH_ALT('?'): + edit_generate_completions(env,&eb,false); + break; + case KEY_CTRL_R: + case KEY_CTRL_S: + edit_history_search_with_current_word(env,&eb); + break; + case KEY_CTRL_P: + edit_history_prev(env, &eb); + break; + case KEY_CTRL_N: + edit_history_next(env, &eb); + break; + case KEY_CTRL_L: + edit_clear_screen(env, &eb); + break; + case KEY_CTRL_Z: + case WITH_CTRL('_'): + edit_undo_restore(env, &eb); + break; + case KEY_CTRL_Y: + edit_redo_restore(env, &eb); + break; + case KEY_F1: + edit_show_help(env, &eb); + break; + + // navigation + case KEY_LEFT: + case KEY_CTRL_B: + edit_cursor_left(env,&eb); + break; + case KEY_RIGHT: + case KEY_CTRL_F: + if (eb.pos == sbuf_len(eb.input)) { + edit_generate_completions( env, &eb, false ); + } + else { + edit_cursor_right(env,&eb); + } + break; + case KEY_UP: + edit_cursor_row_up(env,&eb); + break; + case KEY_DOWN: + edit_cursor_row_down(env,&eb); + break; + case KEY_HOME: + case KEY_CTRL_A: + edit_cursor_line_start(env,&eb); + break; + case KEY_END: + case KEY_CTRL_E: + edit_cursor_line_end(env,&eb); + break; + case KEY_CTRL_LEFT: + case WITH_SHIFT(KEY_LEFT): + case WITH_ALT('b'): + edit_cursor_prev_word(env,&eb); + break; + case KEY_CTRL_RIGHT: + case WITH_SHIFT(KEY_RIGHT): + case WITH_ALT('f'): + if (eb.pos == sbuf_len(eb.input)) { + edit_generate_completions( env, &eb, false ); + } + else { + edit_cursor_next_word(env,&eb); + } + break; + case KEY_CTRL_HOME: + case WITH_SHIFT(KEY_HOME): + case KEY_PAGEUP: + case WITH_ALT('<'): + edit_cursor_to_start(env,&eb); + break; + case KEY_CTRL_END: + case WITH_SHIFT(KEY_END): + case KEY_PAGEDOWN: + case WITH_ALT('>'): + edit_cursor_to_end(env,&eb); + break; + case WITH_ALT('m'): + edit_cursor_match_brace(env,&eb); + break; + + // deletion + case KEY_BACKSP: + edit_backspace(env,&eb); + break; + case KEY_DEL: + edit_delete_char(env,&eb); + break; + case WITH_ALT('d'): + edit_delete_to_end_of_word(env,&eb); + break; + case KEY_CTRL_W: + edit_delete_to_start_of_ws_word(env, &eb); + break; + case WITH_ALT(KEY_DEL): + case WITH_ALT(KEY_BACKSP): + edit_delete_to_start_of_word(env,&eb); + break; + case KEY_CTRL_U: + edit_delete_to_start_of_line(env,&eb); + break; + case KEY_CTRL_K: + edit_delete_to_end_of_line(env,&eb); + break; + case KEY_CTRL_T: + edit_swap_char(env,&eb); + break; + + // Editing + case KEY_SHIFT_TAB: + case KEY_LINEFEED: // '\n' (ctrl+J, shift+enter) + if (!env->singleline_only) { + edit_insert_char(env, &eb, '\n'); + } + break; + default: { + char chr; + unicode_t uchr; + if (code_is_ascii_char(c,&chr)) { + edit_insert_char(env,&eb,chr); + } + else if (code_is_unicode(c, &uchr)) { + edit_insert_unicode(env,&eb, uchr); + } + else { + debug_msg( "edit: ignore code: 0x%04x\n", c); + } + break; + } + } + + } + + // goto end + eb.pos = sbuf_len(eb.input); + + // refresh once more but without brace matching + bool bm = env->no_bracematch; + env->no_bracematch = true; + edit_refresh(env,&eb); + env->no_bracematch = bm; + + // save result + char* res; + if ((c == KEY_CTRL_D && sbuf_len(eb.input) == 0) || c == KEY_CTRL_C || c == KEY_EVENT_STOP) { + res = NULL; + } + else if (!tty_is_utf8(env->tty)) { + res = sbuf_strdup_from_utf8(eb.input); + } + else { + res = sbuf_strdup(eb.input); + } + + // update history + history_update(env->history, sbuf_string(eb.input)); + if (res == NULL || sbuf_len(eb.input) <= 1) { ic_history_remove_last(); } // no empty or single-char entries + history_save(env->history); + + // free resources + editstate_done(env->mem, &eb.undo); + editstate_done(env->mem, &eb.redo); + attrbuf_free(eb.attrs); + attrbuf_free(eb.attrs_extra); + sbuf_free(eb.input); + sbuf_free(eb.extra); + sbuf_free(eb.hint); + sbuf_free(eb.hint_help); + + return res; +} + diff --git a/extern/isocline/src/editline_completion.c b/extern/isocline/src/editline_completion.c new file mode 100644 index 00000000..1734ef34 --- /dev/null +++ b/extern/isocline/src/editline_completion.c @@ -0,0 +1,277 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Completion menu: this file is included in editline.c +//------------------------------------------------------------- + +// return true if anything changed +static bool edit_complete(ic_env_t* env, editor_t* eb, ssize_t idx) { + editor_start_modify(eb); + ssize_t newpos = completions_apply(env->completions, idx, eb->input, eb->pos); + if (newpos < 0) { + editor_undo_restore(eb,false); + return false; + } + eb->pos = newpos; + edit_refresh(env,eb); + return true; +} + +static bool edit_complete_longest_prefix(ic_env_t* env, editor_t* eb ) { + editor_start_modify(eb); + ssize_t newpos = completions_apply_longest_prefix( env->completions, eb->input, eb->pos ); + if (newpos < 0) { + editor_undo_restore(eb,false); + return false; + } + eb->pos = newpos; + edit_refresh(env,eb); + return true; +} + +ic_private void sbuf_append_tagged( stringbuf_t* sb, const char* tag, const char* content ) { + sbuf_appendf(sb, "[%s]", tag); + sbuf_append(sb,content); + sbuf_append(sb,"[/]"); +} + +static void editor_append_completion(ic_env_t* env, editor_t* eb, ssize_t idx, ssize_t width, bool numbered, bool selected ) { + const char* help = NULL; + const char* display = completions_get_display(env->completions, idx, &help); + if (display == NULL) return; + if (numbered) { + sbuf_appendf(eb->extra, "[ic-info]%s%zd [/]", (selected ? (tty_is_utf8(env->tty) ? "\xE2\x86\x92" : "*") : " "), 1 + idx); + width -= 3; + } + + if (width > 0) { + sbuf_appendf(eb->extra, "[width=\"%zd;left; ;on\"]", width ); + } + if (selected) { + sbuf_append(eb->extra, "[ic-emphasis]"); + } + sbuf_append(eb->extra, display); + if (selected) { sbuf_append(eb->extra,"[/ic-emphasis]"); } + if (help != NULL) { + sbuf_append(eb->extra, " "); + sbuf_append_tagged(eb->extra, "ic-info", help ); + } + if (width > 0) { sbuf_append(eb->extra,"[/width]"); } +} + +// 2 and 3 column output up to 80 wide +#define IC_DISPLAY2_MAX 34 +#define IC_DISPLAY2_COL (3+IC_DISPLAY2_MAX) +#define IC_DISPLAY2_WIDTH (2*IC_DISPLAY2_COL + 2) // 75 + +#define IC_DISPLAY3_MAX 21 +#define IC_DISPLAY3_COL (3+IC_DISPLAY3_MAX) +#define IC_DISPLAY3_WIDTH (3*IC_DISPLAY3_COL + 2*2) // 76 + +static void editor_append_completion2(ic_env_t* env, editor_t* eb, ssize_t col_width, ssize_t idx1, ssize_t idx2, ssize_t selected ) { + editor_append_completion(env, eb, idx1, col_width, true, (idx1 == selected) ); + sbuf_append( eb->extra, " "); + editor_append_completion(env, eb, idx2, col_width, true, (idx2 == selected) ); +} + +static void editor_append_completion3(ic_env_t* env, editor_t* eb, ssize_t col_width, ssize_t idx1, ssize_t idx2, ssize_t idx3, ssize_t selected ) { + editor_append_completion(env, eb, idx1, col_width, true, (idx1 == selected) ); + sbuf_append( eb->extra, " "); + editor_append_completion(env, eb, idx2, col_width, true, (idx2 == selected)); + sbuf_append( eb->extra, " "); + editor_append_completion(env, eb, idx3, col_width, true, (idx3 == selected) ); +} + +static ssize_t edit_completions_max_width( ic_env_t* env, ssize_t count ) { + ssize_t max_width = 0; + for( ssize_t i = 0; i < count; i++) { + const char* help = NULL; + ssize_t w = bbcode_column_width(env->bbcode, completions_get_display(env->completions, i, &help)); + if (help != NULL) { + w += 2 + bbcode_column_width(env->bbcode, help); + } + if (w > max_width) { + max_width = w; + } + } + return max_width; +} + +static void edit_completion_menu(ic_env_t* env, editor_t* eb, bool more_available) { + ssize_t count = completions_count(env->completions); + ssize_t count_displayed = count; + assert(count > 1); + ssize_t selected = (env->complete_nopreview ? 0 : -1); // select first or none + ssize_t percolumn = count; + +again: + // show first 9 (or 8) completions + sbuf_clear(eb->extra); + ssize_t twidth = term_get_width(env->term) - 1; + ssize_t colwidth; + if (count > 3 && ((colwidth = 3 + edit_completions_max_width(env, 9))*3 + 2*2) < twidth) { + // display as a 3 column block + count_displayed = (count > 9 ? 9 : count); + percolumn = 3; + for (ssize_t rw = 0; rw < percolumn; rw++) { + if (rw > 0) sbuf_append(eb->extra, "\n"); + editor_append_completion3(env, eb, colwidth, rw, percolumn+rw, (2*percolumn)+rw, selected); + } + } + else if (count > 4 && ((colwidth = 3 + edit_completions_max_width(env, 8))*2 + 2) < twidth) { + // display as a 2 column block if some entries are too wide for three columns + count_displayed = (count > 8 ? 8 : count); + percolumn = (count_displayed <= 6 ? 3 : 4); + for (ssize_t rw = 0; rw < percolumn; rw++) { + if (rw > 0) sbuf_append(eb->extra, "\n"); + editor_append_completion2(env, eb, colwidth, rw, percolumn+rw, selected); + } + } + else { + // display as a list + count_displayed = (count > 9 ? 9 : count); + percolumn = count_displayed; + for (ssize_t i = 0; i < count_displayed; i++) { + if (i > 0) sbuf_append(eb->extra, "\n"); + editor_append_completion(env, eb, i, -1, true /* numbered */, selected == i); + } + } + if (count > count_displayed) { + if (more_available) { + sbuf_append(eb->extra, "\n[ic-info](press page-down (or ctrl-j) to see all further completions)[/]"); + } + else { + sbuf_appendf(eb->extra, "\n[ic-info](press page-down (or ctrl-j) to see all %zd completions)[/]", count ); + } + } + if (!env->complete_nopreview && selected >= 0 && selected <= count_displayed) { + edit_complete(env,eb,selected); + editor_undo_restore(eb,false); + } + else { + edit_refresh(env, eb); + } + + // read here; if not a valid key, push it back and return to main event loop + code_t c = tty_read(env->tty); + if (tty_term_resize_event(env->tty)) { + edit_resize(env, eb); + } + sbuf_clear(eb->extra); + + // direct selection? + if (c >= '1' && c <= '9') { + ssize_t i = (c - '1'); + if (i < count) { + selected = i; + c = KEY_ENTER; + } + } + + // process commands + if (c == KEY_DOWN || c == KEY_TAB) { + selected++; + if (selected >= count_displayed) { + //term_beep(env->term); + selected = 0; + } + goto again; + } + else if (c == KEY_UP || c == KEY_SHIFT_TAB) { + selected--; + if (selected < 0) { + selected = count_displayed - 1; + //term_beep(env->term); + } + goto again; + } + else if (c == KEY_F1) { + edit_show_help(env, eb); + goto again; + } + else if (c == KEY_ESC) { + completions_clear(env->completions); + edit_refresh(env,eb); + c = 0; // ignore and return + } + else if (selected >= 0 && (c == KEY_ENTER || c == KEY_RIGHT || c == KEY_END)) /* || c == KEY_TAB*/ { + // select the current entry + assert(selected < count); + c = 0; + edit_complete(env, eb, selected); + if (env->complete_autotab) { + tty_code_pushback(env->tty,KEY_EVENT_AUTOTAB); // immediately try to complete again + } + } + else if (!env->complete_nopreview && !code_is_virt_key(c)) { + // if in preview mode, select the current entry and exit the menu + assert(selected < count); + edit_complete(env, eb, selected); + } + else if ((c == KEY_PAGEDOWN || c == KEY_LINEFEED) && count > 9) { + // show all completions + c = 0; + if (more_available) { + // generate all entries (up to the max (= 1000)) + count = completions_generate(env, env->completions, sbuf_string(eb->input), eb->pos, IC_MAX_COMPLETIONS_TO_SHOW); + } + rowcol_t rc; + edit_get_rowcol(env,eb,&rc); + edit_clear(env,eb); + edit_write_prompt(env,eb,0,false); + term_writeln(env->term, ""); + for(ssize_t i = 0; i < count; i++) { + const char* display = completions_get_display(env->completions, i, NULL); + if (display != NULL) { + bbcode_println(env->bbcode, display); + } + } + if (count >= IC_MAX_COMPLETIONS_TO_SHOW) { + bbcode_println(env->bbcode, "[ic-info]... and more.[/]"); + } + else { + bbcode_printf(env->bbcode, "[ic-info](%zd possible completions)[/]\n", count ); + } + for(ssize_t i = 0; i < rc.row+1; i++) { + term_write(env->term, " \n"); + } + eb->cur_rows = 0; + edit_refresh(env,eb); + } + else { + edit_refresh(env,eb); + } + // done + completions_clear(env->completions); + if (c != 0) tty_code_pushback(env->tty,c); +} + +static void edit_generate_completions(ic_env_t* env, editor_t* eb, bool autotab) { + debug_msg( "edit: complete: %zd: %s\n", eb->pos, sbuf_string(eb->input) ); + if (eb->pos < 0) return; + ssize_t count = completions_generate(env, env->completions, sbuf_string(eb->input), eb->pos, IC_MAX_COMPLETIONS_TO_TRY); + bool more_available = (count >= IC_MAX_COMPLETIONS_TO_TRY); + if (count <= 0) { + // no completions + if (!autotab) { term_beep(env->term); } + } + else if (count == 1) { + // complete if only one match + if (edit_complete(env,eb,0 /*idx*/) && env->complete_autotab) { + tty_code_pushback(env->tty,KEY_EVENT_AUTOTAB); + } + } + else { + //term_beep(env->term); + if (!more_available) { + edit_complete_longest_prefix(env,eb); + } + completions_sort(env->completions); + edit_completion_menu( env, eb, more_available); + } +} diff --git a/extern/isocline/src/editline_help.c b/extern/isocline/src/editline_help.c new file mode 100644 index 00000000..fa07d1db --- /dev/null +++ b/extern/isocline/src/editline_help.c @@ -0,0 +1,140 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Help: this is included into editline.c +//------------------------------------------------------------- + +static const char* help[] = { + "","Navigation:", + "left," + "^b", "go one character to the left", + "right," + "^f", "go one character to the right", + "up", "go one row up, or back in the history", + "down", "go one row down, or forward in the history", + #ifdef __APPLE__ + "shift-left", + #else + "^left", + #endif + "go to the start of the previous word", + #ifdef __APPLE__ + "shift-right", + #else + "^right", + #endif + "go to the end the current word", + "home," + "^a", "go to the start of the current line", + "end," + "^e", "go to the end of the current line", + "pgup," + "^home", "go to the start of the current input", + "pgdn," + "^end", "go to the end of the current input", + "alt-m", "jump to matching brace", + "^p", "go back in the history", + "^n", "go forward in the history", + "^r,^s", "search the history starting with the current word", + "","", + + "", "Deletion:", + "del,^d", "delete the current character", + "backsp,^h", "delete the previous character", + "^w", "delete to preceding white space", + "alt-backsp", "delete to the start of the current word", + "alt-d", "delete to the end of the current word", + "^u", "delete to the start of the current line", + "^k", "delete to the end of the current line", + "esc", "delete the current input, or done with empty input", + "","", + + "", "Editing:", + "enter", "accept current input", + #ifndef __APPLE__ + "^enter, ^j", "", + "shift-tab", + #else + "shift-tab,^j", + #endif + "create a new line for multi-line input", + //" ", "(or type '\\' followed by enter)", + "^l", "clear screen", + "^t", "swap with previous character (move character backward)", + "^z,^_", "undo", + "^y", "redo", + //"^C", "done with empty input", + //"F1", "show this help", + "tab", "try to complete the current input", + "","", + "","In the completion menu:", + "enter,left", "use the currently selected completion", + "1 - 9", "use completion N from the menu", + "tab,down", "select the next completion", + "shift-tab,up","select the previous completion", + "esc", "exit menu without completing", + "pgdn,^j", "show all further possible completions", + "","", + "","In incremental history search:", + "enter", "use the currently found history entry", + "backsp," + "^z", "go back to the previous match (undo)", + "tab," + "^r", "find the next match", + "shift-tab," + "^s", "find an earlier match", + "esc", "exit search", + " ","", + NULL, NULL +}; + +static const char* help_initial = + "[ic-info]" + "Isocline v1.0, copyright (c) 2021 Daan Leijen.\n" + "This is free software; you can redistribute it and/or\n" + "modify it under the terms of the MIT License.\n" + "See <[url]https://github.com/daanx/isocline[/url]> for further information.\n" + "We use ^ as a shorthand for ctrl-.\n" + "\n" + "Overview:\n" + "\n[ansi-lightgray]" + " home,ctrl-a cursor end,ctrl-e\n" + " ┌────────────────┼───────────────┐ (navigate)\n" + //" │ │ │\n" + #ifndef __APPLE__ + " │ ctrl-left │ ctrl-right │\n" + #else + " │ alt-left │ alt-right │\n" + #endif + " │ ┌───────┼──────┐ │ ctrl-r : search history\n" + " ▼ ▼ ▼ ▼ ▼ tab : complete word\n" + " prompt> [ansi-darkgray]it's the quintessential language[/] shift-tab: insert new line\n" + " ▲ ▲ ▲ ▲ esc : delete input, done\n" + " │ └──────────────┘ │ ctrl-z : undo\n" + " │ alt-backsp alt-d │\n" + //" │ │ │\n" + " └────────────────────────────────┘ (delete)\n" + " ctrl-u ctrl-k\n" + "[/ansi-lightgray][/ic-info]\n"; + +static void edit_show_help(ic_env_t* env, editor_t* eb) { + edit_clear(env, eb); + bbcode_println(env->bbcode, help_initial); + for (ssize_t i = 0; help[i] != NULL && help[i+1] != NULL; i += 2) { + if (help[i][0] == 0) { + bbcode_printf(env->bbcode, "[ic-info]%s[/]\n", help[i+1]); + } + else { + bbcode_printf(env->bbcode, " [ic-emphasis]%-13s[/][ansi-lightgray]%s%s[/]\n", help[i], (help[i+1][0] == 0 ? "" : ": "), help[i+1]); + } + } + + eb->cur_rows = 0; + eb->cur_row = 0; + edit_refresh(env, eb); +} diff --git a/extern/isocline/src/editline_history.c b/extern/isocline/src/editline_history.c new file mode 100644 index 00000000..2a0afa1c --- /dev/null +++ b/extern/isocline/src/editline_history.c @@ -0,0 +1,260 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// History search: this file is included in editline.c +//------------------------------------------------------------- + +static void edit_history_at(ic_env_t* env, editor_t* eb, int ofs ) +{ + if (eb->modified) { + history_update(env->history, sbuf_string(eb->input)); // update first entry if modified + eb->history_idx = 0; // and start again + eb->modified = false; + } + const char* entry = history_get(env->history,eb->history_idx + ofs); + // debug_msg( "edit: history: at: %d + %d, found: %s\n", eb->history_idx, ofs, entry); + if (entry == NULL) { + term_beep(env->term); + } + else { + eb->history_idx += ofs; + sbuf_replace(eb->input, entry); + if (ofs > 0) { + // at end of first line when scrolling up + ssize_t end = sbuf_find_line_end(eb->input,0); + eb->pos = (end < 0 ? 0 : end); + } + else { + eb->pos = sbuf_len(eb->input); // at end of last line when scrolling down + } + edit_refresh(env, eb); + } +} + +static void edit_history_prev(ic_env_t* env, editor_t* eb) { + edit_history_at(env,eb, 1 ); +} + +static void edit_history_next(ic_env_t* env, editor_t* eb) { + edit_history_at(env,eb, -1 ); +} + +typedef struct hsearch_s { + struct hsearch_s* next; + ssize_t hidx; + ssize_t match_pos; + ssize_t match_len; + bool cinsert; +} hsearch_t; + +static void hsearch_push( alloc_t* mem, hsearch_t** hs, ssize_t hidx, ssize_t mpos, ssize_t mlen, bool cinsert ) { + hsearch_t* h = mem_zalloc_tp( mem, hsearch_t ); + if (h == NULL) return; + h->hidx = hidx; + h->match_pos = mpos; + h->match_len = mlen; + h->cinsert = cinsert; + h->next = *hs; + *hs = h; +} + +static bool hsearch_pop( alloc_t* mem, hsearch_t** hs, ssize_t* hidx, ssize_t* match_pos, ssize_t* match_len, bool* cinsert ) { + hsearch_t* h = *hs; + if (h == NULL) return false; + *hs = h->next; + if (hidx != NULL) *hidx = h->hidx; + if (match_pos != NULL) *match_pos = h->match_pos; + if (match_len != NULL) *match_len = h->match_len; + if (cinsert != NULL) *cinsert = h->cinsert; + mem_free(mem, h); + return true; +} + +static void hsearch_done( alloc_t* mem, hsearch_t* hs ) { + while (hs != NULL) { + hsearch_t* next = hs->next; + mem_free(mem, hs); + hs = next; + } +} + +static void edit_history_search(ic_env_t* env, editor_t* eb, char* initial ) { + if (history_count( env->history ) <= 0) { + term_beep(env->term); + return; + } + + // update history + if (eb->modified) { + history_update(env->history, sbuf_string(eb->input)); // update first entry if modified + eb->history_idx = 0; // and start again + eb->modified = false; + } + + // set a search prompt and remember the previous state + editor_undo_capture(eb); + eb->disable_undo = true; + bool old_hint = ic_enable_hint(false); + const char* prompt_text = eb->prompt_text; + eb->prompt_text = "history search"; + + // search state + hsearch_t* hs = NULL; // search undo + ssize_t hidx = 1; // current history entry + ssize_t match_pos = 0; // current matched position + ssize_t match_len = 0; // length of the match + const char* hentry = NULL; // current history entry + + // Simulate per character searches for each letter in `initial` (so backspace works) + if (initial != NULL) { + const ssize_t initial_len = ic_strlen(initial); + ssize_t ipos = 0; + while( ipos < initial_len ) { + ssize_t next = str_next_ofs( initial, initial_len, ipos, NULL ); + if (next < 0) break; + hsearch_push( eb->mem, &hs, hidx, match_pos, match_len, true); + char c = initial[ipos + next]; // terminate temporarily + initial[ipos + next] = 0; + if (history_search( env->history, hidx, initial, true, &hidx, &match_pos )) { + match_len = ipos + next; + } + else if (ipos + next >= initial_len) { + term_beep(env->term); + } + initial[ipos + next] = c; // restore + ipos += next; + } + sbuf_replace( eb->input, initial); + eb->pos = ipos; + } + else { + sbuf_clear( eb->input ); + eb->pos = 0; + } + + // Incremental search +again: + hentry = history_get(env->history,hidx); + if (hentry != NULL) { + sbuf_appendf(eb->extra, "[ic-info]%zd. [/][ic-diminish][!pre]", hidx); + sbuf_append_n( eb->extra, hentry, match_pos ); + sbuf_append(eb->extra, "[/pre][u ic-emphasis][!pre]" ); + sbuf_append_n( eb->extra, hentry + match_pos, match_len ); + sbuf_append(eb->extra, "[/pre][/u][!pre]" ); + sbuf_append(eb->extra, hentry + match_pos + match_len ); + sbuf_append(eb->extra, "[/pre][/ic-diminish]"); + if (!env->no_help) { + sbuf_append(eb->extra, "\n[ic-info](use tab for the next match)[/]"); + } + sbuf_append(eb->extra, "\n" ); + } + edit_refresh(env, eb); + + // Wait for input + code_t c = (hentry == NULL ? KEY_ESC : tty_read(env->tty)); + if (tty_term_resize_event(env->tty)) { + edit_resize(env, eb); + } + sbuf_clear(eb->extra); + + // Process commands + if (c == KEY_ESC || c == KEY_BELL /* ^G */ || c == KEY_CTRL_C) { + c = 0; + eb->disable_undo = false; + editor_undo_restore(eb, false); + } + else if (c == KEY_ENTER) { + c = 0; + editor_undo_forget(eb); + sbuf_replace( eb->input, hentry ); + eb->pos = sbuf_len(eb->input); + eb->modified = false; + eb->history_idx = hidx; + } + else if (c == KEY_BACKSP || c == KEY_CTRL_Z) { + // undo last search action + bool cinsert; + if (hsearch_pop(env->mem,&hs, &hidx, &match_pos, &match_len, &cinsert)) { + if (cinsert) edit_backspace(env,eb); + } + goto again; + } + else if (c == KEY_CTRL_R || c == KEY_TAB || c == KEY_UP) { + // search backward + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, false); + if (!history_search( env->history, hidx+1, sbuf_string(eb->input), true, &hidx, &match_pos )) { + hsearch_pop(env->mem,&hs,NULL,NULL,NULL,NULL); + term_beep(env->term); + }; + goto again; + } + else if (c == KEY_CTRL_S || c == KEY_SHIFT_TAB || c == KEY_DOWN) { + // search forward + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, false); + if (!history_search( env->history, hidx-1, sbuf_string(eb->input), false, &hidx, &match_pos )) { + hsearch_pop(env->mem, &hs,NULL,NULL,NULL,NULL); + term_beep(env->term); + }; + goto again; + } + else if (c == KEY_F1) { + edit_show_help(env, eb); + goto again; + } + else { + // insert character and search further backward + char chr; + unicode_t uchr; + if (code_is_ascii_char(c,&chr)) { + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, true); + edit_insert_char(env,eb,chr); + } + else if (code_is_unicode(c,&uchr)) { + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, true); + edit_insert_unicode(env,eb,uchr); + } + else { + // ignore command + term_beep(env->term); + goto again; + } + // search for the new input + if (history_search( env->history, hidx, sbuf_string(eb->input), true, &hidx, &match_pos )) { + match_len = sbuf_len(eb->input); + } + else { + term_beep(env->term); + }; + goto again; + } + + // done + eb->disable_undo = false; + hsearch_done(env->mem,hs); + eb->prompt_text = prompt_text; + ic_enable_hint(old_hint); + edit_refresh(env,eb); + if (c != 0) tty_code_pushback(env->tty, c); +} + +// Start an incremental search with the current word +static void edit_history_search_with_current_word(ic_env_t* env, editor_t* eb) { + char* initial = NULL; + ssize_t start = sbuf_find_word_start( eb->input, eb->pos ); + if (start >= 0) { + const ssize_t next = sbuf_next(eb->input, start, NULL); + if (!ic_char_is_idletter(sbuf_string(eb->input) + start, (long)(next - start))) { + start = next; + } + if (start >= 0 && start < eb->pos) { + initial = mem_strndup(eb->mem, sbuf_string(eb->input) + start, eb->pos - start); + } + } + edit_history_search( env, eb, initial); + mem_free(env->mem, initial); +} diff --git a/extern/isocline/src/env.h b/extern/isocline/src/env.h new file mode 100644 index 00000000..edfc1003 --- /dev/null +++ b/extern/isocline/src/env.h @@ -0,0 +1,60 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_ENV_H +#define IC_ENV_H + +#include "../include/isocline.h" +#include "common.h" +#include "term.h" +#include "tty.h" +#include "stringbuf.h" +#include "history.h" +#include "completions.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// Environment +//------------------------------------------------------------- + +struct ic_env_s { + alloc_t* mem; // potential custom allocator + ic_env_t* next; // next environment (used for proper deallocation) + term_t* term; // terminal + tty_t* tty; // keyboard (NULL if stdin is a pipe, file, etc) + completions_t* completions; // current completions + history_t* history; // edit history + bbcode_t* bbcode; // print with bbcodes + const char* prompt_marker; // the prompt marker (defaults to "> ") + const char* cprompt_marker; // prompt marker for continuation lines (defaults to `prompt_marker`) + ic_highlight_fun_t* highlighter; // highlight callback + void* highlighter_arg; // user state for the highlighter. + const char* match_braces; // matching braces, e.g "()[]{}" + const char* auto_braces; // auto insertion braces, e.g "()[]{}\"\"''" + char multiline_eol; // character used for multiline input ("\") (set to 0 to disable) + bool initialized; // are we initialized? + bool noedit; // is rich editing possible (tty != NULL) + bool singleline_only; // allow only single line editing? + bool complete_nopreview; // do not show completion preview for each selection in the completion menu? + bool complete_autotab; // try to keep completing after a completion? + bool no_multiline_indent; // indent continuation lines to line up under the initial prompt + bool no_help; // show short help line for history search etc. + bool no_hint; // allow hinting? + bool no_highlight; // enable highlighting? + bool no_bracematch; // enable brace matching? + bool no_autobrace; // enable automatic brace insertion? + bool no_lscolors; // use LSCOLORS/LS_COLORS to colorize file name completions? + long hint_delay; // delay before displaying a hint in milliseconds +}; + +ic_private char* ic_editline(ic_env_t* env, const char* prompt_text); + +ic_private ic_env_t* ic_get_env(void); +ic_private const char* ic_env_get_auto_braces(ic_env_t* env); +ic_private const char* ic_env_get_match_braces(ic_env_t* env); + +#endif // IC_ENV_H diff --git a/extern/isocline/src/highlight.c b/extern/isocline/src/highlight.c new file mode 100644 index 00000000..59c7255c --- /dev/null +++ b/extern/isocline/src/highlight.c @@ -0,0 +1,259 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +#include +#include "common.h" +#include "term.h" +#include "stringbuf.h" +#include "attr.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// Syntax highlighting +//------------------------------------------------------------- + +struct ic_highlight_env_s { + attrbuf_t* attrs; + const char* input; + ssize_t input_len; + bbcode_t* bbcode; + alloc_t* mem; + ssize_t cached_upos; // cached unicode position + ssize_t cached_cpos; // corresponding utf-8 byte position +}; + + +ic_private void highlight( alloc_t* mem, bbcode_t* bb, const char* s, attrbuf_t* attrs, ic_highlight_fun_t* highlighter, void* arg ) { + const ssize_t len = ic_strlen(s); + if (len <= 0) return; + attrbuf_set_at(attrs,0,len,attr_none()); // fill to length of s + if (highlighter != NULL) { + ic_highlight_env_t henv; + henv.attrs = attrs; + henv.input = s; + henv.input_len = len; + henv.bbcode = bb; + henv.mem = mem; + henv.cached_cpos = 0; + henv.cached_upos = 0; + (*highlighter)( &henv, s, arg ); + } +} + + +//------------------------------------------------------------- +// Client interface +//------------------------------------------------------------- + +static void pos_adjust( ic_highlight_env_t* henv, ssize_t* ppos, ssize_t* plen ) { + ssize_t pos = *ppos; + ssize_t len = *plen; + if (pos >= henv->input_len) return; + if (pos >= 0 && len >= 0) return; // already character positions + if (henv->input == NULL) return; + + if (pos < 0) { + // negative `pos` is used as the unicode character position (for easy interfacing with Haskell) + ssize_t upos = -pos; + ssize_t cpos = 0; + ssize_t ucount = 0; + if (henv->cached_upos <= upos) { // if we have a cached position, start from there + ucount = henv->cached_upos; + cpos = henv->cached_cpos; + } + while ( ucount < upos ) { + ssize_t next = str_next_ofs(henv->input, henv->input_len, cpos, NULL); + if (next <= 0) return; + ucount++; + cpos += next; + } + *ppos = pos = cpos; + // and cache it to avoid quadratic behavior + henv->cached_upos = upos; + henv->cached_cpos = cpos; + } + if (len < 0) { + // negative `len` is used as a unicode character length + len = -len; + ssize_t ucount = 0; + ssize_t clen = 0; + while (ucount < len) { + ssize_t next = str_next_ofs(henv->input, henv->input_len, pos + clen, NULL); + if (next <= 0) return; + ucount++; + clen += next; + } + *plen = len = clen; + // and update cache if possible + if (henv->cached_cpos == pos) { + henv->cached_upos += ucount; + henv->cached_cpos += clen; + } + } +} + +static void highlight_attr(ic_highlight_env_t* henv, ssize_t pos, ssize_t count, attr_t attr ) { + if (henv==NULL) return; + pos_adjust(henv,&pos,&count); + if (pos < 0 || count <= 0) return; + attrbuf_update_at(henv->attrs, pos, count, attr); +} + +ic_public void ic_highlight(ic_highlight_env_t* henv, long pos, long count, const char* style ) { + if (henv == NULL || style==NULL || style[0]==0 || pos < 0) return; + highlight_attr(henv,pos,count,bbcode_style( henv->bbcode, style )); +} + +ic_public void ic_highlight_formatted(ic_highlight_env_t* henv, const char* s, const char* fmt) { + if (s==NULL || s[0] == 0 || fmt==NULL) return; + attrbuf_t* attrs = attrbuf_new(henv->mem); + stringbuf_t* out = sbuf_new(henv->mem); // todo: avoid allocating out? + if (attrs!=NULL && out != NULL) { + bbcode_append( henv->bbcode, fmt, out, attrs); + const ssize_t len = ic_strlen(s); + if (sbuf_len(out) != len) { + debug_msg("highlight: formatted string content differs from the original input:\n original: %s\n formatted: %s\n", s, fmt); + } + for( ssize_t i = 0; i < len; i++) { + attrbuf_update_at(henv->attrs, i, 1, attrbuf_attr_at(attrs,i)); + } + } + sbuf_free(out); + attrbuf_free(attrs); +} + +//------------------------------------------------------------- +// Brace matching +//------------------------------------------------------------- +#define MAX_NESTING (64) + +typedef struct brace_s { + char close; + bool at_cursor; + ssize_t pos; +} brace_t; + +ic_private void highlight_match_braces(const char* s, attrbuf_t* attrs, ssize_t cursor_pos, const char* braces, attr_t match_attr, attr_t error_attr) +{ + brace_t open[MAX_NESTING+1]; + ssize_t nesting = 0; + const ssize_t brace_len = ic_strlen(braces); + for (long i = 0; i < ic_strlen(s); i++) { + const char c = s[i]; + // push open brace + bool found_open = false; + for (ssize_t b = 0; b < brace_len; b += 2) { + if (c == braces[b]) { + // open brace + if (nesting >= MAX_NESTING) return; // give up + open[nesting].close = braces[b+1]; + open[nesting].pos = i; + open[nesting].at_cursor = (i == cursor_pos - 1); + nesting++; + found_open = true; + break; + } + } + if (found_open) continue; + + // pop to closing brace and potentially highlight + for (ssize_t b = 1; b < brace_len; b += 2) { + if (c == braces[b]) { + // close brace + if (nesting <= 0) { + // unmatched close brace + attrbuf_update_at( attrs, i, 1, error_attr); + } + else { + // can we fix an unmatched brace where we can match by popping just one? + if (open[nesting-1].close != c && nesting > 1 && open[nesting-2].close == c) { + // assume previous open brace was wrong + attrbuf_update_at(attrs, open[nesting-1].pos, 1, error_attr); + nesting--; + } + if (open[nesting-1].close != c) { + // unmatched open brace + attrbuf_update_at( attrs, i, 1, error_attr); + } + else { + // matching brace + nesting--; + if (i == cursor_pos - 1 || (open[nesting].at_cursor && open[nesting].pos != i - 1)) { + // highlight matching brace + attrbuf_update_at(attrs, open[nesting].pos, 1, match_attr); + attrbuf_update_at(attrs, i, 1, match_attr); + } + } + } + break; + } + } + } + // note: don't mark further unmatched open braces as in error +} + + +ic_private ssize_t find_matching_brace(const char* s, ssize_t cursor_pos, const char* braces, bool* is_balanced) +{ + if (is_balanced != NULL) { *is_balanced = false; } + bool balanced = true; + ssize_t match = -1; + brace_t open[MAX_NESTING+1]; + ssize_t nesting = 0; + const ssize_t brace_len = ic_strlen(braces); + for (long i = 0; i < ic_strlen(s); i++) { + const char c = s[i]; + // push open brace + bool found_open = false; + for (ssize_t b = 0; b < brace_len; b += 2) { + if (c == braces[b]) { + // open brace + if (nesting >= MAX_NESTING) return -1; // give up + open[nesting].close = braces[b+1]; + open[nesting].pos = i; + open[nesting].at_cursor = (i == cursor_pos - 1); + nesting++; + found_open = true; + break; + } + } + if (found_open) continue; + + // pop to closing brace + for (ssize_t b = 1; b < brace_len; b += 2) { + if (c == braces[b]) { + // close brace + if (nesting <= 0) { + // unmatched close brace + balanced = false; + } + else { + if (open[nesting-1].close != c) { + // unmatched open brace + balanced = false; + } + else { + // matching brace + nesting--; + if (i == cursor_pos - 1) { + // found matching open brace + match = open[nesting].pos + 1; + } + else if (open[nesting].at_cursor) { + // found matching close brace + match = i + 1; + } + } + } + break; + } + } + } + if (nesting != 0) { balanced = false; } + if (is_balanced != NULL) { *is_balanced = balanced; } + return match; +} diff --git a/extern/isocline/src/highlight.h b/extern/isocline/src/highlight.h new file mode 100644 index 00000000..67da02ff --- /dev/null +++ b/extern/isocline/src/highlight.h @@ -0,0 +1,24 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_HIGHLIGHT_H +#define IC_HIGHLIGHT_H + +#include "common.h" +#include "attr.h" +#include "term.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// Syntax highlighting +//------------------------------------------------------------- + +ic_private void highlight( alloc_t* mem, bbcode_t* bb, const char* s, attrbuf_t* attrs, ic_highlight_fun_t* highlighter, void* arg ); +ic_private void highlight_match_braces(const char* s, attrbuf_t* attrs, ssize_t cursor_pos, const char* braces, attr_t match_attr, attr_t error_attr); +ic_private ssize_t find_matching_brace(const char* s, ssize_t cursor_pos, const char* braces, bool* is_balanced); + +#endif // IC_HIGHLIGHT_H diff --git a/extern/isocline/src/history.c b/extern/isocline/src/history.c new file mode 100644 index 00000000..440976aa --- /dev/null +++ b/extern/isocline/src/history.c @@ -0,0 +1,269 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "history.h" +#include "stringbuf.h" + +#define IC_MAX_HISTORY (200) + +struct history_s { + ssize_t count; // current number of entries in use + ssize_t len; // size of elems + const char** elems; // history items (up to count) + const char* fname; // history file + alloc_t* mem; + bool allow_duplicates; // allow duplicate entries? +}; + +ic_private history_t* history_new(alloc_t* mem) { + history_t* h = mem_zalloc_tp(mem,history_t); + h->mem = mem; + return h; +} + +ic_private void history_free(history_t* h) { + if (h == NULL) return; + history_clear(h); + if (h->len > 0) { + mem_free( h->mem, h->elems ); + h->elems = NULL; + h->len = 0; + } + mem_free(h->mem, h->fname); + h->fname = NULL; + mem_free(h->mem, h); // free ourselves +} + +ic_private bool history_enable_duplicates( history_t* h, bool enable ) { + bool prev = h->allow_duplicates; + h->allow_duplicates = enable; + return prev; +} + +ic_private ssize_t history_count(const history_t* h) { + return h->count; +} + +//------------------------------------------------------------- +// push/clear +//------------------------------------------------------------- + +ic_private bool history_update( history_t* h, const char* entry ) { + if (entry==NULL) return false; + history_remove_last(h); + history_push(h,entry); + //debug_msg("history: update: with %s; now at %s\n", entry, history_get(h,0)); + return true; +} + +static void history_delete_at( history_t* h, ssize_t idx ) { + if (idx < 0 || idx >= h->count) return; + mem_free(h->mem, h->elems[idx]); + for(ssize_t i = idx+1; i < h->count; i++) { + h->elems[i-1] = h->elems[i]; + } + h->count--; +} + +ic_private bool history_push( history_t* h, const char* entry ) { + if (h->len <= 0 || entry==NULL) return false; + // remove any older duplicate + if (!h->allow_duplicates) { + for( int i = 0; i < h->count; i++) { + if (strcmp(h->elems[i],entry) == 0) { + history_delete_at(h,i); + } + } + } + // insert at front + if (h->count == h->len) { + // delete oldest entry + history_delete_at(h,0); + } + assert(h->count < h->len); + h->elems[h->count] = mem_strdup(h->mem,entry); + h->count++; + return true; +} + + +static void history_remove_last_n( history_t* h, ssize_t n ) { + if (n <= 0) return; + if (n > h->count) n = h->count; + for( ssize_t i = h->count - n; i < h->count; i++) { + mem_free( h->mem, h->elems[i] ); + } + h->count -= n; + assert(h->count >= 0); +} + +ic_private void history_remove_last(history_t* h) { + history_remove_last_n(h,1); +} + +ic_private void history_clear(history_t* h) { + history_remove_last_n( h, h->count ); +} + +ic_private const char* history_get( const history_t* h, ssize_t n ) { + if (n < 0 || n >= h->count) return NULL; + return h->elems[h->count - n - 1]; +} + +ic_private bool history_search( const history_t* h, ssize_t from /*including*/, const char* search, bool backward, ssize_t* hidx, ssize_t* hpos ) { + const char* p = NULL; + ssize_t i; + if (backward) { + for( i = from; i < h->count; i++ ) { + p = strstr( history_get(h,i), search); + if (p != NULL) break; + } + } + else { + for( i = from; i >= 0; i-- ) { + p = strstr( history_get(h,i), search); + if (p != NULL) break; + } + } + if (p == NULL) return false; + if (hidx != NULL) *hidx = i; + if (hpos != NULL) *hpos = (p - history_get(h,i)); + return true; +} + +//------------------------------------------------------------- +// +//------------------------------------------------------------- + +ic_private void history_load_from(history_t* h, const char* fname, long max_entries ) { + history_clear(h); + h->fname = mem_strdup(h->mem,fname); + if (max_entries == 0) { + assert(h->elems == NULL); + return; + } + if (max_entries < 0 || max_entries > IC_MAX_HISTORY) max_entries = IC_MAX_HISTORY; + h->elems = (const char**)mem_zalloc_tp_n(h->mem, char*, max_entries ); + if (h->elems == NULL) return; + h->len = max_entries; + history_load(h); +} + + + + +//------------------------------------------------------------- +// save/load history to file +//------------------------------------------------------------- + +static char from_xdigit( int c ) { + if (c >= '0' && c <= '9') return (char)(c - '0'); + if (c >= 'A' && c <= 'F') return (char)(10 + (c - 'A')); + if (c >= 'a' && c <= 'f') return (char)(10 + (c - 'a')); + return 0; +} + +static char to_xdigit( uint8_t c ) { + if (c <= 9) return ((char)c + '0'); + if (c >= 10 && c <= 15) return ((char)c - 10 + 'A'); + return '0'; +} + +static bool ic_isxdigit( int c ) { + return ((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || (c >= '0' && c <= '9')); +} + +static bool history_read_entry( history_t* h, FILE* f, stringbuf_t* sbuf ) { + sbuf_clear(sbuf); + while( !feof(f)) { + int c = fgetc(f); + if (c == EOF || c == '\n') break; + if (c == '\\') { + c = fgetc(f); + if (c == 'n') { sbuf_append(sbuf,"\n"); } + else if (c == 'r') { /* ignore */ } // sbuf_append(sbuf,"\r"); + else if (c == 't') { sbuf_append(sbuf,"\t"); } + else if (c == '\\') { sbuf_append(sbuf,"\\"); } + else if (c == 'x') { + int c1 = fgetc(f); + int c2 = fgetc(f); + if (ic_isxdigit(c1) && ic_isxdigit(c2)) { + char chr = from_xdigit(c1)*16 + from_xdigit(c2); + sbuf_append_char(sbuf,chr); + } + else return false; + } + else return false; + } + else sbuf_append_char(sbuf,(char)c); + } + if (sbuf_len(sbuf)==0 || sbuf_string(sbuf)[0] == '#') return true; + return history_push(h, sbuf_string(sbuf)); +} + +static bool history_write_entry( const char* entry, FILE* f, stringbuf_t* sbuf ) { + sbuf_clear(sbuf); + //debug_msg("history: write: %s\n", entry); + while( entry != NULL && *entry != 0 ) { + char c = *entry++; + if (c == '\\') { sbuf_append(sbuf,"\\\\"); } + else if (c == '\n') { sbuf_append(sbuf,"\\n"); } + else if (c == '\r') { /* ignore */ } // sbuf_append(sbuf,"\\r"); } + else if (c == '\t') { sbuf_append(sbuf,"\\t"); } + else if (c < ' ' || c > '~' || c == '#') { + char c1 = to_xdigit( (uint8_t)c / 16 ); + char c2 = to_xdigit( (uint8_t)c % 16 ); + sbuf_append(sbuf,"\\x"); + sbuf_append_char(sbuf,c1); + sbuf_append_char(sbuf,c2); + } + else sbuf_append_char(sbuf,c); + } + //debug_msg("history: write buf: %s\n", sbuf_string(sbuf)); + + if (sbuf_len(sbuf) > 0) { + sbuf_append(sbuf,"\n"); + fputs(sbuf_string(sbuf),f); + } + return true; +} + +ic_private void history_load( history_t* h ) { + if (h->fname == NULL) return; + FILE* f = fopen(h->fname, "r"); + if (f == NULL) return; + stringbuf_t* sbuf = sbuf_new(h->mem); + if (sbuf != NULL) { + while (!feof(f)) { + if (!history_read_entry(h,f,sbuf)) break; // error + } + sbuf_free(sbuf); + } + fclose(f); +} + +ic_private void history_save( const history_t* h ) { + if (h->fname == NULL) return; + FILE* f = fopen(h->fname, "w"); + if (f == NULL) return; + #ifndef _WIN32 + chmod(h->fname,S_IRUSR|S_IWUSR); + #endif + stringbuf_t* sbuf = sbuf_new(h->mem); + if (sbuf != NULL) { + for( int i = 0; i < h->count; i++ ) { + if (!history_write_entry(h->elems[i],f,sbuf)) break; // error + } + sbuf_free(sbuf); + } + fclose(f); +} diff --git a/extern/isocline/src/history.h b/extern/isocline/src/history.h new file mode 100644 index 00000000..76a37160 --- /dev/null +++ b/extern/isocline/src/history.h @@ -0,0 +1,38 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_HISTORY_H +#define IC_HISTORY_H + +#include "common.h" + +//------------------------------------------------------------- +// History +//------------------------------------------------------------- + +struct history_s; +typedef struct history_s history_t; + +ic_private history_t* history_new(alloc_t* mem); +ic_private void history_free(history_t* h); +ic_private void history_clear(history_t* h); +ic_private bool history_enable_duplicates( history_t* h, bool enable ); +ic_private ssize_t history_count(const history_t* h); + +ic_private void history_load_from(history_t* h, const char* fname, long max_entries); +ic_private void history_load( history_t* h ); +ic_private void history_save( const history_t* h ); + +ic_private bool history_push( history_t* h, const char* entry ); +ic_private bool history_update( history_t* h, const char* entry ); +ic_private const char* history_get( const history_t* h, ssize_t n ); +ic_private void history_remove_last(history_t* h); + +ic_private bool history_search( const history_t* h, ssize_t from, const char* search, bool backward, ssize_t* hidx, ssize_t* hpos); + + +#endif // IC_HISTORY_H diff --git a/extern/isocline/src/isocline.c b/extern/isocline/src/isocline.c new file mode 100644 index 00000000..13278062 --- /dev/null +++ b/extern/isocline/src/isocline.c @@ -0,0 +1,589 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Usually we include all sources one file so no internal +// symbols are public in the libray. +// +// You can compile the entire library just as: +// $ gcc -c src/isocline.c +//------------------------------------------------------------- +#if !defined(IC_SEPARATE_OBJS) +# define _CRT_SECURE_NO_WARNINGS // for msvc +# define _XOPEN_SOURCE 700 // for wcwidth +# define _DEFAULT_SOURCE // ensure usleep stays visible with _XOPEN_SOURCE >= 700 +# include "attr.c" +# include "bbcode.c" +# include "editline.c" +# include "highlight.c" +# include "undo.c" +# include "history.c" +# include "completers.c" +# include "completions.c" +# include "term.c" +# include "tty_esc.c" +# include "tty.c" +# include "stringbuf.c" +# include "common.c" +#endif + +//------------------------------------------------------------- +// includes +//------------------------------------------------------------- +#include +#include +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" + + +//------------------------------------------------------------- +// Readline +//------------------------------------------------------------- + +static char* ic_getline( alloc_t* mem ); + +ic_public char* ic_readline(const char* prompt_text) +{ + ic_env_t* env = ic_get_env(); + if (env == NULL) return NULL; + if (!env->noedit) { + // terminal editing enabled + return ic_editline(env, prompt_text); // in editline.c + } + else { + // no editing capability (pipe, dumb terminal, etc) + if (env->tty != NULL && env->term != NULL) { + // if the terminal is not interactive, but we are reading from the tty (keyboard), we display a prompt + term_start_raw(env->term); // set utf8 mode on windows + if (prompt_text != NULL) { + term_write(env->term, prompt_text); + } + term_write(env->term, env->prompt_marker); + term_end_raw(env->term, false); + } + // read directly from stdin + return ic_getline(env->mem); + } +} + + +//------------------------------------------------------------- +// Read a line from the stdin stream if there is no editing +// support (like from a pipe, file, or dumb terminal). +//------------------------------------------------------------- + +static char* ic_getline(alloc_t* mem) +{ + // read until eof or newline + stringbuf_t* sb = sbuf_new(mem); + int c; + while (true) { + c = fgetc(stdin); + if (c==EOF || c=='\n') { + break; + } + else { + sbuf_append_char(sb, (char)c); + } + } + return sbuf_free_dup(sb); +} + + +//------------------------------------------------------------- +// Formatted output +//------------------------------------------------------------- + + +ic_public void ic_printf(const char* fmt, ...) { + va_list ap; + va_start(ap, fmt); + ic_vprintf(fmt, ap); + va_end(ap); +} + +ic_public void ic_vprintf(const char* fmt, va_list args) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode == NULL) return; + bbcode_vprintf(env->bbcode, fmt, args); +} + +ic_public void ic_print(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_print(env->bbcode, s); +} + +ic_public void ic_println(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_println(env->bbcode, s); +} + +void ic_style_def(const char* name, const char* fmt) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_style_def(env->bbcode, name, fmt); +} + +void ic_style_open(const char* fmt) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_style_open(env->bbcode, fmt); +} + +void ic_style_close(void) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_style_close(env->bbcode, NULL); +} + + +//------------------------------------------------------------- +// Interface +//------------------------------------------------------------- + +ic_public bool ic_async_stop(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + if (env->tty==NULL) return false; + return tty_async_stop(env->tty); +} + +static void set_prompt_marker(ic_env_t* env, const char* prompt_marker, const char* cprompt_marker) { + if (prompt_marker == NULL) prompt_marker = "> "; + if (cprompt_marker == NULL) cprompt_marker = prompt_marker; + mem_free(env->mem, env->prompt_marker); + mem_free(env->mem, env->cprompt_marker); + env->prompt_marker = mem_strdup(env->mem, prompt_marker); + env->cprompt_marker = mem_strdup(env->mem, cprompt_marker); +} + +ic_public const char* ic_get_prompt_marker(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + return env->prompt_marker; +} + +ic_public const char* ic_get_continuation_prompt_marker(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + return env->cprompt_marker; +} + +ic_public void ic_set_prompt_marker( const char* prompt_marker, const char* cprompt_marker ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + set_prompt_marker(env, prompt_marker, cprompt_marker); +} + +ic_public bool ic_enable_multiline( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->singleline_only; + env->singleline_only = !enable; + return !prev; +} + +ic_public bool ic_enable_beep( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + return term_enable_beep(env->term, enable); +} + +ic_public bool ic_enable_color( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + return term_enable_color( env->term, enable ); +} + +ic_public bool ic_enable_history_duplicates( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + return history_enable_duplicates(env->history, enable); +} + +ic_public void ic_set_history(const char* fname, long max_entries ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_load_from(env->history, fname, max_entries ); +} + +ic_public void ic_history_remove_last(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_remove_last(env->history); +} + +ic_public void ic_history_add( const char* entry ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_push( env->history, entry ); +} + +ic_public void ic_history_clear(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_clear(env->history); +} + +ic_public bool ic_enable_auto_tab( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->complete_autotab; + env->complete_autotab = enable; + return prev; +} + +ic_public bool ic_enable_completion_preview( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->complete_nopreview; + env->complete_nopreview = !enable; + return !prev; +} + +ic_public bool ic_enable_multiline_indent(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_multiline_indent; + env->no_multiline_indent = !enable; + return !prev; +} + +ic_public bool ic_enable_hint(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_hint; + env->no_hint = !enable; + return !prev; +} + +ic_public long ic_set_hint_delay(long delay_ms) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + long prev = env->hint_delay; + env->hint_delay = (delay_ms < 0 ? 0 : (delay_ms > 5000 ? 5000 : delay_ms)); + return prev; +} + +ic_public void ic_set_tty_esc_delay(long initial_delay_ms, long followup_delay_ms ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->tty == NULL) return; + tty_set_esc_delay(env->tty, initial_delay_ms, followup_delay_ms); +} + + +ic_public bool ic_enable_highlight(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_highlight; + env->no_highlight = !enable; + return !prev; +} + +ic_public bool ic_enable_inline_help(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_help; + env->no_help = !enable; + return !prev; +} + +ic_public bool ic_enable_brace_matching(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_bracematch; + env->no_bracematch = !enable; + return !prev; +} + +ic_public void ic_set_matching_braces(const char* brace_pairs) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + mem_free(env->mem, env->match_braces); + env->match_braces = NULL; + if (brace_pairs != NULL) { + ssize_t len = ic_strlen(brace_pairs); + if (len > 0 && (len % 2) == 0) { + env->match_braces = mem_strdup(env->mem, brace_pairs); + } + } +} + +ic_public bool ic_enable_brace_insertion(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_autobrace; + env->no_autobrace = !enable; + return !prev; +} + +ic_public void ic_set_insertion_braces(const char* brace_pairs) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + mem_free(env->mem, env->auto_braces); + env->auto_braces = NULL; + if (brace_pairs != NULL) { + ssize_t len = ic_strlen(brace_pairs); + if (len > 0 && (len % 2) == 0) { + env->auto_braces = mem_strdup(env->mem, brace_pairs); + } + } +} + +ic_private const char* ic_env_get_match_braces(ic_env_t* env) { + return (env->match_braces == NULL ? "()[]{}" : env->match_braces); +} + +ic_private const char* ic_env_get_auto_braces(ic_env_t* env) { + return (env->auto_braces == NULL ? "()[]{}\"\"''" : env->auto_braces); +} + +ic_public void ic_set_default_highlighter(ic_highlight_fun_t* highlighter, void* arg) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + env->highlighter = highlighter; + env->highlighter_arg = arg; +} + + +ic_public void ic_free( void* p ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + mem_free(env->mem, p); +} + +ic_public void* ic_malloc(size_t sz) { + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + return mem_malloc(env->mem, to_ssize_t(sz)); +} + +ic_public const char* ic_strdup( const char* s ) { + if (s==NULL) return NULL; + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + ssize_t len = ic_strlen(s); + char* p = mem_malloc_tp_n( env->mem, char, len + 1 ); + if (p == NULL) return NULL; + ic_memcpy( p, s, len ); + p[len] = 0; + return p; +} + +//------------------------------------------------------------- +// Terminal +//------------------------------------------------------------- + +ic_public void ic_term_init(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term==NULL) return; + term_start_raw(env->term); +} + +ic_public void ic_term_done(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term==NULL) return; + term_end_raw(env->term,false); +} + +ic_public void ic_term_flush(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term==NULL) return; + term_flush(env->term); +} + +ic_public void ic_term_write(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_write(env->term, s); +} + +ic_public void ic_term_writeln(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_writeln(env->term, s); +} + +ic_public void ic_term_writef(const char* fmt, ...) { + va_list ap; + va_start(ap, fmt); + ic_term_vwritef(fmt, ap); + va_end(ap); +} + +ic_public void ic_term_vwritef(const char* fmt, va_list args) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_vwritef(env->term, fmt, args); +} + +ic_public void ic_term_reset( void ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_attr_reset(env->term); +} + +ic_public void ic_term_style( const char* style ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL || env->bbcode == NULL) return; + term_set_attr( env->term, bbcode_style(env->bbcode, style)); +} + +ic_public int ic_term_get_color_bits(void) { + ic_env_t* env = ic_get_env(); + if (env==NULL || env->term==NULL) return 4; + return term_get_color_bits(env->term); +} + +ic_public void ic_term_bold(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_bold(env->term, enable); +} + +ic_public void ic_term_underline(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_underline(env->term, enable); +} + +ic_public void ic_term_italic(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_italic(env->term, enable); +} + +ic_public void ic_term_reverse(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_reverse(env->term, enable); +} + +ic_public void ic_term_color_ansi(bool foreground, int ansi_color) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + ic_color_t color = color_from_ansi256(ansi_color); + if (foreground) { term_color(env->term, color); } + else { term_bgcolor(env->term, color); } +} + +ic_public void ic_term_color_rgb(bool foreground, uint32_t hcolor) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + ic_color_t color = ic_rgb(hcolor); + if (foreground) { term_color(env->term, color); } + else { term_bgcolor(env->term, color); } +} + + +//------------------------------------------------------------- +// Readline with temporary completer and highlighter +//------------------------------------------------------------- + +ic_public char* ic_readline_ex(const char* prompt_text, + ic_completer_fun_t* completer, void* completer_arg, + ic_highlight_fun_t* highlighter, void* highlighter_arg ) +{ + ic_env_t* env = ic_get_env(); if (env == NULL) return NULL; + // save previous + ic_completer_fun_t* prev_completer; + void* prev_completer_arg; + completions_get_completer(env->completions, &prev_completer, &prev_completer_arg); + ic_highlight_fun_t* prev_highlighter = env->highlighter; + void* prev_highlighter_arg = env->highlighter_arg; + // call with current + if (completer != NULL) { ic_set_default_completer(completer, completer_arg); } + if (highlighter != NULL) { ic_set_default_highlighter(highlighter, highlighter_arg); } + char* res = ic_readline(prompt_text); + // restore previous + ic_set_default_completer(prev_completer, prev_completer_arg); + ic_set_default_highlighter(prev_highlighter, prev_highlighter_arg); + return res; +} + + +//------------------------------------------------------------- +// Initialize +//------------------------------------------------------------- + +static void ic_atexit(void); + +static void ic_env_free(ic_env_t* env) { + if (env == NULL) return; + history_save(env->history); + history_free(env->history); + completions_free(env->completions); + bbcode_free(env->bbcode); + term_free(env->term); + tty_free(env->tty); + mem_free(env->mem, env->cprompt_marker); + mem_free(env->mem,env->prompt_marker); + mem_free(env->mem, env->match_braces); + mem_free(env->mem, env->auto_braces); + env->prompt_marker = NULL; + + // and deallocate ourselves + alloc_t* mem = env->mem; + mem_free(mem, env); + + // and finally the custom memory allocation structure + mem_free(mem, mem); +} + + +static ic_env_t* ic_env_create( ic_malloc_fun_t* _malloc, ic_realloc_fun_t* _realloc, ic_free_fun_t* _free ) +{ + if (_malloc == NULL) _malloc = &malloc; + if (_realloc == NULL) _realloc = &realloc; + if (_free == NULL) _free = &free; + // allocate + alloc_t* mem = (alloc_t*)_malloc(sizeof(alloc_t)); + if (mem == NULL) return NULL; + mem->malloc = _malloc; + mem->realloc = _realloc; + mem->free = _free; + ic_env_t* env = mem_zalloc_tp(mem, ic_env_t); + if (env==NULL) { + mem->free(mem); + return NULL; + } + env->mem = mem; + + // Initialize + env->tty = tty_new(env->mem, -1); // can return NULL + env->term = term_new(env->mem, env->tty, false, false, -1 ); + env->history = history_new(env->mem); + env->completions = completions_new(env->mem); + env->bbcode = bbcode_new(env->mem, env->term); + env->hint_delay = 400; + + if (env->tty == NULL || env->term==NULL || + env->completions == NULL || env->history == NULL || env->bbcode == NULL || + !term_is_interactive(env->term)) + { + env->noedit = true; + } + env->multiline_eol = '\\'; + + bbcode_style_def(env->bbcode, "ic-prompt", "ansi-green" ); + bbcode_style_def(env->bbcode, "ic-info", "ansi-darkgray" ); + bbcode_style_def(env->bbcode, "ic-diminish", "ansi-lightgray" ); + bbcode_style_def(env->bbcode, "ic-emphasis", "#ffffd7" ); + bbcode_style_def(env->bbcode, "ic-hint", "ansi-darkgray" ); + bbcode_style_def(env->bbcode, "ic-error", "#d70000" ); + bbcode_style_def(env->bbcode, "ic-bracematch","ansi-white"); // color = #F7DC6F" ); + + bbcode_style_def(env->bbcode, "keyword", "#569cd6" ); + bbcode_style_def(env->bbcode, "control", "#c586c0" ); + bbcode_style_def(env->bbcode, "number", "#b5cea8" ); + bbcode_style_def(env->bbcode, "string", "#ce9178" ); + bbcode_style_def(env->bbcode, "comment", "#6A9955" ); + bbcode_style_def(env->bbcode, "type", "darkcyan" ); + bbcode_style_def(env->bbcode, "constant", "#569cd6" ); + + set_prompt_marker(env, NULL, NULL); + return env; +} + +static ic_env_t* rpenv; + +static void ic_atexit(void) { + if (rpenv != NULL) { + ic_env_free(rpenv); + rpenv = NULL; + } +} + +ic_private ic_env_t* ic_get_env(void) { + if (rpenv==NULL) { + rpenv = ic_env_create( NULL, NULL, NULL ); + if (rpenv != NULL) { atexit( &ic_atexit ); } + } + return rpenv; +} + +ic_public void ic_init_custom_malloc( ic_malloc_fun_t* _malloc, ic_realloc_fun_t* _realloc, ic_free_fun_t* _free ) { + assert(rpenv == NULL); + if (rpenv != NULL) { + ic_env_free(rpenv); + rpenv = ic_env_create( _malloc, _realloc, _free ); + } + else { + rpenv = ic_env_create( _malloc, _realloc, _free ); + if (rpenv != NULL) { + atexit( &ic_atexit ); + } + } +} + diff --git a/extern/isocline/src/stringbuf.c b/extern/isocline/src/stringbuf.c new file mode 100644 index 00000000..7bbfad04 --- /dev/null +++ b/extern/isocline/src/stringbuf.c @@ -0,0 +1,1038 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +// get `wcwidth` for the column width of unicode characters +// note: for now the OS provided one is unused as we see quite a bit of variation +// among platforms and including our own seems more reliable. +/* +#if defined(__linux__) || defined(__freebsd__) +// use the system supplied one +#if !defined(_XOPEN_SOURCE) +#define _XOPEN_SOURCE 700 // so wcwidth is visible +#endif +#include +#else +*/ +// use our own (also on APPLE as that fails within vscode) +#define wcwidth(c) mk_wcwidth(c) +#include "wcwidth.c" +// #endif + +#include +#include +#include + +#include "common.h" +#include "stringbuf.h" + +//------------------------------------------------------------- +// In place growable utf-8 strings +//------------------------------------------------------------- + +struct stringbuf_s { + char* buf; + ssize_t buflen; + ssize_t count; + alloc_t* mem; +}; + + +//------------------------------------------------------------- +// String column width +//------------------------------------------------------------- + +// column width of a utf8 single character sequence. +static ssize_t utf8_char_width( const char* s, ssize_t n ) { + if (n <= 0) return 0; + + uint8_t b = (uint8_t)s[0]; + int32_t c; + if (b < ' ') { + return 0; + } + else if (b <= 0x7F) { + return 1; + } + else if (b <= 0xC1) { // invalid continuation byte or invalid 0xC0, 0xC1 (check is strictly not necessary as we don't validate..) + return 1; + } + else if (b <= 0xDF && n >= 2) { // b >= 0xC2 // 2 bytes + c = (((b & 0x1F) << 6) | (s[1] & 0x3F)); + assert(c < 0xD800 || c > 0xDFFF); + int w = wcwidth(c); + return w; + } + else if (b <= 0xEF && n >= 3) { // b >= 0xE0 // 3 bytes + c = (((b & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F)); + return wcwidth(c); + } + else if (b <= 0xF4 && n >= 4) { // b >= 0xF0 // 4 bytes + c = (((b & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F)); + return wcwidth(c); + } + else { + // failed + return 1; + } +} + + +// The column width of a codepoint (0, 1, or 2) +static ssize_t char_column_width( const char* s, ssize_t n ) { + if (s == NULL || n <= 0) return 0; + else if ((uint8_t)(*s) < ' ') return 0; // also for CSI escape sequences + else { + ssize_t w = utf8_char_width(s, n); + #ifdef _WIN32 + return (w <= 0 ? 1 : w); // windows console seems to use at least one column + #else + return w; + #endif + } +} + +static ssize_t str_column_width_n( const char* s, ssize_t len ) { + if (s == NULL || len <= 0) return 0; + ssize_t pos = 0; + ssize_t cwidth = 0; + ssize_t cw; + ssize_t ofs; + while (s[pos] != 0 && (ofs = str_next_ofs(s, len, pos, &cw)) > 0) { + cwidth += cw; + pos += ofs; + } + return cwidth; +} + +ic_private ssize_t str_column_width( const char* s ) { + return str_column_width_n( s, ic_strlen(s) ); +} + +ic_private ssize_t str_skip_until_fit( const char* s, ssize_t max_width ) { + if (s == NULL) return 0; + ssize_t cwidth = str_column_width(s); + ssize_t len = ic_strlen(s); + ssize_t pos = 0; + ssize_t next; + ssize_t cw; + while (cwidth > max_width && (next = str_next_ofs(s, len, pos, &cw)) > 0) { + cwidth -= cw; + pos += next; + } + return pos; +} + +ic_private ssize_t str_take_while_fit( const char* s, ssize_t max_width) { + if (s == NULL) return 0; + const ssize_t len = ic_strlen(s); + ssize_t pos = 0; + ssize_t next; + ssize_t cw; + ssize_t cwidth = 0; + while ((next = str_next_ofs(s, len, pos, &cw)) > 0) { + if (cwidth + cw > max_width) break; + cwidth += cw; + pos += next; + } + return pos; +} + + +//------------------------------------------------------------- +// String navigation +//------------------------------------------------------------- + +// get offset of the previous codepoint. does not skip back over CSI sequences. +ic_private ssize_t str_prev_ofs( const char* s, ssize_t pos, ssize_t* width ) { + ssize_t ofs = 0; + if (s != NULL && pos > 0) { + ofs = 1; + while (pos > ofs) { + uint8_t u = (uint8_t)s[pos - ofs]; + if (u < 0x80 || u > 0xBF) break; // continue while follower + ofs++; + } + } + if (width != NULL) *width = char_column_width( s+(pos-ofs), ofs ); + return ofs; +} + +// skip an escape sequence +// +ic_private bool skip_esc( const char* s, ssize_t len, ssize_t* esclen ) { + if (s == NULL || len <= 1 || s[0] != '\x1B') return false; + if (esclen != NULL) *esclen = 0; + if (strchr("[PX^_]",s[1]) != NULL) { + // CSI (ESC [), DCS (ESC P), SOS (ESC X), PM (ESC ^), APC (ESC _), and OSC (ESC ]): terminated with a special sequence + bool finalCSI = (s[1] == '['); // CSI terminates with 0x40-0x7F; otherwise ST (bell or ESC \) + ssize_t n = 2; + while (len > n) { + char c = s[n++]; + if ((finalCSI && (uint8_t)c >= 0x40 && (uint8_t)c <= 0x7F) || // terminating byte: @A–Z[\]^_`a–z{|}~ + (!finalCSI && c == '\x07') || // bell + (c == '\x02')) // STX terminates as well + { + if (esclen != NULL) *esclen = n; + return true; + } + else if (!finalCSI && c == '\x1B' && len > n && s[n] == '\\') { // ST (ESC \) + n++; + if (esclen != NULL) *esclen = n; + return true; + } + } + } + if (strchr(" #%()*+",s[1]) != NULL) { + // assume escape sequence of length 3 (like ESC % G) + if (esclen != NULL) *esclen = 2; + return true; + } + else { + // assume single character escape code (like ESC 7) + if (esclen != NULL) *esclen = 2; + return true; + } + return false; +} + +// Offset to the next codepoint, treats CSI escape sequences as a single code point. +ic_private ssize_t str_next_ofs( const char* s, ssize_t len, ssize_t pos, ssize_t* cwidth ) { + ssize_t ofs = 0; + if (s != NULL && len > pos) { + if (skip_esc(s+pos,len-pos,&ofs)) { + // skip escape sequence + } + else { + ofs = 1; + // utf8 extended character? + while(len > pos + ofs) { + uint8_t u = (uint8_t)s[pos + ofs]; + if (u < 0x80 || u > 0xBF) break; // break if not a follower + ofs++; + } + } + } + if (cwidth != NULL) *cwidth = char_column_width( s+pos, ofs ); + return ofs; +} + +static ssize_t str_limit_to_length( const char* s, ssize_t n ) { + ssize_t i; + for(i = 0; i < n && s[i] != 0; i++) { /* nothing */ } + return i; +} + + +//------------------------------------------------------------- +// String searching prev/next word, line, ws_word +//------------------------------------------------------------- + + +static ssize_t str_find_backward( const char* s, ssize_t len, ssize_t pos, ic_is_char_class_fun_t* match, bool skip_immediate_matches ) { + if (pos > len) pos = len; + if (pos < 0) pos = 0; + ssize_t i = pos; + // skip matching first (say, whitespace in case of the previous start-of-word) + if (skip_immediate_matches) { + do { + ssize_t prev = str_prev_ofs(s, i, NULL); + if (prev <= 0) break; + assert(i - prev >= 0); + if (!match(s + i - prev, (long)prev)) break; + i -= prev; + } while (i > 0); + } + // find match + do { + ssize_t prev = str_prev_ofs(s, i, NULL); + if (prev <= 0) break; + assert(i - prev >= 0); + if (match(s + i - prev, (long)prev)) { + return i; // found; + } + i -= prev; + } while (i > 0); + return -1; // not found +} + +static ssize_t str_find_forward( const char* s, ssize_t len, ssize_t pos, ic_is_char_class_fun_t* match, bool skip_immediate_matches ) { + if (s == NULL || len < 0) return -1; + if (pos > len) pos = len; + if (pos < 0) pos = 0; + ssize_t i = pos; + ssize_t next; + // skip matching first (say, whitespace in case of the next end-of-word) + if (skip_immediate_matches) { + do { + next = str_next_ofs(s, len, i, NULL); + if (next <= 0) break; + assert( i + next <= len); + if (!match(s + i, (long)next)) break; + i += next; + } while (i < len); + } + // and then look + do { + next = str_next_ofs(s, len, i, NULL); + if (next <= 0) break; + assert( i + next <= len); + if (match(s + i, (long)next)) { + return i; // found + } + i += next; + } while (i < len); + return -1; +} + +static bool char_is_linefeed( const char* s, long n ) { + return (n == 1 && (*s == '\n' || *s == 0)); +} + +static ssize_t str_find_line_start( const char* s, ssize_t len, ssize_t pos) { + ssize_t start = str_find_backward(s,len,pos,&char_is_linefeed,false /* don't skip immediate matches */); + return (start < 0 ? 0 : start); +} + +static ssize_t str_find_line_end( const char* s, ssize_t len, ssize_t pos) { + ssize_t end = str_find_forward(s,len,pos, &char_is_linefeed, false); + return (end < 0 ? len : end); +} + +static ssize_t str_find_word_start( const char* s, ssize_t len, ssize_t pos) { + ssize_t start = str_find_backward(s,len,pos, &ic_char_is_idletter,true /* skip immediate matches */); + return (start < 0 ? 0 : start); +} + +static ssize_t str_find_word_end( const char* s, ssize_t len, ssize_t pos) { + ssize_t end = str_find_forward(s,len,pos,&ic_char_is_idletter,true /* skip immediate matches */); + return (end < 0 ? len : end); +} + +static ssize_t str_find_ws_word_start( const char* s, ssize_t len, ssize_t pos) { + ssize_t start = str_find_backward(s,len,pos,&ic_char_is_white,true /* skip immediate matches */); + return (start < 0 ? 0 : start); +} + +static ssize_t str_find_ws_word_end( const char* s, ssize_t len, ssize_t pos) { + ssize_t end = str_find_forward(s,len,pos,&ic_char_is_white,true /* skip immediate matches */); + return (end < 0 ? len : end); +} + + +//------------------------------------------------------------- +// String row/column iteration +//------------------------------------------------------------- + +// invoke a function for each terminal row; returns total row count. +static ssize_t str_for_each_row( const char* s, ssize_t len, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + row_fun_t* fun, const void* arg, void* res ) +{ + if (s == NULL) s = ""; + ssize_t i; + ssize_t rcount = 0; + ssize_t rcol = 0; + ssize_t rstart = 0; + ssize_t startw = promptw; + for(i = 0; i < len; ) { + ssize_t w; + ssize_t next = str_next_ofs(s, len, i, &w); + if (next <= 0) { + debug_msg("str: foreach row: next<=0: len %zd, i %zd, w %zd, buf %s\n", len, i, w, s ); + assert(false); + break; + } + startw = (rcount == 0 ? promptw : cpromptw); + ssize_t termcol = rcol + w + startw + 1 /* for the cursor */; + if (termw != 0 && i != 0 && termcol >= termw) { + // wrap + if (fun != NULL) { + if (fun(s,rcount,rstart,i - rstart,startw,true,arg,res)) return rcount; + } + rcount++; + rstart = i; + rcol = 0; + } + if (s[i] == '\n') { + // newline + if (fun != NULL) { + if (fun(s,rcount,rstart,i - rstart,startw,false,arg,res)) return rcount; + } + rcount++; + rstart = i+1; + rcol = 0; + } + assert (s[i] != 0); + i += next; + rcol += w; + } + if (fun != NULL) { + if (fun(s,rcount,rstart,i - rstart,startw,false,arg,res)) return rcount; + } + return rcount+1; +} + +//------------------------------------------------------------- +// String: get row/column position +//------------------------------------------------------------- + + +static bool str_get_current_pos_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(is_wrap); ic_unused(startw); + rowcol_t* rc = (rowcol_t*)res; + ssize_t pos = *((ssize_t*)arg); + + if (pos >= row_start && pos <= (row_start + row_len)) { + // found the cursor row + rc->row_start = row_start; + rc->row_len = row_len; + rc->row = row; + rc->col = str_column_width_n( s + row_start, pos - row_start ); + rc->first_on_row = (pos == row_start); + if (is_wrap) { + // if wrapped, we check if the next character is at row_len + ssize_t next = str_next_ofs(s, row_start + row_len, pos, NULL); + rc->last_on_row = (pos + next >= row_start + row_len); + } + else { + // normal last position is right after the last character + rc->last_on_row = (pos >= row_start + row_len); + } + // debug_msg("edit; pos iter: pos: %zd (%c), row_start: %zd, rowlen: %zd\n", pos, s[pos], row_start, row_len); + } + return false; // always continue to count all rows +} + +static ssize_t str_get_rc_at_pos(const char* s, ssize_t len, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc) { + memset(rc, 0, sizeof(*rc)); + ssize_t rows = str_for_each_row(s, len, termw, promptw, cpromptw, &str_get_current_pos_iter, &pos, rc); + // debug_msg("edit: current pos: (%d, %d) %s %s\n", rc->row, rc->col, rc->first_on_row ? "first" : "", rc->last_on_row ? "last" : ""); + return rows; +} + + + +//------------------------------------------------------------- +// String: get row/column position for a resized terminal +// with potentially "hard-wrapped" rows +//------------------------------------------------------------- +typedef struct wrapped_arg_s { + ssize_t pos; + ssize_t newtermw; +} wrapped_arg_t; + +typedef struct wrowcol_s { + rowcol_t rc; + ssize_t hrows; // count of hard-wrapped extra rows +} wrowcol_t; + +static bool str_get_current_wrapped_pos_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(is_wrap); + wrowcol_t* wrc = (wrowcol_t*)res; + const wrapped_arg_t* warg = (const wrapped_arg_t*)arg; + + // iterate through the row and record the postion and hard-wraps + ssize_t hwidth = startw; + ssize_t i = 0; + while( i <= row_len ) { // include rowlen as the cursor position can be just after the last character + // get next position and column width + ssize_t cw; + ssize_t next; + bool is_cursor = (warg->pos == row_start+i); + if (i < row_len) { + next = str_next_ofs(s + row_start, row_len, i, &cw); + } + else { + // end of row: take wrap or cursor into account + // (wrap has width 2 as it displays a back-arrow but also has an invisible newline that wraps) + cw = (is_wrap ? 2 : (is_cursor ? 1 : 0)); + next = 1; + } + + if (next > 0) { + if (hwidth + cw > warg->newtermw) { + // hardwrap + hwidth = 0; + wrc->hrows++; + debug_msg("str: found hardwrap: row: %zd, hrows: %zd\n", row, wrc->hrows); + } + } + else { + next++; // ensure we terminate (as we go up to rowlen) + } + + // did we find our position? + if (is_cursor) { + debug_msg("str: found position: row: %zd, hrows: %zd\n", row, wrc->hrows); + wrc->rc.row_start = row_start; + wrc->rc.row_len = row_len; + wrc->rc.row = wrc->hrows + row; + wrc->rc.col = hwidth; + wrc->rc.first_on_row = (i==0); + wrc->rc.last_on_row = (i+next >= row_len - (is_wrap ? 1 : 0)); + } + + // advance + hwidth += cw; + i += next; + } + return false; // always continue to count all rows +} + + +static ssize_t str_get_wrapped_rc_at_pos(const char* s, ssize_t len, ssize_t termw, ssize_t newtermw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc) { + wrapped_arg_t warg; + warg.pos = pos; + warg.newtermw = newtermw; + wrowcol_t wrc; + memset(&wrc,0,sizeof(wrc)); + ssize_t rows = str_for_each_row(s, len, termw, promptw, cpromptw, &str_get_current_wrapped_pos_iter, &warg, &wrc); + debug_msg("edit: wrapped pos: (%zd,%zd) rows %zd %s %s, hrows: %zd\n", wrc.rc.row, wrc.rc.col, rows, wrc.rc.first_on_row ? "first" : "", wrc.rc.last_on_row ? "last" : "", wrc.hrows); + *rc = wrc.rc; + return (rows + wrc.hrows); +} + + +//------------------------------------------------------------- +// Set position +//------------------------------------------------------------- + +static bool str_set_pos_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(arg); ic_unused(is_wrap); ic_unused(startw); + rowcol_t* rc = (rowcol_t*)arg; + if (rc->row != row) return false; // keep searching + // we found our row + ssize_t col = 0; + ssize_t i = row_start; + ssize_t end = row_start + row_len; + while (col < rc->col && i < end) { + ssize_t cw; + ssize_t next = str_next_ofs(s, row_start + row_len, i, &cw); + if (next <= 0) break; + i += next; + col += cw; + } + *((ssize_t*)res) = i; + return true; // stop iteration +} + +static ssize_t str_get_pos_at_rc(const char* s, ssize_t len, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t row, ssize_t col /* without prompt */) { + rowcol_t rc; + memset(&rc,0,ssizeof(rc)); + rc.row = row; + rc.col = col; + ssize_t pos = -1; + str_for_each_row(s,len,termw,promptw,cpromptw,&str_set_pos_iter,&rc,&pos); + return pos; +} + + +//------------------------------------------------------------- +// String buffer +//------------------------------------------------------------- +static bool sbuf_ensure_extra(stringbuf_t* s, ssize_t extra) +{ + if (s->buflen >= s->count + extra) return true; + // reallocate; pick good initial size and multiples to increase reuse on allocation + ssize_t newlen = (s->buflen <= 0 ? 120 : (s->buflen > 1000 ? s->buflen + 1000 : 2*s->buflen)); + if (newlen < s->count + extra) newlen = s->count + extra; + if (s->buflen > 0) { + debug_msg("stringbuf: reallocate: old %zd, new %zd\n", s->buflen, newlen); + } + char* newbuf = mem_realloc_tp(s->mem, char, s->buf, newlen+1); // one more for terminating zero + if (newbuf == NULL) { + assert(false); + return false; + } + s->buf = newbuf; + s->buflen = newlen; + s->buf[s->count] = s->buf[s->buflen] = 0; + assert(s->buflen >= s->count + extra); + return true; +} + +static void sbuf_init( stringbuf_t* sbuf, alloc_t* mem ) { + sbuf->mem = mem; + sbuf->buf = NULL; + sbuf->buflen = 0; + sbuf->count = 0; +} + +static void sbuf_done( stringbuf_t* sbuf ) { + mem_free( sbuf->mem, sbuf->buf ); + sbuf->buf = NULL; + sbuf->buflen = 0; + sbuf->count = 0; +} + + +ic_private void sbuf_free( stringbuf_t* sbuf ) { + if (sbuf==NULL) return; + sbuf_done(sbuf); + mem_free(sbuf->mem, sbuf); +} + +ic_private stringbuf_t* sbuf_new( alloc_t* mem ) { + stringbuf_t* sbuf = mem_zalloc_tp(mem,stringbuf_t); + if (sbuf == NULL) return NULL; + sbuf_init(sbuf,mem); + return sbuf; +} + +// free the sbuf and return the current string buffer as the result +ic_private char* sbuf_free_dup(stringbuf_t* sbuf) { + if (sbuf == NULL) return NULL; + char* s = NULL; + if (sbuf->buf != NULL) { + s = mem_realloc_tp(sbuf->mem, char, sbuf->buf, sbuf_len(sbuf)+1); + if (s == NULL) { s = sbuf->buf; } + sbuf->buf = 0; + sbuf->buflen = 0; + sbuf->count = 0; + } + sbuf_free(sbuf); + return s; +} + +ic_private const char* sbuf_string_at( stringbuf_t* sbuf, ssize_t pos ) { + if (pos < 0 || sbuf->count < pos) return NULL; + if (sbuf->buf == NULL) return ""; + assert(sbuf->buf[sbuf->count] == 0); + return sbuf->buf + pos; +} + +ic_private const char* sbuf_string( stringbuf_t* sbuf ) { + return sbuf_string_at( sbuf, 0 ); +} + +ic_private char sbuf_char_at(stringbuf_t* sbuf, ssize_t pos) { + if (sbuf->buf == NULL || pos < 0 || sbuf->count < pos) return 0; + return sbuf->buf[pos]; +} + +ic_private char* sbuf_strdup_at( stringbuf_t* sbuf, ssize_t pos ) { + return mem_strdup(sbuf->mem, sbuf_string_at(sbuf,pos)); +} + +ic_private char* sbuf_strdup( stringbuf_t* sbuf ) { + return mem_strdup(sbuf->mem, sbuf_string(sbuf)); +} + +ic_private ssize_t sbuf_len(const stringbuf_t* s) { + if (s == NULL) return 0; + return s->count; +} + +ic_private ssize_t sbuf_append_vprintf(stringbuf_t* sb, const char* fmt, va_list args) { + const ssize_t min_needed = ic_strlen(fmt); + if (!sbuf_ensure_extra(sb,min_needed + 16)) return sb->count; + ssize_t avail = sb->buflen - sb->count; + va_list args0; + va_copy(args0, args); + ssize_t needed = vsnprintf(sb->buf + sb->count, to_size_t(avail), fmt, args0); + if (needed > avail) { + sb->buf[sb->count] = 0; + if (!sbuf_ensure_extra(sb, needed)) return sb->count; + avail = sb->buflen - sb->count; + needed = vsnprintf(sb->buf + sb->count, to_size_t(avail), fmt, args); + } + assert(needed <= avail); + sb->count += (needed > avail ? avail : (needed >= 0 ? needed : 0)); + assert(sb->count <= sb->buflen); + sb->buf[sb->count] = 0; + return sb->count; +} + +ic_private ssize_t sbuf_appendf(stringbuf_t* sb, const char* fmt, ...) { + va_list args; + va_start( args, fmt); + ssize_t res = sbuf_append_vprintf( sb, fmt, args ); + va_end(args); + return res; +} + + +ic_private ssize_t sbuf_insert_at_n(stringbuf_t* sbuf, const char* s, ssize_t n, ssize_t pos ) { + if (pos < 0 || pos > sbuf->count || s == NULL) return pos; + n = str_limit_to_length(s,n); + if (n <= 0 || !sbuf_ensure_extra(sbuf,n)) return pos; + ic_memmove(sbuf->buf + pos + n, sbuf->buf + pos, sbuf->count - pos); + ic_memcpy(sbuf->buf + pos, s, n); + sbuf->count += n; + sbuf->buf[sbuf->count] = 0; + return (pos + n); +} + +ic_private stringbuf_t* sbuf_split_at( stringbuf_t* sb, ssize_t pos ) { + stringbuf_t* res = sbuf_new(sb->mem); + if (res==NULL || pos < 0) return NULL; + if (pos < sb->count) { + sbuf_append_n(res, sb->buf + pos, sb->count - pos); + sb->count = pos; + } + return res; +} + +ic_private ssize_t sbuf_insert_at(stringbuf_t* sbuf, const char* s, ssize_t pos ) { + return sbuf_insert_at_n( sbuf, s, ic_strlen(s), pos ); +} + +ic_private ssize_t sbuf_insert_char_at(stringbuf_t* sbuf, char c, ssize_t pos ) { + char s[2]; + s[0] = c; + s[1] = 0; + return sbuf_insert_at_n( sbuf, s, 1, pos); +} + +ic_private ssize_t sbuf_insert_unicode_at(stringbuf_t* sbuf, unicode_t u, ssize_t pos) { + uint8_t s[5]; + unicode_to_qutf8(u, s); + return sbuf_insert_at(sbuf, (const char*)s, pos); +} + + + +ic_private void sbuf_delete_at( stringbuf_t* sbuf, ssize_t pos, ssize_t count ) { + if (pos < 0 || pos >= sbuf->count) return; + if (pos + count > sbuf->count) count = sbuf->count - pos; + ic_memmove(sbuf->buf + pos, sbuf->buf + pos + count, sbuf->count - pos - count); + sbuf->count -= count; + sbuf->buf[sbuf->count] = 0; +} + +ic_private void sbuf_delete_from_to( stringbuf_t* sbuf, ssize_t pos, ssize_t end ) { + if (end <= pos) return; + sbuf_delete_at( sbuf, pos, end - pos); +} + +ic_private void sbuf_delete_from(stringbuf_t* sbuf, ssize_t pos ) { + sbuf_delete_at(sbuf, pos, sbuf_len(sbuf) - pos ); +} + + +ic_private void sbuf_clear( stringbuf_t* sbuf ) { + sbuf_delete_at(sbuf, 0, sbuf_len(sbuf)); +} + +ic_private ssize_t sbuf_append_n( stringbuf_t* sbuf, const char* s, ssize_t n ) { + return sbuf_insert_at_n( sbuf, s, n, sbuf_len(sbuf)); +} + +ic_private ssize_t sbuf_append( stringbuf_t* sbuf, const char* s ) { + return sbuf_insert_at( sbuf, s, sbuf_len(sbuf)); +} + +ic_private ssize_t sbuf_append_char( stringbuf_t* sbuf, char c ) { + char buf[2]; + buf[0] = c; + buf[1] = 0; + return sbuf_append( sbuf, buf ); +} + +ic_private void sbuf_replace(stringbuf_t* sbuf, const char* s) { + sbuf_clear(sbuf); + sbuf_append(sbuf,s); +} + +ic_private ssize_t sbuf_next_ofs( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ) { + return str_next_ofs( sbuf->buf, sbuf->count, pos, cwidth); +} + +ic_private ssize_t sbuf_prev_ofs( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ) { + return str_prev_ofs( sbuf->buf, pos, cwidth); +} + +ic_private ssize_t sbuf_next( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth) { + ssize_t ofs = sbuf_next_ofs(sbuf,pos,cwidth); + if (ofs <= 0) return -1; + assert(pos + ofs <= sbuf->count); + return pos + ofs; +} + +ic_private ssize_t sbuf_prev( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth) { + ssize_t ofs = sbuf_prev_ofs(sbuf,pos,cwidth); + if (ofs <= 0) return -1; + assert(pos - ofs >= 0); + return pos - ofs; +} + +ic_private ssize_t sbuf_delete_char_before( stringbuf_t* sbuf, ssize_t pos ) { + ssize_t n = sbuf_prev_ofs(sbuf, pos, NULL); + if (n <= 0) return 0; + assert( pos - n >= 0 ); + sbuf_delete_at(sbuf, pos - n, n); + return pos - n; +} + +ic_private void sbuf_delete_char_at( stringbuf_t* sbuf, ssize_t pos ) { + ssize_t n = sbuf_next_ofs(sbuf, pos, NULL); + if (n <= 0) return; + assert( pos + n <= sbuf->count ); + sbuf_delete_at(sbuf, pos, n); + return; +} + +ic_private ssize_t sbuf_swap_char( stringbuf_t* sbuf, ssize_t pos ) { + ssize_t next = sbuf_next_ofs(sbuf, pos, NULL); + if (next <= 0) return 0; + ssize_t prev = sbuf_prev_ofs(sbuf, pos, NULL); + if (prev <= 0) return 0; + char buf[64]; + if (prev >= 63) return 0; + ic_memcpy(buf, sbuf->buf + pos - prev, prev ); + ic_memmove(sbuf->buf + pos - prev, sbuf->buf + pos, next); + ic_memmove(sbuf->buf + pos - prev + next, buf, prev); + return pos - prev; +} + +ic_private ssize_t sbuf_find_line_start( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_line_start( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_line_end( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_line_end( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_word_start( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_word_start( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_word_end( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_word_end( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_ws_word_start( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_ws_word_start( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_ws_word_end( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_ws_word_end( sbuf->buf, sbuf->count, pos); +} + +// find row/col position +ic_private ssize_t sbuf_get_pos_at_rc( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t row, ssize_t col ) { + return str_get_pos_at_rc( sbuf->buf, sbuf->count, termw, promptw, cpromptw, row, col); +} + +// get row/col for a given position +ic_private ssize_t sbuf_get_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc ) { + return str_get_rc_at_pos( sbuf->buf, sbuf->count, termw, promptw, cpromptw, pos, rc); +} + +ic_private ssize_t sbuf_get_wrapped_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t newtermw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc ) { + return str_get_wrapped_rc_at_pos( sbuf->buf, sbuf->count, termw, newtermw, promptw, cpromptw, pos, rc); +} + +ic_private ssize_t sbuf_for_each_row( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, row_fun_t* fun, void* arg, void* res ) { + if (sbuf == NULL) return 0; + return str_for_each_row( sbuf->buf, sbuf->count, termw, promptw, cpromptw, fun, arg, res); +} + + +// Duplicate and decode from utf-8 (for non-utf8 terminals) +ic_private char* sbuf_strdup_from_utf8(stringbuf_t* sbuf) { + ssize_t len = sbuf_len(sbuf); + if (sbuf == NULL || len <= 0) return NULL; + char* s = mem_zalloc_tp_n(sbuf->mem, char, len); + if (s == NULL) return NULL; + ssize_t dest = 0; + for (ssize_t i = 0; i < len; ) { + ssize_t ofs = sbuf_next_ofs(sbuf, i, NULL); + if (ofs <= 0) { + // invalid input + break; + } + else if (ofs == 1) { + // regular character + s[dest++] = sbuf->buf[i]; + } + else if (sbuf->buf[i] == '\x1B') { + // skip escape sequences + } + else { + // decode unicode + ssize_t nread; + unicode_t uchr = unicode_from_qutf8( (const uint8_t*)(sbuf->buf + i), ofs, &nread); + uint8_t c; + if (unicode_is_raw(uchr, &c)) { + // raw byte, output as is (this will take care of locale specific input) + s[dest++] = (char)c; + } + else if (uchr <= 0x7F) { + // allow ascii + s[dest++] = (char)uchr; + } + else { + // skip unknown unicode characters.. + // todo: convert according to locale? + } + } + i += ofs; + } + assert(dest <= len); + s[dest] = 0; + return s; +} + +//------------------------------------------------------------- +// String helpers +//------------------------------------------------------------- + +ic_public long ic_prev_char( const char* s, long pos ) { + ssize_t len = ic_strlen(s); + if (pos < 0 || pos > len) return -1; + ssize_t ofs = str_prev_ofs( s, pos, NULL ); + if (ofs <= 0) return -1; + return (long)(pos - ofs); +} + +ic_public long ic_next_char( const char* s, long pos ) { + ssize_t len = ic_strlen(s); + if (pos < 0 || pos > len) return -1; + ssize_t ofs = str_next_ofs( s, len, pos, NULL ); + if (ofs <= 0) return -1; + return (long)(pos + ofs); +} + + +// parse a decimal (leave pi unchanged on error) +ic_private bool ic_atoz(const char* s, ssize_t* pi) { + return (sscanf(s, "%zd", pi) == 1); +} + +// parse two decimals separated by a semicolon +ic_private bool ic_atoz2(const char* s, ssize_t* pi, ssize_t* pj) { + return (sscanf(s, "%zd;%zd", pi, pj) == 2); +} + +// parse unsigned 32-bit (leave pu unchanged on error) +ic_private bool ic_atou32(const char* s, uint32_t* pu) { + return (sscanf(s, "%" SCNu32, pu) == 1); +} + + +// Convenience: character class for whitespace `[ \t\r\n]`. +ic_public bool ic_char_is_white(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return (c==' ' || c == '\t' || c == '\n' || c == '\r'); +} + +// Convenience: character class for non-whitespace `[^ \t\r\n]`. +ic_public bool ic_char_is_nonwhite(const char* s, long len) { + return !ic_char_is_white(s, len); +} + +// Convenience: character class for separators `[ \t\r\n,.;:/\\\(\)\{\}\[\]]`. +ic_public bool ic_char_is_separator(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return (strchr(" \t\r\n,.;:/\\(){}[]", c) != NULL); +} + +// Convenience: character class for non-separators. +ic_public bool ic_char_is_nonseparator(const char* s, long len) { + return !ic_char_is_separator(s, len); +} + + +// Convenience: character class for digits (`[0-9]`). +ic_public bool ic_char_is_digit(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return (c >= '0' && c <= '9'); +} + +// Convenience: character class for hexadecimal digits (`[A-Fa-f0-9]`). +ic_public bool ic_char_is_hexdigit(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')); +} + +// Convenience: character class for letters (`[A-Za-z]` and any unicode > 0x80). +ic_public bool ic_char_is_letter(const char* s, long len) { + if (s == NULL || len <= 0) return false; + const char c = *s; + return ((uint8_t)c >= 0x80 || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')); +} + +// Convenience: character class for identifier letters (`[A-Za-z0-9_-]` and any unicode > 0x80). +ic_public bool ic_char_is_idletter(const char* s, long len) { + if (s == NULL || len <= 0) return false; + const char c = *s; + return ((uint8_t)c >= 0x80 || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || (c == '_') || (c == '-')); +} + +// Convenience: character class for filename letters (`[^ \t\r\n`@$><=;|&{(]`). +ic_public bool ic_char_is_filename_letter(const char* s, long len) { + if (s == NULL || len <= 0) return false; + const char c = *s; + return ((uint8_t)c >= 0x80 || (strchr(" \t\r\n`@$><=;|&{}()[]", c) == NULL)); +} + +// Convenience: If this is a token start, returns the length (or <= 0 if not found). +ic_public long ic_is_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char) { + if (s == NULL || pos < 0 || is_token_char == NULL) return -1; + ssize_t len = ic_strlen(s); + if (pos >= len) return -1; + if (pos > 0 && is_token_char(s + pos -1, 1)) return -1; // token start? + ssize_t i = pos; + while ( i < len ) { + ssize_t next = str_next_ofs(s, len, i, NULL); + if (next <= 0) return -1; + if (!is_token_char(s + i, (long)next)) break; + i += next; + } + return (long)(i - pos); +} + + +static int ic_strncmp(const char* s1, const char* s2, ssize_t n) { + return strncmp(s1, s2, to_size_t(n)); +} + +// Convenience: Does this match the specified token? +// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +// E.g. `ic_match_token("function",0,&ic_char_is_letter,"fun")` returns 0. +ic_public long ic_match_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char* token) { + long n = ic_is_token(s, pos, is_token_char); + if (n > 0 && token != NULL && n == ic_strlen(token) && ic_strncmp(s + pos, token, n) == 0) { + return n; + } + else { + return 0; + } +} + + +// Convenience: Do any of the specified tokens match? +// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +// Ensures not to match prefixes or suffixes. +// E.g. `ic_match_any_token("function",0,&ic_char_is_letter,{"fun","func",NULL})` returns 0. +ic_public long ic_match_any_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char** tokens) { + long n = ic_is_token(s, pos, is_token_char); + if (n <= 0 || tokens == NULL) return 0; + for (const char** token = tokens; *token != NULL; token++) { + if (n == ic_strlen(*token) && ic_strncmp(s + pos, *token, n) == 0) { + return n; + } + } + return 0; +} + diff --git a/extern/isocline/src/stringbuf.h b/extern/isocline/src/stringbuf.h new file mode 100644 index 00000000..39b21ea4 --- /dev/null +++ b/extern/isocline/src/stringbuf.h @@ -0,0 +1,121 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_STRINGBUF_H +#define IC_STRINGBUF_H + +#include +#include "common.h" + +//------------------------------------------------------------- +// string buffer +// in-place modified buffer with edit operations +// that grows on demand. +//------------------------------------------------------------- + +// abstract string buffer +struct stringbuf_s; +typedef struct stringbuf_s stringbuf_t; + +ic_private stringbuf_t* sbuf_new( alloc_t* mem ); +ic_private void sbuf_free( stringbuf_t* sbuf ); +ic_private char* sbuf_free_dup(stringbuf_t* sbuf); +ic_private ssize_t sbuf_len(const stringbuf_t* s); + +ic_private const char* sbuf_string_at( stringbuf_t* sbuf, ssize_t pos ); +ic_private const char* sbuf_string( stringbuf_t* sbuf ); +ic_private char sbuf_char_at(stringbuf_t* sbuf, ssize_t pos); +ic_private char* sbuf_strdup_at( stringbuf_t* sbuf, ssize_t pos ); +ic_private char* sbuf_strdup( stringbuf_t* sbuf ); +ic_private char* sbuf_strdup_from_utf8(stringbuf_t* sbuf); // decode to locale + + +ic_private ssize_t sbuf_appendf(stringbuf_t* sb, const char* fmt, ...); +ic_private ssize_t sbuf_append_vprintf(stringbuf_t* sb, const char* fmt, va_list args); + +ic_private stringbuf_t* sbuf_split_at( stringbuf_t* sb, ssize_t pos ); + +// primitive edit operations (inserts return the new position) +ic_private void sbuf_clear(stringbuf_t* sbuf); +ic_private void sbuf_replace(stringbuf_t* sbuf, const char* s); +ic_private void sbuf_delete_at(stringbuf_t* sbuf, ssize_t pos, ssize_t count); +ic_private void sbuf_delete_from_to(stringbuf_t* sbuf, ssize_t pos, ssize_t end); +ic_private void sbuf_delete_from(stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_insert_at_n(stringbuf_t* sbuf, const char* s, ssize_t n, ssize_t pos ); +ic_private ssize_t sbuf_insert_at(stringbuf_t* sbuf, const char* s, ssize_t pos ); +ic_private ssize_t sbuf_insert_char_at(stringbuf_t* sbuf, char c, ssize_t pos ); +ic_private ssize_t sbuf_insert_unicode_at(stringbuf_t* sbuf, unicode_t u, ssize_t pos); +ic_private ssize_t sbuf_append_n(stringbuf_t* sbuf, const char* s, ssize_t n); +ic_private ssize_t sbuf_append(stringbuf_t* sbuf, const char* s); +ic_private ssize_t sbuf_append_char(stringbuf_t* sbuf, char c); + +// high level edit operations (return the new position) +ic_private ssize_t sbuf_next( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t sbuf_prev( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t sbuf_next_ofs(stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth); + +ic_private ssize_t sbuf_delete_char_before( stringbuf_t* sbuf, ssize_t pos ); +ic_private void sbuf_delete_char_at( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_swap_char( stringbuf_t* sbuf, ssize_t pos ); + +ic_private ssize_t sbuf_find_line_start( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_line_end( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_word_start( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_word_end( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_ws_word_start( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_ws_word_end( stringbuf_t* sbuf, ssize_t pos ); + +// parse a decimal +ic_private bool ic_atoz(const char* s, ssize_t* i); +// parse two decimals separated by a semicolon +ic_private bool ic_atoz2(const char* s, ssize_t* i, ssize_t* j); +ic_private bool ic_atou32(const char* s, uint32_t* pu); + +// row/column info +typedef struct rowcol_s { + ssize_t row; + ssize_t col; + ssize_t row_start; + ssize_t row_len; + bool first_on_row; + bool last_on_row; +} rowcol_t; + +// find row/col position +ic_private ssize_t sbuf_get_pos_at_rc( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + ssize_t row, ssize_t col ); +// get row/col for a given position +ic_private ssize_t sbuf_get_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + ssize_t pos, rowcol_t* rc ); + +ic_private ssize_t sbuf_get_wrapped_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t newtermw, ssize_t promptw, ssize_t cpromptw, + ssize_t pos, rowcol_t* rc ); + +// row iteration +typedef bool (row_fun_t)(const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, // prompt width + bool is_wrap, const void* arg, void* res); + +ic_private ssize_t sbuf_for_each_row( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + row_fun_t* fun, void* arg, void* res ); + + +//------------------------------------------------------------- +// Strings +//------------------------------------------------------------- + +// skip a single CSI sequence (ESC [ ...) +ic_private bool skip_csi_esc( const char* s, ssize_t len, ssize_t* esclen ); // used in term.c + +ic_private ssize_t str_column_width( const char* s ); +ic_private ssize_t str_prev_ofs( const char* s, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t str_next_ofs( const char* s, ssize_t len, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t str_skip_until_fit( const char* s, ssize_t max_width); // tail that fits +ic_private ssize_t str_take_while_fit( const char* s, ssize_t max_width); // prefix that fits + +#endif // IC_STRINGBUF_H diff --git a/extern/isocline/src/term.c b/extern/isocline/src/term.c new file mode 100644 index 00000000..c55d9ae5 --- /dev/null +++ b/extern/isocline/src/term.c @@ -0,0 +1,1124 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include +#include // getenv +#include + +#include "common.h" +#include "tty.h" +#include "term.h" +#include "stringbuf.h" // str_next_ofs + +#if defined(_WIN32) +#include +#define STDOUT_FILENO 1 +#else +#include +#include +#include +#if defined(__linux__) +#include +#endif +#endif + +#define IC_CSI "\x1B[" + +// color support; colors are auto mapped smaller palettes if needed. (see `term_color.c`) +typedef enum palette_e { + MONOCHROME, // no color + ANSI8, // only basic 8 ANSI color (ESC[m, idx: 30-37, +10 for background) + ANSI16, // basic + bright ANSI colors (ESC[m, idx: 30-37, 90-97, +10 for background) + ANSI256, // ANSI 256 color palette (ESC[38;5;m, idx: 0-15 standard color, 16-231 6x6x6 rbg colors, 232-255 gray shades) + ANSIRGB // direct rgb colors supported (ESC[38;2;;;m) +} palette_t; + +// The terminal screen +struct term_s { + int fd_out; // output handle + ssize_t width; // screen column width + ssize_t height; // screen row height + ssize_t raw_enabled; // is raw mode active? counted by start/end pairs + bool nocolor; // show colors? + bool silent; // enable beep? + bool is_utf8; // utf-8 output? determined by the tty + attr_t attr; // current text attributes + palette_t palette; // color support + buffer_mode_t bufmode; // buffer mode + stringbuf_t* buf; // buffer for buffered output + tty_t* tty; // used on posix to get the cursor position + alloc_t* mem; // allocator + #ifdef _WIN32 + HANDLE hcon; // output console handler + WORD hcon_default_attr; // default text attributes + WORD hcon_orig_attr; // original text attributes + DWORD hcon_orig_mode; // original console mode + DWORD hcon_mode; // used console mode + UINT hcon_orig_cp; // original console code-page (locale) + COORD hcon_save_cursor; // saved cursor position (for escape sequence emulation) + #endif +}; + +static bool term_write_direct(term_t* term, const char* s, ssize_t n ); +static void term_append_buf(term_t* term, const char* s, ssize_t n); + +//------------------------------------------------------------- +// Colors +//------------------------------------------------------------- + +#include "term_color.c" + +//------------------------------------------------------------- +// Helpers +//------------------------------------------------------------- + +ic_private void term_left(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdD", n ); +} + +ic_private void term_right(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdC", n ); +} + +ic_private void term_up(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdA", n ); +} + +ic_private void term_down(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdB", n ); +} + +ic_private void term_clear_line(term_t* term) { + term_write( term, "\r" IC_CSI "K"); +} + +ic_private void term_clear_to_end_of_line(term_t* term) { + term_write(term, IC_CSI "K"); +} + +ic_private void term_start_of_line(term_t* term) { + term_write( term, "\r" ); +} + +ic_private ssize_t term_get_width(term_t* term) { + return term->width; +} + +ic_private ssize_t term_get_height(term_t* term) { + return term->height; +} + +ic_private void term_attr_reset(term_t* term) { + term_write(term, IC_CSI "m" ); +} + +ic_private void term_underline(term_t* term, bool on) { + term_write(term, on ? IC_CSI "4m" : IC_CSI "24m" ); +} + +ic_private void term_reverse(term_t* term, bool on) { + term_write(term, on ? IC_CSI "7m" : IC_CSI "27m"); +} + +ic_private void term_bold(term_t* term, bool on) { + term_write(term, on ? IC_CSI "1m" : IC_CSI "22m" ); +} + +ic_private void term_italic(term_t* term, bool on) { + term_write(term, on ? IC_CSI "3m" : IC_CSI "23m" ); +} + +ic_private void term_writeln(term_t* term, const char* s) { + term_write(term,s); + term_write(term,"\n"); +} + +ic_private void term_write_char(term_t* term, char c) { + char buf[2]; + buf[0] = c; + buf[1] = 0; + term_write_n(term, buf, 1 ); +} + +ic_private attr_t term_get_attr( const term_t* term ) { + return term->attr; +} + +ic_private void term_set_attr( term_t* term, attr_t attr ) { + if (term->nocolor) return; + if (attr.x.color != term->attr.x.color && attr.x.color != IC_COLOR_NONE) { + term_color(term,attr.x.color); + if (term->palette < ANSIRGB && color_is_rgb(attr.x.color)) { + term->attr.x.color = attr.x.color; // actual color may have been approximated but we keep the actual color to avoid updating every time + } + } + if (attr.x.bgcolor != term->attr.x.bgcolor && attr.x.bgcolor != IC_COLOR_NONE) { + term_bgcolor(term,attr.x.bgcolor); + if (term->palette < ANSIRGB && color_is_rgb(attr.x.bgcolor)) { + term->attr.x.bgcolor = attr.x.bgcolor; + } + } + if (attr.x.bold != term->attr.x.bold && attr.x.bold != IC_NONE) { + term_bold(term,attr.x.bold == IC_ON); + } + if (attr.x.underline != term->attr.x.underline && attr.x.underline != IC_NONE) { + term_underline(term,attr.x.underline == IC_ON); + } + if (attr.x.reverse != term->attr.x.reverse && attr.x.reverse != IC_NONE) { + term_reverse(term,attr.x.reverse == IC_ON); + } + if (attr.x.italic != term->attr.x.italic && attr.x.italic != IC_NONE) { + term_italic(term,attr.x.italic == IC_ON); + } + assert(attr.x.color == term->attr.x.color || attr.x.color == IC_COLOR_NONE); + assert(attr.x.bgcolor == term->attr.x.bgcolor || attr.x.bgcolor == IC_COLOR_NONE); + assert(attr.x.bold == term->attr.x.bold || attr.x.bold == IC_NONE); + assert(attr.x.reverse == term->attr.x.reverse || attr.x.reverse == IC_NONE); + assert(attr.x.underline == term->attr.x.underline || attr.x.underline == IC_NONE); + assert(attr.x.italic == term->attr.x.italic || attr.x.italic == IC_NONE); +} + + +/* +ic_private void term_clear_lines_to_end(term_t* term) { + term_write(term, "\r" IC_CSI "J"); +} + +ic_private void term_show_cursor(term_t* term, bool on) { + term_write(term, on ? IC_CSI "?25h" : IC_CSI "?25l"); +} +*/ + +//------------------------------------------------------------- +// Formatted output +//------------------------------------------------------------- + +ic_private void term_writef(term_t* term, const char* fmt, ...) { + va_list ap; + va_start(ap, fmt); + term_vwritef(term,fmt,ap); + va_end(ap); +} + +ic_private void term_vwritef(term_t* term, const char* fmt, va_list args ) { + sbuf_append_vprintf(term->buf, fmt, args); +} + +ic_private void term_write_formatted( term_t* term, const char* s, const attr_t* attrs ) { + term_write_formatted_n( term, s, attrs, ic_strlen(s)); +} + +ic_private void term_write_formatted_n( term_t* term, const char* s, const attr_t* attrs, ssize_t len ) { + if (attrs == NULL) { + // write directly + term_write(term,s); + } + else { + // ensure raw mode from now on + if (term->raw_enabled <= 0) { + term_start_raw(term); + } + // and output with text attributes + const attr_t default_attr = term_get_attr(term); + attr_t attr = attr_none(); + ssize_t i = 0; + ssize_t n = 0; + while( i+n < len && s[i+n] != 0 ) { + if (!attr_is_eq(attr,attrs[i+n])) { + if (n > 0) { + term_write_n( term, s+i, n ); + i += n; + n = 0; + } + attr = attrs[i]; + term_set_attr( term, attr_update_with(default_attr,attr) ); + } + n++; + } + if (n > 0) { + term_write_n( term, s+i, n ); + i += n; + n = 0; + } + assert(s[i] != 0 || i == len); + term_set_attr(term, default_attr); + } +} + +//------------------------------------------------------------- +// Write to the terminal +// The buffered functions are used to reduce cursor flicker +// during refresh +//------------------------------------------------------------- + +ic_private void term_beep(term_t* term) { + if (term->silent) return; + fprintf(stderr,"\x7"); + fflush(stderr); +} + +ic_private void term_write_repeat(term_t* term, const char* s, ssize_t count) { + for (; count > 0; count--) { + term_write(term, s); + } +} + +ic_private void term_write(term_t* term, const char* s) { + if (s == NULL || s[0] == 0) return; + ssize_t n = ic_strlen(s); + term_write_n(term,s,n); +} + +// Primitive terminal write; all writes go through here +ic_private void term_write_n(term_t* term, const char* s, ssize_t n) { + if (s == NULL || n <= 0) return; + // write to buffer to reduce flicker and to process escape sequences (this may flush too) + term_append_buf(term, s, n); +} + + +//------------------------------------------------------------- +// Buffering +//------------------------------------------------------------- + + +ic_private void term_flush(term_t* term) { + if (sbuf_len(term->buf) > 0) { + //term_show_cursor(term,false); + term_write_direct(term, sbuf_string(term->buf), sbuf_len(term->buf)); + //term_show_cursor(term,true); + sbuf_clear(term->buf); + } +} + +ic_private buffer_mode_t term_set_buffer_mode(term_t* term, buffer_mode_t mode) { + buffer_mode_t oldmode = term->bufmode; + if (oldmode != mode) { + if (mode == UNBUFFERED) { + term_flush(term); + } + term->bufmode = mode; + } + return oldmode; +} + +static void term_check_flush(term_t* term, bool contains_nl) { + if (term->bufmode == UNBUFFERED || + sbuf_len(term->buf) > 4000 || + (term->bufmode == LINEBUFFERED && contains_nl)) + { + term_flush(term); + } +} + +//------------------------------------------------------------- +// Init +//------------------------------------------------------------- + +static void term_init_raw(term_t* term); + +ic_private term_t* term_new(alloc_t* mem, tty_t* tty, bool nocolor, bool silent, int fd_out ) +{ + term_t* term = mem_zalloc_tp(mem, term_t); + if (term == NULL) return NULL; + + term->fd_out = (fd_out < 0 ? STDOUT_FILENO : fd_out); + term->nocolor = nocolor || (isatty(term->fd_out) == 0); + term->silent = silent; + term->mem = mem; + term->tty = tty; // can be NULL + term->width = 80; + term->height = 25; + term->is_utf8 = tty_is_utf8(tty); + term->palette = ANSI16; // almost universally supported + term->buf = sbuf_new(mem); + term->bufmode = LINEBUFFERED; + term->attr = attr_default(); + + // respect NO_COLOR + if (getenv("NO_COLOR") != NULL) { + term->nocolor = true; + } + if (!term->nocolor) { + // detect color palette + // COLORTERM takes precedence + const char* colorterm = getenv("COLORTERM"); + const char* eterm = getenv("TERM"); + if (ic_contains(colorterm,"24bit") || ic_contains(colorterm,"truecolor") || ic_contains(colorterm,"direct")) { + term->palette = ANSIRGB; + } + else if (ic_contains(colorterm,"8bit") || ic_contains(colorterm,"256color")) { term->palette = ANSI256; } + else if (ic_contains(colorterm,"4bit") || ic_contains(colorterm,"16color")) { term->palette = ANSI16; } + else if (ic_contains(colorterm,"3bit") || ic_contains(colorterm,"8color")) { term->palette = ANSI8; } + else if (ic_contains(colorterm,"1bit") || ic_contains(colorterm,"nocolor") || ic_contains(colorterm,"monochrome")) { + term->palette = MONOCHROME; + } + // otherwise check for some specific terminals + else if (getenv("WT_SESSION") != NULL) { term->palette = ANSIRGB; } // Windows terminal + else if (getenv("ITERM_SESSION_ID") != NULL) { term->palette = ANSIRGB; } // iTerm2 terminal + else if (getenv("VSCODE_PID") != NULL) { term->palette = ANSIRGB; } // vscode terminal + else { + // and otherwise fall back to checking TERM + if (ic_contains(eterm,"truecolor") || ic_contains(eterm,"direct") || ic_contains(colorterm,"24bit")) { + term->palette = ANSIRGB; + } + else if (ic_contains(eterm,"alacritty") || ic_contains(eterm,"kitty")) { + term->palette = ANSIRGB; + } + else if (ic_contains(eterm,"256color") || ic_contains(eterm,"gnome")) { + term->palette = ANSI256; + } + else if (ic_contains(eterm,"16color")){ term->palette = ANSI16; } + else if (ic_contains(eterm,"8color")) { term->palette = ANSI8; } + else if (ic_contains(eterm,"monochrome") || ic_contains(eterm,"nocolor") || ic_contains(eterm,"dumb")) { + term->palette = MONOCHROME; + } + } + debug_msg("term: color-bits: %d (COLORTERM=%s, TERM=%s)\n", term_get_color_bits(term), colorterm, eterm); + } + + // read COLUMS/LINES from the environment for a better initial guess. + const char* env_columns = getenv("COLUMNS"); + if (env_columns != NULL) { ic_atoz(env_columns, &term->width); } + const char* env_lines = getenv("LINES"); + if (env_lines != NULL) { ic_atoz(env_lines, &term->height); } + + // initialize raw terminal output and terminal dimensions + term_init_raw(term); + term_update_dim(term); + term_attr_reset(term); // ensure we are at default settings + + return term; +} + +ic_private bool term_is_interactive(const term_t* term) { + ic_unused(term); + // check dimensions (0 is used for debuggers) + // if (term->width <= 0) return false; + + // check editing support + const char* eterm = getenv("TERM"); + debug_msg("term: TERM=%s\n", eterm); + if (eterm != NULL && + (strstr("dumb|DUMB|cons25|CONS25|emacs|EMACS",eterm) != NULL)) { + return false; + } + + return true; +} + +ic_private bool term_enable_beep(term_t* term, bool enable) { + bool prev = term->silent; + term->silent = !enable; + return prev; +} + +ic_private bool term_enable_color(term_t* term, bool enable) { + bool prev = !term->nocolor; + term->nocolor = !enable; + return prev; +} + +ic_private void term_free(term_t* term) { + if (term == NULL) return; + term_flush(term); + term_end_raw(term, true); + sbuf_free(term->buf); term->buf = NULL; + mem_free(term->mem, term); +} + +//------------------------------------------------------------- +// For best portability and applications inserting CSI SGR (ESC[ .. m) +// codes themselves in strings, we interpret these at the +// lowest level so we can have a `term_get_attr` function which +// is needed for bracketed styles etc. +//------------------------------------------------------------- + +static void term_append_esc(term_t* term, const char* const s, ssize_t len) { + if (s[1]=='[' && s[len-1] == 'm') { + // it is a CSI SGR sequence: ESC[ ... m + if (term->nocolor) return; // ignore escape sequences if nocolor is set + term->attr = attr_update_with(term->attr, attr_from_esc_sgr(s,len)); + } + // and write out the escape sequence as-is + sbuf_append_n(term->buf, s, len); +} + + +static void term_append_utf8(term_t* term, const char* s, ssize_t len) { + ssize_t nread; + unicode_t uchr = unicode_from_qutf8((const uint8_t*)s, len, &nread); + uint8_t c; + if (unicode_is_raw(uchr, &c)) { + // write bytes as is; this also ensure that on non-utf8 terminals characters between 0x80-0xFF + // go through _as is_ due to the qutf8 encoding. + sbuf_append_char(term->buf,(char)c); + } + else if (!term->is_utf8) { + // on non-utf8 terminals still send utf-8 and hope for the best + // todo: we could try to convert to the locale first? + sbuf_append_n(term->buf, s, len); + // sbuf_appendf(term->buf, "\x1B[%" PRIu32 "u", uchr); // unicode escape code + } + else { + // write utf-8 as is + sbuf_append_n(term->buf, s, len); + } +} + +static void term_append_buf( term_t* term, const char* s, ssize_t len ) { + ssize_t pos = 0; + bool newline = false; + while (pos < len) { + // handle ascii sequences in bulk + ssize_t ascii = 0; + ssize_t next; + while ((next = str_next_ofs(s, len, pos+ascii, NULL)) > 0 && + (uint8_t)s[pos + ascii] > '\x1B' && (uint8_t)s[pos + ascii] <= 0x7F ) + { + ascii += next; + } + if (ascii > 0) { + sbuf_append_n(term->buf, s+pos, ascii); + pos += ascii; + } + if (next <= 0) break; + + const uint8_t c = (uint8_t)s[pos]; + // handle utf8 sequences (for non-utf8 terminals) + if (c >= 0x80) { + term_append_utf8(term, s+pos, next); + } + // handle escape sequence (note: str_next_ofs considers whole CSI escape sequences at a time) + else if (next > 1 && c == '\x1B') { + term_append_esc(term, s+pos, next); + } + else if (c < ' ' && c != 0 && (c < '\x07' || c > '\x0D')) { + // ignore control characters except \a, \b, \t, \n, \r, and form-feed and vertical tab. + } + else { + if (c == '\n') { newline = true; } + sbuf_append_n(term->buf, s+pos, next); + } + pos += next; + } + // possibly flush + term_check_flush(term, newline); +} + +//------------------------------------------------------------- +// Platform dependent: Write directly to the terminal +//------------------------------------------------------------- + +#if !defined(_WIN32) + +// write to the console without further processing +static bool term_write_direct(term_t* term, const char* s, ssize_t n) { + ssize_t count = 0; + while( count < n ) { + ssize_t nwritten = write(term->fd_out, s + count, to_size_t(n - count)); + if (nwritten > 0) { + count += nwritten; + } + else if (errno != EINTR && errno != EAGAIN) { + debug_msg("term: write failed: length %i, errno %i: \"%s\"\n", n, errno, s); + return false; + } + } + return true; +} + +#else + +//---------------------------------------------------------------------------------- +// On windows we use the new virtual terminal processing if it is available (Windows Terminal) +// but fall back to ansi escape emulation on older systems but also for example +// the PS terminal +// +// note: we use row/col as 1-based ANSI escape while windows X/Y coords are 0-based. +//----------------------------------------------------------------------------------- + +#if !defined(ENABLE_VIRTUAL_TERMINAL_PROCESSING) +#define ENABLE_VIRTUAL_TERMINAL_PROCESSING (0) +#endif +#if !defined(ENABLE_LVB_GRID_WORLDWIDE) +#define ENABLE_LVB_GRID_WORLDWIDE (0) +#endif + +// direct write to the console without further processing +static bool term_write_console(term_t* term, const char* s, ssize_t n ) { + DWORD written; + // WriteConsoleA(term->hcon, s, (DWORD)(to_size_t(n)), &written, NULL); + WriteFile(term->hcon, s, (DWORD)(to_size_t(n)), &written, NULL); // so it can be redirected + return (written == (DWORD)(to_size_t(n))); +} + +static bool term_get_cursor_pos( term_t* term, ssize_t* row, ssize_t* col) { + *row = 0; + *col = 0; + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(term->hcon, &info)) return false; + *row = (ssize_t)info.dwCursorPosition.Y + 1; + *col = (ssize_t)info.dwCursorPosition.X + 1; + return true; +} + +static void term_move_cursor_to( term_t* term, ssize_t row, ssize_t col ) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + if (col > info.dwSize.X) col = info.dwSize.X; + if (row > info.dwSize.Y) row = info.dwSize.Y; + if (col <= 0) col = 1; + if (row <= 0) row = 1; + COORD coord; + coord.X = (SHORT)col - 1; + coord.Y = (SHORT)row - 1; + SetConsoleCursorPosition( term->hcon, coord); +} + +static void term_cursor_save(term_t* term) { + memset(&term->hcon_save_cursor, 0, sizeof(term->hcon_save_cursor)); + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(term->hcon, &info)) return; + term->hcon_save_cursor = info.dwCursorPosition; +} + +static void term_cursor_restore(term_t* term) { + if (term->hcon_save_cursor.X == 0) return; + SetConsoleCursorPosition(term->hcon, term->hcon_save_cursor); +} + +static void term_move_cursor( term_t* term, ssize_t drow, ssize_t dcol, ssize_t n ) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + COORD cur = info.dwCursorPosition; + ssize_t col = (ssize_t)cur.X + 1 + n*dcol; + ssize_t row = (ssize_t)cur.Y + 1 + n*drow; + term_move_cursor_to( term, row, col ); +} + +static void term_cursor_visible( term_t* term, bool visible ) { + CONSOLE_CURSOR_INFO info; + if (!GetConsoleCursorInfo(term->hcon,&info)) return; + info.bVisible = visible; + SetConsoleCursorInfo(term->hcon,&info); +} + +static void term_erase_line( term_t* term, ssize_t mode ) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + DWORD written; + COORD start; + ssize_t length; + if (mode == 2) { + // entire line + start.X = 0; + start.Y = info.dwCursorPosition.Y; + length = (ssize_t)info.srWindow.Right + 1; + } + else if (mode == 1) { + // to start of line + start.X = 0; + start.Y = info.dwCursorPosition.Y; + length = info.dwCursorPosition.X; + } + else { + // to end of line + length = (ssize_t)info.srWindow.Right - info.dwCursorPosition.X + 1; + start = info.dwCursorPosition; + } + FillConsoleOutputAttribute( term->hcon, term->hcon_default_attr, (DWORD)length, start, &written ); + FillConsoleOutputCharacterA( term->hcon, ' ', (DWORD)length, start, &written ); +} + +static void term_clear_screen(term_t* term, ssize_t mode) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(term->hcon, &info)) return; + COORD start; + start.X = 0; + start.Y = 0; + ssize_t length; + ssize_t width = (ssize_t)info.dwSize.X; + if (mode == 2) { + // entire screen + length = width * info.dwSize.Y; + } + else if (mode == 1) { + // to cursor + length = (width * ((ssize_t)info.dwCursorPosition.Y - 1)) + info.dwCursorPosition.X; + } + else { + // from cursor + start = info.dwCursorPosition; + length = (width * ((ssize_t)info.dwSize.Y - info.dwCursorPosition.Y)) + (width - info.dwCursorPosition.X + 1); + } + DWORD written; + FillConsoleOutputAttribute(term->hcon, term->hcon_default_attr, (DWORD)length, start, &written); + FillConsoleOutputCharacterA(term->hcon, ' ', (DWORD)length, start, &written); +} + +static WORD attr_color[8] = { + 0, // black + FOREGROUND_RED, // maroon + FOREGROUND_GREEN, // green + FOREGROUND_RED | FOREGROUND_GREEN, // orange + FOREGROUND_BLUE, // navy + FOREGROUND_RED | FOREGROUND_BLUE, // purple + FOREGROUND_GREEN | FOREGROUND_BLUE, // teal + FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE, // light gray +}; + +static void term_set_win_attr( term_t* term, attr_t ta ) { + WORD def_attr = term->hcon_default_attr; + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + WORD cur_attr = info.wAttributes; + WORD attr = cur_attr; + if (ta.x.color != IC_COLOR_NONE) { + if (ta.x.color >= IC_ANSI_BLACK && ta.x.color <= IC_ANSI_SILVER) { + attr = (attr & 0xFFF0) | attr_color[ta.x.color - IC_ANSI_BLACK]; + } + else if (ta.x.color >= IC_ANSI_GRAY && ta.x.color <= IC_ANSI_WHITE) { + attr = (attr & 0xFFF0) | attr_color[ta.x.color - IC_ANSI_GRAY] | FOREGROUND_INTENSITY; + } + else if (ta.x.color == IC_ANSI_DEFAULT) { + attr = (attr & 0xFFF0) | (def_attr & 0x000F); + } + } + if (ta.x.bgcolor != IC_COLOR_NONE) { + if (ta.x.bgcolor >= IC_ANSI_BLACK && ta.x.bgcolor <= IC_ANSI_SILVER) { + attr = (attr & 0xFF0F) | (WORD)(attr_color[ta.x.bgcolor - IC_ANSI_BLACK] << 4); + } + else if (ta.x.bgcolor >= IC_ANSI_GRAY && ta.x.bgcolor <= IC_ANSI_WHITE) { + attr = (attr & 0xFF0F) | (WORD)(attr_color[ta.x.bgcolor - IC_ANSI_GRAY] << 4) | BACKGROUND_INTENSITY; + } + else if (ta.x.bgcolor == IC_ANSI_DEFAULT) { + attr = (attr & 0xFF0F) | (def_attr & 0x00F0); + } + } + if (ta.x.underline != IC_NONE) { + attr = (attr & ~COMMON_LVB_UNDERSCORE) | (ta.x.underline == IC_ON ? COMMON_LVB_UNDERSCORE : 0); + } + if (ta.x.reverse != IC_NONE) { + attr = (attr & ~COMMON_LVB_REVERSE_VIDEO) | (ta.x.reverse == IC_ON ? COMMON_LVB_REVERSE_VIDEO : 0); + } + if (attr != cur_attr) { + SetConsoleTextAttribute(term->hcon, attr); + } +} + +static ssize_t esc_param( const char* s, ssize_t def ) { + if (*s == '?') s++; + ssize_t n = def; + ic_atoz(s, &n); + return n; +} + +static void esc_param2( const char* s, ssize_t* p1, ssize_t* p2, ssize_t def ) { + if (*s == '?') s++; + *p1 = def; + *p2 = def; + ic_atoz2(s, p1, p2); +} + +// Emulate escape sequences on older windows. +static void term_write_esc( term_t* term, const char* s, ssize_t len ) { + ssize_t row; + ssize_t col; + + if (s[1] == '[') { + switch (s[len-1]) { + case 'A': + term_move_cursor(term, -1, 0, esc_param(s+2, 1)); + break; + case 'B': + term_move_cursor(term, 1, 0, esc_param(s+2, 1)); + break; + case 'C': + term_move_cursor(term, 0, 1, esc_param(s+2, 1)); + break; + case 'D': + term_move_cursor(term, 0, -1, esc_param(s+2, 1)); + break; + case 'H': + esc_param2(s+2, &row, &col, 1); + term_move_cursor_to(term, row, col); + break; + case 'K': + term_erase_line(term, esc_param(s+2, 0)); + break; + case 'm': + term_set_win_attr( term, attr_from_esc_sgr(s,len) ); + break; + + // support some less standard escape codes (currently not used by isocline) + case 'E': // line down + term_get_cursor_pos(term, &row, &col); + row += esc_param(s+2, 1); + term_move_cursor_to(term, row, 1); + break; + case 'F': // line up + term_get_cursor_pos(term, &row, &col); + row -= esc_param(s+2, 1); + term_move_cursor_to(term, row, 1); + break; + case 'G': // absolute column + term_get_cursor_pos(term, &row, &col); + col = esc_param(s+2, 1); + term_move_cursor_to(term, row, col); + break; + case 'J': + term_clear_screen(term, esc_param(s+2, 0)); + break; + case 'h': + if (strncmp(s+2, "?25h", 4) == 0) { + term_cursor_visible(term, true); + } + break; + case 'l': + if (strncmp(s+2, "?25l", 4) == 0) { + term_cursor_visible(term, false); + } + break; + case 's': + term_cursor_save(term); + break; + case 'u': + term_cursor_restore(term); + break; + // otherwise ignore + } + } + else if (s[1] == '7') { + term_cursor_save(term); + } + else if (s[1] == '8') { + term_cursor_restore(term); + } + else { + // otherwise ignore + } +} + +static bool term_write_direct(term_t* term, const char* s, ssize_t len ) { + term_cursor_visible(term,false); // reduce flicker + ssize_t pos = 0; + if ((term->hcon_mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) != 0) { + // use the builtin virtual terminal processing. (enables truecolor for example) + term_write_console(term, s, len); + pos = len; + } + else { + // emulate escape sequences + while( pos < len ) { + // handle non-control in bulk (including utf-8 sequences) + // (We don't need to handle utf-8 separately as we set the codepage to always be in utf-8 mode) + ssize_t nonctrl = 0; + ssize_t next; + while( (next = str_next_ofs( s, len, pos+nonctrl, NULL )) > 0 && + (uint8_t)s[pos + nonctrl] >= ' ' && (uint8_t)s[pos + nonctrl] <= 0x7F) { + nonctrl += next; + } + if (nonctrl > 0) { + term_write_console(term, s+pos, nonctrl); + pos += nonctrl; + } + if (next <= 0) break; + + if ((uint8_t)s[pos] >= 0x80) { + // utf8 is already processed + term_write_console(term, s+pos, next); + } + else if (next > 1 && s[pos] == '\x1B') { + // handle control (note: str_next_ofs considers whole CSI escape sequences at a time) + term_write_esc(term, s+pos, next); + } + else if (next == 1 && (s[pos] == '\r' || s[pos] == '\n' || s[pos] == '\t' || s[pos] == '\b')) { + term_write_console( term, s+pos, next); + } + else { + // ignore + } + pos += next; + } + } + term_cursor_visible(term,true); + assert(pos == len); + return (pos == len); + +} +#endif + + + +//------------------------------------------------------------- +// Update terminal dimensions +//------------------------------------------------------------- + +#if !defined(_WIN32) + +// send escape query that may return a response on the tty +static bool term_esc_query_raw( term_t* term, const char* query, char* buf, ssize_t buflen ) +{ + if (buf==NULL || buflen <= 0 || query[0] == 0) return false; + bool osc = (query[1] == ']'); + if (!term_write_direct(term, query, ic_strlen(query))) return false; + debug_msg("term: read tty query response to: ESC %s\n", query + 1); + return tty_read_esc_response( term->tty, query[1], osc, buf, buflen ); +} + +static bool term_esc_query( term_t* term, const char* query, char* buf, ssize_t buflen ) +{ + if (!tty_start_raw(term->tty)) return false; + bool ok = term_esc_query_raw(term,query,buf,buflen); + tty_end_raw(term->tty); + return ok; +} + +// get the cursor position via an ESC[6n +static bool term_get_cursor_pos( term_t* term, ssize_t* row, ssize_t* col) +{ + // send escape query + char buf[128]; + if (!term_esc_query(term,"\x1B[6n",buf,128)) return false; + if (!ic_atoz2(buf,row,col)) return false; + return true; +} + +static void term_set_cursor_pos( term_t* term, ssize_t row, ssize_t col ) { + term_writef( term, IC_CSI "%zd;%zdH", row, col ); +} + +ic_private bool term_update_dim(term_t* term) { + ssize_t cols = 0; + ssize_t rows = 0; + struct winsize ws; + if (ioctl(term->fd_out, TIOCGWINSZ, &ws) >= 0) { + // ioctl succeeded + cols = ws.ws_col; // debuggers return 0 for the column + rows = ws.ws_row; + } + else { + // determine width by querying the cursor position + debug_msg("term: ioctl term-size failed: %d,%d\n", ws.ws_row, ws.ws_col); + ssize_t col0 = 0; + ssize_t row0 = 0; + if (term_get_cursor_pos(term,&row0,&col0)) { + term_set_cursor_pos(term,999,999); + ssize_t col1 = 0; + ssize_t row1 = 0; + if (term_get_cursor_pos(term,&row1,&col1)) { + cols = col1; + rows = row1; + } + term_set_cursor_pos(term,row0,col0); + } + else { + // cannot query position + // return 0 column + } + } + + // update width and return whether it changed. + bool changed = (term->width != cols || term->height != rows); + debug_msg("terminal dim: %zd,%zd: %s\n", rows, cols, changed ? "changed" : "unchanged"); + if (cols > 0) { + term->width = cols; + term->height = rows; + } + return changed; +} + +#else + +ic_private bool term_update_dim(term_t* term) { + if (term->hcon == 0) { + term->hcon = GetConsoleWindow(); + } + ssize_t rows = 0; + ssize_t cols = 0; + CONSOLE_SCREEN_BUFFER_INFO sbinfo; + if (GetConsoleScreenBufferInfo(term->hcon, &sbinfo)) { + cols = (ssize_t)sbinfo.srWindow.Right - (ssize_t)sbinfo.srWindow.Left + 1; + rows = (ssize_t)sbinfo.srWindow.Bottom - (ssize_t)sbinfo.srWindow.Top + 1; + } + bool changed = (term->width != cols || term->height != rows); + term->width = cols; + term->height = rows; + debug_msg("term: update dim: %zd, %zd\n", term->height, term->width ); + return changed; +} + +#endif + + + +//------------------------------------------------------------- +// Enable/disable terminal raw mode +//------------------------------------------------------------- + +#if !defined(_WIN32) + +// On non-windows, the terminal is set in raw mode by the tty. + +ic_private void term_start_raw(term_t* term) { + term->raw_enabled++; +} + +ic_private void term_end_raw(term_t* term, bool force) { + if (term->raw_enabled <= 0) return; + if (!force) { + term->raw_enabled--; + } + else { + term->raw_enabled = 0; + } +} + +static bool term_esc_query_color_raw(term_t* term, int color_idx, uint32_t* color ) { + char buf[128+1]; + snprintf(buf,128,"\x1B]4;%d;?\x1B\\", color_idx); + if (!term_esc_query_raw( term, buf, buf, 128 )) { + debug_msg("esc query response not received\n"); + return false; + } + if (buf[0] != '4') return false; + const char* rgb = strchr(buf,':'); + if (rgb==NULL) return false; + rgb++; // skip ':' + unsigned int r,g,b; + if (sscanf(rgb,"%x/%x/%x",&r,&g,&b) != 3) return false; + if (rgb[2]!='/') { // 48-bit rgb, hexadecimal round to 24-bit + r = (r+0x7F)/0x100; // note: can "overflow", e.g. 0xFFFF -> 0x100. (and we need `ic_cap8` to convert.) + g = (g+0x7F)/0x100; + b = (b+0x7F)/0x100; + } + *color = (ic_cap8(r)<<16) | (ic_cap8(g)<<8) | ic_cap8(b); + debug_msg("color query: %02x,%02x,%02x: %06x\n", r, g, b, *color); + return true; +} + +// update ansi 16 color palette for better color approximation +static void term_update_ansi16(term_t* term) { + debug_msg("update ansi colors\n"); + #if defined(GIO_CMAP) + // try ioctl first (on Linux) + uint8_t cmap[48]; + memset(cmap,0,48); + if (ioctl(term->fd_out,GIO_CMAP,&cmap) >= 0) { + // success + for(ssize_t i = 0; i < 48; i+=3) { + uint32_t color = ((uint32_t)(cmap[i]) << 16) | ((uint32_t)(cmap[i+1]) << 8) | cmap[i+2]; + debug_msg("term (ioctl) ansi color %d: 0x%06x\n", i, color); + ansi256[i] = color; + } + return; + } + else { + debug_msg("ioctl GIO_CMAP failed: entry 1: 0x%02x%02x%02x\n", cmap[3], cmap[4], cmap[5]); + } + #endif + // this seems to be unreliable on some systems (Ubuntu+Gnome terminal) so only enable when known ok. + #if __APPLE__ + // otherwise use OSC 4 escape sequence query + if (tty_start_raw(term->tty)) { + for(ssize_t i = 0; i < 16; i++) { + uint32_t color; + if (!term_esc_query_color_raw(term, i, &color)) break; + debug_msg("term ansi color %d: 0x%06x\n", i, color); + ansi256[i] = color; + } + tty_end_raw(term->tty); + } + #endif +} + +static void term_init_raw(term_t* term) { + if (term->palette < ANSIRGB) { + term_update_ansi16(term); + } +} + +#else + +ic_private void term_start_raw(term_t* term) { + if (term->raw_enabled++ > 0) return; + CONSOLE_SCREEN_BUFFER_INFO info; + if (GetConsoleScreenBufferInfo(term->hcon, &info)) { + term->hcon_orig_attr = info.wAttributes; + } + term->hcon_orig_cp = GetConsoleOutputCP(); + SetConsoleOutputCP(CP_UTF8); + if (term->hcon_mode == 0) { + // first time initialization + DWORD mode = ENABLE_PROCESSED_OUTPUT | ENABLE_WRAP_AT_EOL_OUTPUT | ENABLE_LVB_GRID_WORLDWIDE; // for \r \n and \b + // use escape sequence handling if available and the terminal supports it (so we can use rgb colors in Windows terminal) + // Unfortunately, in plain powershell, we can successfully enable terminal processing + // but it still fails to render correctly; so we require the palette be large enough (like in Windows Terminal) + if (term->palette >= ANSI256 && SetConsoleMode(term->hcon, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + term->hcon_mode = mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING; + debug_msg("term: console mode: virtual terminal processing enabled\n"); + } + // no virtual terminal processing, emulate instead + else if (SetConsoleMode(term->hcon, mode)) { + term->hcon_mode = mode; + term->palette = ANSI16; + } + GetConsoleMode(term->hcon, &mode); + debug_msg("term: console mode: orig: 0x%x, new: 0x%x, current 0x%x\n", term->hcon_orig_mode, term->hcon_mode, mode); + } + else { + SetConsoleMode(term->hcon, term->hcon_mode); + } +} + +ic_private void term_end_raw(term_t* term, bool force) { + if (term->raw_enabled <= 0) return; + if (!force && term->raw_enabled > 1) { + term->raw_enabled--; + } + else { + term->raw_enabled = 0; + SetConsoleMode(term->hcon, term->hcon_orig_mode); + SetConsoleOutputCP(term->hcon_orig_cp); + SetConsoleTextAttribute(term->hcon, term->hcon_orig_attr); + } +} + +static void term_init_raw(term_t* term) { + term->hcon = GetStdHandle(STD_OUTPUT_HANDLE); + GetConsoleMode(term->hcon, &term->hcon_orig_mode); + CONSOLE_SCREEN_BUFFER_INFOEX info; + memset(&info, 0, sizeof(info)); + info.cbSize = sizeof(info); + if (GetConsoleScreenBufferInfoEx(term->hcon, &info)) { + // store default attributes + term->hcon_default_attr = info.wAttributes; + // update our color table with the actual colors used. + for (unsigned i = 0; i < 16; i++) { + COLORREF cr = info.ColorTable[i]; + uint32_t color = (ic_cap8(GetRValue(cr))<<16) | (ic_cap8(GetGValue(cr))<<8) | ic_cap8(GetBValue(cr)); // COLORREF = BGR + // index is also in reverse in the bits 0 and 2 + unsigned j = (i&0x08) | ((i&0x04)>>2) | (i&0x02) | (i&0x01)<<2; + debug_msg("term: ansi color %d is 0x%06x\n", j, color); + ansi256[j] = color; + } + } + else { + DWORD err = GetLastError(); + debug_msg("term: cannot get console screen buffer: %d %x", err, err); + } + term_start_raw(term); // initialize the hcon_mode + term_end_raw(term,false); +} + +#endif diff --git a/extern/isocline/src/term.h b/extern/isocline/src/term.h new file mode 100644 index 00000000..50bfd968 --- /dev/null +++ b/extern/isocline/src/term.h @@ -0,0 +1,85 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_TERM_H +#define IC_TERM_H + +#include "common.h" +#include "tty.h" +#include "stringbuf.h" +#include "attr.h" + +struct term_s; +typedef struct term_s term_t; + +typedef enum buffer_mode_e { + UNBUFFERED, + LINEBUFFERED, + BUFFERED, +} buffer_mode_t; + +// Primitives +ic_private term_t* term_new(alloc_t* mem, tty_t* tty, bool nocolor, bool silent, int fd_out); +ic_private void term_free(term_t* term); + +ic_private bool term_is_interactive(const term_t* term); +ic_private void term_start_raw(term_t* term); +ic_private void term_end_raw(term_t* term, bool force); + +ic_private bool term_enable_beep(term_t* term, bool enable); +ic_private bool term_enable_color(term_t* term, bool enable); + +ic_private void term_flush(term_t* term); +ic_private buffer_mode_t term_set_buffer_mode(term_t* term, buffer_mode_t mode); + +ic_private void term_write_n(term_t* term, const char* s, ssize_t n); +ic_private void term_write(term_t* term, const char* s); +ic_private void term_writeln(term_t* term, const char* s); +ic_private void term_write_char(term_t* term, char c); + +ic_private void term_write_repeat(term_t* term, const char* s, ssize_t count ); +ic_private void term_beep(term_t* term); + +ic_private bool term_update_dim(term_t* term); + +ic_private ssize_t term_get_width(term_t* term); +ic_private ssize_t term_get_height(term_t* term); +ic_private int term_get_color_bits(term_t* term); + +// Helpers +ic_private void term_writef(term_t* term, const char* fmt, ...); +ic_private void term_vwritef(term_t* term, const char* fmt, va_list args); + +ic_private void term_left(term_t* term, ssize_t n); +ic_private void term_right(term_t* term, ssize_t n); +ic_private void term_up(term_t* term, ssize_t n); +ic_private void term_down(term_t* term, ssize_t n); +ic_private void term_start_of_line(term_t* term ); +ic_private void term_clear_line(term_t* term); +ic_private void term_clear_to_end_of_line(term_t* term); +// ic_private void term_clear_lines_to_end(term_t* term); + + +ic_private void term_attr_reset(term_t* term); +ic_private void term_underline(term_t* term, bool on); +ic_private void term_reverse(term_t* term, bool on); +ic_private void term_bold(term_t* term, bool on); +ic_private void term_italic(term_t* term, bool on); + +ic_private void term_color(term_t* term, ic_color_t color); +ic_private void term_bgcolor(term_t* term, ic_color_t color); + +// Formatted output + +ic_private attr_t term_get_attr( const term_t* term ); +ic_private void term_set_attr( term_t* term, attr_t attr ); +ic_private void term_write_formatted( term_t* term, const char* s, const attr_t* attrs ); +ic_private void term_write_formatted_n( term_t* term, const char* s, const attr_t* attrs, ssize_t n ); + +ic_private ic_color_t color_from_ansi256(ssize_t i); + +#endif // IC_TERM_H diff --git a/extern/isocline/src/term_color.c b/extern/isocline/src/term_color.c new file mode 100644 index 00000000..98af3cf4 --- /dev/null +++ b/extern/isocline/src/term_color.c @@ -0,0 +1,371 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +// This file is included in "term.c" + +//------------------------------------------------------------- +// Standard ANSI palette for 256 colors +//------------------------------------------------------------- + +static uint32_t ansi256[256] = { + // not const as on some platforms (e.g. Windows, xterm) we update the first 16 entries with the actual used colors. + // 0, standard ANSI + 0x000000, 0x800000, 0x008000, 0x808000, 0x000080, 0x800080, + 0x008080, 0xc0c0c0, + // 8, bright ANSI + 0x808080, 0xff0000, 0x00ff00, 0xffff00, 0x0000ff, 0xff00ff, + 0x00ffff, 0xffffff, + // 6x6x6 RGB colors + // 16 + 0x000000, 0x00005f, 0x000087, 0x0000af, 0x0000d7, 0x0000ff, + 0x005f00, 0x005f5f, 0x005f87, 0x005faf, 0x005fd7, 0x005fff, + 0x008700, 0x00875f, 0x008787, 0x0087af, 0x0087d7, 0x0087ff, + 0x00af00, 0x00af5f, 0x00af87, 0x00afaf, 0x00afd7, 0x00afff, + 0x00d700, 0x00d75f, 0x00d787, 0x00d7af, 0x00d7d7, 0x00d7ff, + 0x00ff00, 0x00ff5f, 0x00ff87, 0x00ffaf, 0x00ffd7, 0x00ffff, + // 52 + 0x5f0000, 0x5f005f, 0x5f0087, 0x5f00af, 0x5f00d7, 0x5f00ff, + 0x5f5f00, 0x5f5f5f, 0x5f5f87, 0x5f5faf, 0x5f5fd7, 0x5f5fff, + 0x5f8700, 0x5f875f, 0x5f8787, 0x5f87af, 0x5f87d7, 0x5f87ff, + 0x5faf00, 0x5faf5f, 0x5faf87, 0x5fafaf, 0x5fafd7, 0x5fafff, + 0x5fd700, 0x5fd75f, 0x5fd787, 0x5fd7af, 0x5fd7d7, 0x5fd7ff, + 0x5fff00, 0x5fff5f, 0x5fff87, 0x5fffaf, 0x5fffd7, 0x5fffff, + // 88 + 0x870000, 0x87005f, 0x870087, 0x8700af, 0x8700d7, 0x8700ff, + 0x875f00, 0x875f5f, 0x875f87, 0x875faf, 0x875fd7, 0x875fff, + 0x878700, 0x87875f, 0x878787, 0x8787af, 0x8787d7, 0x8787ff, + 0x87af00, 0x87af5f, 0x87af87, 0x87afaf, 0x87afd7, 0x87afff, + 0x87d700, 0x87d75f, 0x87d787, 0x87d7af, 0x87d7d7, 0x87d7ff, + 0x87ff00, 0x87ff5f, 0x87ff87, 0x87ffaf, 0x87ffd7, 0x87ffff, + // 124 + 0xaf0000, 0xaf005f, 0xaf0087, 0xaf00af, 0xaf00d7, 0xaf00ff, + 0xaf5f00, 0xaf5f5f, 0xaf5f87, 0xaf5faf, 0xaf5fd7, 0xaf5fff, + 0xaf8700, 0xaf875f, 0xaf8787, 0xaf87af, 0xaf87d7, 0xaf87ff, + 0xafaf00, 0xafaf5f, 0xafaf87, 0xafafaf, 0xafafd7, 0xafafff, + 0xafd700, 0xafd75f, 0xafd787, 0xafd7af, 0xafd7d7, 0xafd7ff, + 0xafff00, 0xafff5f, 0xafff87, 0xafffaf, 0xafffd7, 0xafffff, + // 160 + 0xd70000, 0xd7005f, 0xd70087, 0xd700af, 0xd700d7, 0xd700ff, + 0xd75f00, 0xd75f5f, 0xd75f87, 0xd75faf, 0xd75fd7, 0xd75fff, + 0xd78700, 0xd7875f, 0xd78787, 0xd787af, 0xd787d7, 0xd787ff, + 0xd7af00, 0xd7af5f, 0xd7af87, 0xd7afaf, 0xd7afd7, 0xd7afff, + 0xd7d700, 0xd7d75f, 0xd7d787, 0xd7d7af, 0xd7d7d7, 0xd7d7ff, + 0xd7ff00, 0xd7ff5f, 0xd7ff87, 0xd7ffaf, 0xd7ffd7, 0xd7ffff, + // 196 + 0xff0000, 0xff005f, 0xff0087, 0xff00af, 0xff00d7, 0xff00ff, + 0xff5f00, 0xff5f5f, 0xff5f87, 0xff5faf, 0xff5fd7, 0xff5fff, + 0xff8700, 0xff875f, 0xff8787, 0xff87af, 0xff87d7, 0xff87ff, + 0xffaf00, 0xffaf5f, 0xffaf87, 0xffafaf, 0xffafd7, 0xffafff, + 0xffd700, 0xffd75f, 0xffd787, 0xffd7af, 0xffd7d7, 0xffd7ff, + 0xffff00, 0xffff5f, 0xffff87, 0xffffaf, 0xffffd7, 0xffffff, + // 232, gray scale + 0x080808, 0x121212, 0x1c1c1c, 0x262626, 0x303030, 0x3a3a3a, + 0x444444, 0x4e4e4e, 0x585858, 0x626262, 0x6c6c6c, 0x767676, + 0x808080, 0x8a8a8a, 0x949494, 0x9e9e9e, 0xa8a8a8, 0xb2b2b2, + 0xbcbcbc, 0xc6c6c6, 0xd0d0d0, 0xdadada, 0xe4e4e4, 0xeeeeee +}; + + +//------------------------------------------------------------- +// Create colors +//------------------------------------------------------------- + +// Create a color from a 24-bit color value. +ic_private ic_color_t ic_rgb(uint32_t hex) { + return (ic_color_t)(0x1000000 | (hex & 0xFFFFFF)); +} + +// Limit an int to values between 0 and 255. +static uint32_t ic_cap8(ssize_t i) { + return (i < 0 ? 0 : (i > 255 ? 255 : (uint32_t)i)); +} + +// Create a color from a 24-bit color value. +ic_private ic_color_t ic_rgbx(ssize_t r, ssize_t g, ssize_t b) { + return ic_rgb( (ic_cap8(r)<<16) | (ic_cap8(g)<<8) | ic_cap8(b) ); +} + + +//------------------------------------------------------------- +// Match an rgb color to a ansi8, ansi16, or ansi256 +//------------------------------------------------------------- + +static bool color_is_rgb( ic_color_t color ) { + return (color >= IC_RGB(0)); // bit 24 is set for rgb colors +} + +static void color_to_rgb(ic_color_t color, int* r, int* g, int* b) { + assert(color_is_rgb(color)); + *r = ((color >> 16) & 0xFF); + *g = ((color >> 8) & 0xFF); + *b = (color & 0xFF); +} + +ic_private ic_color_t color_from_ansi256(ssize_t i) { + if (i >= 0 && i < 8) { + return (IC_ANSI_BLACK + (uint32_t)i); + } + else if (i >= 8 && i < 16) { + return (IC_ANSI_DARKGRAY + (uint32_t)(i - 8)); + } + else if (i >= 16 && i <= 255) { + return ic_rgb( ansi256[i] ); + } + else if (i == 256) { + return IC_ANSI_DEFAULT; + } + else { + return IC_ANSI_DEFAULT; + } +} + +static bool is_grayish(int r, int g, int b) { + return (abs(r-g) <= 4) && (abs((r+g)/2 - b) <= 4); +} + +static bool is_grayish_color( uint32_t rgb ) { + int r, g, b; + color_to_rgb(IC_RGB(rgb),&r,&g,&b); + return is_grayish(r,g,b); +} + +static int_least32_t sqr(int_least32_t x) { + return x*x; +} + +// Approximation to delta-E CIE color distance using much +// simpler calculations. See . +// This is essentialy weighted euclidean distance but the weight distribution +// depends on how big the "red" component of the color is. +// We do not take the square root as we only need to find +// the minimal distance (and multiply by 256 to increase precision). +// Needs at least 28-bit signed integers to avoid overflow. +static int_least32_t rgb_distance_rmean( uint32_t color, int r2, int g2, int b2 ) { + int r1, g1, b1; + color_to_rgb(IC_RGB(color),&r1,&g1,&b1); + int_least32_t rmean = (r1 + r2) / 2; + int_least32_t dr2 = sqr(r1 - r2); + int_least32_t dg2 = sqr(g1 - g2); + int_least32_t db2 = sqr(b1 - b2); + int_least32_t dist = ((512+rmean)*dr2) + 1024*dg2 + ((767-rmean)*db2); + return dist; +} + +// Another approximation to delta-E CIE color distance using +// simpler calculations. Similar to `rmean` but adds an adjustment factor +// based on the "red/blue" difference. +static int_least32_t rgb_distance_rbmean( uint32_t color, int r2, int g2, int b2 ) { + int r1, g1, b1; + color_to_rgb(IC_RGB(color),&r1,&g1,&b1); + int_least32_t rmean = (r1 + r2) / 2; + int_least32_t dr2 = sqr(r1 - r2); + int_least32_t dg2 = sqr(g1 - g2); + int_least32_t db2 = sqr(b1 - b2); + int_least32_t dist = 2*dr2 + 4*dg2 + 3*db2 + ((rmean*(dr2 - db2))/256); + return dist; +} + + +// Maintain a small cache of recently used colors. Should be short enough to be effectively constant time. +// If we ever use a more expensive color distance method, we may increase the size a bit (64?) +// (Initial zero initialized cache is valid.) +#define RGB_CACHE_LEN (16) +typedef struct rgb_cache_s { + int last; + int indices[RGB_CACHE_LEN]; + ic_color_t colors[RGB_CACHE_LEN]; +} rgb_cache_t; + +// remember a color in the LRU cache +void rgb_remember( rgb_cache_t* cache, ic_color_t color, int idx ) { + if (cache == NULL) return; + cache->colors[cache->last] = color; + cache->indices[cache->last] = idx; + cache->last++; + if (cache->last >= RGB_CACHE_LEN) { cache->last = 0; } +} + +// quick lookup in cache; -1 on failure +int rgb_lookup( const rgb_cache_t* cache, ic_color_t color ) { + if (cache != NULL) { + for(int i = 0; i < RGB_CACHE_LEN; i++) { + if (cache->colors[i] == color) return cache->indices[i]; + } + } + return -1; +} + +// return the index of the closest matching color +static int rgb_match( uint32_t* palette, int start, int len, rgb_cache_t* cache, ic_color_t color ) { + assert(color_is_rgb(color)); + // in cache? + int min = rgb_lookup(cache,color); + if (min >= 0) { + return min; + } + // otherwise find closest color match in the palette + int r, g, b; + color_to_rgb(color,&r,&g,&b); + min = start; + int_least32_t mindist = (INT_LEAST32_MAX)/4; + for(int i = start; i < len; i++) { + //int_least32_t dist = rgb_distance_rbmean(palette[i],r,g,b); + int_least32_t dist = rgb_distance_rmean(palette[i],r,g,b); + if (is_grayish_color(palette[i]) != is_grayish(r, g, b)) { + // with few colors, make it less eager to substitute a gray for a non-gray (or the other way around) + if (len <= 16) { + dist *= 4; + } + else { + dist = (dist/4)*5; + } + } + if (dist < mindist) { + min = i; + mindist = dist; + } + } + rgb_remember(cache,color,min); + return min; +} + + +// Match RGB to an index in the ANSI 256 color table +static int rgb_to_ansi256(ic_color_t color) { + static rgb_cache_t ansi256_cache; + int c = rgb_match(ansi256, 16, 256, &ansi256_cache, color); // not the first 16 ANSI colors as those may be different + //debug_msg("term: rgb %x -> ansi 256: %d\n", color, c ); + return c; +} + +// Match RGB to an ANSI 16 color code (30-37, 90-97) +static int color_to_ansi16(ic_color_t color) { + if (!color_is_rgb(color)) { + return (int)color; + } + else { + static rgb_cache_t ansi16_cache; + int c = rgb_match(ansi256, 0, 16, &ansi16_cache, color); + //debug_msg("term: rgb %x -> ansi 16: %d\n", color, c ); + return (c < 8 ? 30 + c : 90 + c - 8); + } +} + +// Match RGB to an ANSI 16 color code (30-37, 90-97) +// but assuming the bright colors are simulated using 'bold'. +static int color_to_ansi8(ic_color_t color) { + if (!color_is_rgb(color)) { + return (int)color; + } + else { + // match to basic 8 colors first + static rgb_cache_t ansi8_cache; + int c = 30 + rgb_match(ansi256, 0, 8, &ansi8_cache, color); + // and then adjust for brightness + int r, g, b; + color_to_rgb(color,&r,&g,&b); + if (r>=196 || g>=196 || b>=196) c += 60; + //debug_msg("term: rgb %x -> ansi 8: %d\n", color, c ); + return c; + } +} + + +//------------------------------------------------------------- +// Emit color escape codes based on the terminal capability +//------------------------------------------------------------- + +static void fmt_color_ansi8( char* buf, ssize_t len, ic_color_t color, bool bg ) { + int c = color_to_ansi8(color) + (bg ? 10 : 0); + if (c >= 90) { + snprintf(buf, to_size_t(len), IC_CSI "1;%dm", c - 60); + } + else { + snprintf(buf, to_size_t(len), IC_CSI "22;%dm", c ); + } +} + +static void fmt_color_ansi16( char* buf, ssize_t len, ic_color_t color, bool bg ) { + snprintf( buf, to_size_t(len), IC_CSI "%dm", color_to_ansi16(color) + (bg ? 10 : 0) ); +} + +static void fmt_color_ansi256( char* buf, ssize_t len, ic_color_t color, bool bg ) { + if (!color_is_rgb(color)) { + fmt_color_ansi16(buf,len,color,bg); + } + else { + snprintf( buf, to_size_t(len), IC_CSI "%d;5;%dm", (bg ? 48 : 38), rgb_to_ansi256(color) ); + } +} + +static void fmt_color_rgb( char* buf, ssize_t len, ic_color_t color, bool bg ) { + if (!color_is_rgb(color)) { + fmt_color_ansi16(buf,len,color,bg); + } + else { + int r,g,b; + color_to_rgb(color, &r,&g,&b); + snprintf( buf, to_size_t(len), IC_CSI "%d;2;%d;%d;%dm", (bg ? 48 : 38), r, g, b ); + } +} + +static void fmt_color_ex(char* buf, ssize_t len, palette_t palette, ic_color_t color, bool bg) { + if (color == IC_COLOR_NONE || palette == MONOCHROME) return; + if (palette == ANSI8) { + fmt_color_ansi8(buf,len,color,bg); + } + else if (!color_is_rgb(color) || palette == ANSI16) { + fmt_color_ansi16(buf,len,color,bg); + } + else if (palette == ANSI256) { + fmt_color_ansi256(buf,len,color,bg); + } + else { + fmt_color_rgb(buf,len,color,bg); + } +} + +static void term_color_ex(term_t* term, ic_color_t color, bool bg) { + char buf[128+1]; + fmt_color_ex(buf,128,term->palette,color,bg); + term_write(term,buf); +} + +//------------------------------------------------------------- +// Main API functions +//------------------------------------------------------------- + +ic_private void term_color(term_t* term, ic_color_t color) { + term_color_ex(term,color,false); +} + +ic_private void term_bgcolor(term_t* term, ic_color_t color) { + term_color_ex(term,color,true); +} + +ic_private void term_append_color(term_t* term, stringbuf_t* sbuf, ic_color_t color) { + char buf[128+1]; + fmt_color_ex(buf,128,term->palette,color,false); + sbuf_append(sbuf,buf); +} + +ic_private void term_append_bgcolor(term_t* term, stringbuf_t* sbuf, ic_color_t color) { + char buf[128+1]; + fmt_color_ex(buf, 128, term->palette, color, true); + sbuf_append(sbuf, buf); +} + +ic_private int term_get_color_bits(term_t* term) { + switch (term->palette) { + case MONOCHROME: return 1; + case ANSI8: return 3; + case ANSI16: return 4; + case ANSI256: return 8; + case ANSIRGB: return 24; + default: return 4; + } +} diff --git a/extern/isocline/src/tty.c b/extern/isocline/src/tty.c new file mode 100644 index 00000000..09f7aedd --- /dev/null +++ b/extern/isocline/src/tty.c @@ -0,0 +1,889 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include +#include +#include + +#include "tty.h" + +#if defined(_WIN32) +#include +#include +#define isatty(fd) _isatty(fd) +#define read(fd,s,n) _read(fd,s,n) +#define STDIN_FILENO 0 +#if (_WIN32_WINNT < 0x0600) +WINBASEAPI ULONGLONG WINAPI GetTickCount64(VOID); +#endif +#else +#include +#include +#include +#include +#include +#include +#if !defined(FIONREAD) +#include +#endif +#endif + +#define TTY_PUSH_MAX (32) + +struct tty_s { + int fd_in; // input handle + bool raw_enabled; // is raw mode enabled? + bool is_utf8; // is the input stream in utf-8 mode? + bool has_term_resize_event; // are resize events generated? + bool term_resize_event; // did a term resize happen? + alloc_t* mem; // memory allocator + code_t pushbuf[TTY_PUSH_MAX]; // push back buffer for full key codes + ssize_t push_count; + uint8_t cpushbuf[TTY_PUSH_MAX]; // low level push back buffer for bytes + ssize_t cpush_count; + long esc_initial_timeout; // initial ms wait to see if ESC starts an escape sequence + long esc_timeout; // follow up delay for characters in an escape sequence + #if defined(_WIN32) + HANDLE hcon; // console input handle + DWORD hcon_orig_mode; // original console mode + #else + struct termios orig_ios; // original terminal settings + struct termios raw_ios; // raw terminal settings + #endif +}; + + +//------------------------------------------------------------- +// Forward declarations of platform dependent primitives below +//------------------------------------------------------------- + +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms); // does not modify `c` when no input (false is returned) + +//------------------------------------------------------------- +// Key code helpers +//------------------------------------------------------------- + +ic_private bool code_is_ascii_char(code_t c, char* chr ) { + if (c >= ' ' && c <= 0x7F) { + if (chr != NULL) *chr = (char)c; + return true; + } + else { + if (chr != NULL) *chr = 0; + return false; + } +} + +ic_private bool code_is_unicode(code_t c, unicode_t* uchr) { + if (c <= KEY_UNICODE_MAX) { + if (uchr != NULL) *uchr = c; + return true; + } + else { + if (uchr != NULL) *uchr = 0; + return false; + } +} + +ic_private bool code_is_virt_key(code_t c ) { + return (KEY_NO_MODS(c) <= 0x20 || KEY_NO_MODS(c) >= KEY_VIRT); +} + + +//------------------------------------------------------------- +// Read a key code +//------------------------------------------------------------- +static code_t modify_code( code_t code ); + +static code_t tty_read_utf8( tty_t* tty, uint8_t c0 ) { + uint8_t buf[5]; + memset(buf, 0, 5); + + // try to read as many bytes as potentially needed + buf[0] = c0; + ssize_t count = 1; + if (c0 > 0x7F) { + if (tty_readc_noblock(tty, buf+count, tty->esc_timeout)) { + count++; + if (c0 > 0xDF) { + if (tty_readc_noblock(tty, buf+count, tty->esc_timeout)) { + count++; + if (c0 > 0xEF) { + if (tty_readc_noblock(tty, buf+count, tty->esc_timeout)) { + count++; + } + } + } + } + } + } + + buf[count] = 0; + debug_msg("tty: read utf8: count: %zd: %02x,%02x,%02x,%02x", count, buf[0], buf[1], buf[2], buf[3]); + + // decode the utf8 to unicode + ssize_t read = 0; + code_t code = key_unicode(unicode_from_qutf8(buf, count, &read)); + + // push back unused bytes (in the case of invalid utf8) + while (count > read) { + count--; + if (count >= 0 && count <= 4) { // to help the static analyzer + tty_cpush_char(tty, buf[count]); + } + } + return code; +} + +// pop a code from the pushback buffer. +static bool tty_code_pop(tty_t* tty, code_t* code); + + +// read a single char/key +ic_private bool tty_read_timeout(tty_t* tty, long timeout_ms, code_t* code) +{ + // is there a push_count back code? + if (tty_code_pop(tty,code)) { + return code; + } + + // read a single char/byte from a character stream + uint8_t c; + if (!tty_readc_noblock(tty, &c, timeout_ms)) return false; + + if (c == KEY_ESC) { + // escape sequence? + *code = tty_read_esc(tty, tty->esc_initial_timeout, tty->esc_timeout); + } + else if (c <= 0x7F) { + // ascii + *code = key_unicode(c); + } + else if (tty->is_utf8) { + // utf8 sequence + *code = tty_read_utf8(tty,c); + } + else { + // c >= 0x80 but tty is not utf8; use raw plane so we can translate it back in the end + *code = key_unicode( unicode_from_raw(c) ); + } + + *code = modify_code(*code); + return true; +} + +// Transform virtual keys to be more portable across platforms +static code_t modify_code( code_t code ) { + code_t key = KEY_NO_MODS(code); + code_t mods = KEY_MODS(code); + debug_msg( "tty: readc %s%s%s 0x%03x ('%c')\n", + mods&KEY_MOD_SHIFT ? "shift+" : "", mods&KEY_MOD_CTRL ? "ctrl+" : "", mods&KEY_MOD_ALT ? "alt+" : "", + key, (key >= ' ' && key <= '~' ? key : ' ')); + + // treat KEY_RUBOUT (0x7F) as KEY_BACKSP + if (key == KEY_RUBOUT) { + code = KEY_BACKSP | mods; + } + // ctrl+'_' is translated to '\x1F' on Linux, translate it back + else if (key == key_char('\x1F') && (mods & KEY_MOD_ALT) == 0) { + key = '_'; + code = WITH_CTRL(key_char('_')); + } + // treat ctrl/shift + enter always as KEY_LINEFEED for portability + else if (key == KEY_ENTER && (mods == KEY_MOD_SHIFT || mods == KEY_MOD_ALT || mods == KEY_MOD_CTRL)) { + code = KEY_LINEFEED; + } + // treat ctrl+tab always as shift+tab for portability + else if (code == WITH_CTRL(KEY_TAB)) { + code = KEY_SHIFT_TAB; + } + // treat ctrl+end/alt+>/alt-down and ctrl+home/alt+') || code == WITH_CTRL(KEY_END)) { + code = KEY_PAGEDOWN; + } + else if (code == WITH_ALT(KEY_UP) || code == WITH_ALT('<') || code == WITH_CTRL(KEY_HOME)) { + code = KEY_PAGEUP; + } + + // treat C0 codes without KEY_MOD_CTRL + if (key < ' ' && (mods&KEY_MOD_CTRL) != 0) { + code &= ~KEY_MOD_CTRL; + } + + return code; +} + + +// read a single char/key +ic_private code_t tty_read(tty_t* tty) +{ + code_t code; + if (!tty_read_timeout(tty, -1, &code)) return KEY_NONE; + return code; +} + +//------------------------------------------------------------- +// Read back an ANSI query response +//------------------------------------------------------------- + +ic_private bool tty_read_esc_response(tty_t* tty, char esc_start, bool final_st, char* buf, ssize_t buflen ) +{ + buf[0] = 0; + ssize_t len = 0; + uint8_t c = 0; + if (!tty_readc_noblock(tty, &c, 2*tty->esc_initial_timeout) || c != '\x1B') { + debug_msg("initial esc response failed: 0x%02x\n", c); + return false; + } + if (!tty_readc_noblock(tty, &c, tty->esc_timeout) || (c != esc_start)) return false; + while( len < buflen ) { + if (!tty_readc_noblock(tty, &c, tty->esc_timeout)) return false; + if (final_st) { + // OSC is terminated by BELL, or ESC \ (ST) (and STX) + if (c=='\x07' || c=='\x02') { + break; + } + else if (c=='\x1B') { + uint8_t c1; + if (!tty_readc_noblock(tty, &c1, tty->esc_timeout)) return false; + if (c1=='\\') break; + tty_cpush_char(tty,c1); + } + } + else { + if (c == '\x02') { // STX + break; + } + else if (!((c >= '0' && c <= '9') || strchr("<=>?;:",c) != NULL)) { + buf[len++] = (char)c; // for non-OSC save the terminating character + break; + } + } + buf[len++] = (char)c; + } + buf[len] = 0; + debug_msg("tty: escape query response: %s\n", buf); + return true; +} + +//------------------------------------------------------------- +// High level code pushback +//------------------------------------------------------------- + +static bool tty_code_pop( tty_t* tty, code_t* code ) { + if (tty->push_count <= 0) return false; + tty->push_count--; + *code = tty->pushbuf[tty->push_count]; + return true; +} + +ic_private void tty_code_pushback( tty_t* tty, code_t c ) { + // note: must be signal safe + if (tty->push_count >= TTY_PUSH_MAX) return; + tty->pushbuf[tty->push_count] = c; + tty->push_count++; +} + + +//------------------------------------------------------------- +// low-level character pushback (for escape sequences and windows) +//------------------------------------------------------------- + +ic_private bool tty_cpop(tty_t* tty, uint8_t* c) { + if (tty->cpush_count <= 0) { // do not modify c on failure (see `tty_decode_unicode`) + return false; + } + else { + tty->cpush_count--; + *c = tty->cpushbuf[tty->cpush_count]; + return true; + } +} + +static void tty_cpush(tty_t* tty, const char* s) { + ssize_t len = ic_strlen(s); + if (tty->push_count + len > TTY_PUSH_MAX) { + debug_msg("tty: cpush buffer full! (pushing %s)\n", s); + assert(false); + return; + } + for (ssize_t i = 0; i < len; i++) { + tty->cpushbuf[tty->cpush_count + i] = (uint8_t)( s[len - i - 1] ); + } + tty->cpush_count += len; + return; +} + +// convenience function for small sequences +static void tty_cpushf(tty_t* tty, const char* fmt, ...) { + va_list args; + va_start(args,fmt); + char buf[TTY_PUSH_MAX+1]; + vsnprintf(buf,TTY_PUSH_MAX,fmt,args); + buf[TTY_PUSH_MAX] = 0; + tty_cpush(tty,buf); + va_end(args); + return; +} + +ic_private void tty_cpush_char(tty_t* tty, uint8_t c) { + uint8_t buf[2]; + buf[0] = c; + buf[1] = 0; + tty_cpush(tty, (const char*)buf); +} + + +//------------------------------------------------------------- +// Push escape codes (used on Windows to insert keys) +//------------------------------------------------------------- + +static unsigned csi_mods(code_t mods) { + unsigned m = 1; + if (mods&KEY_MOD_SHIFT) m += 1; + if (mods&KEY_MOD_ALT) m += 2; + if (mods&KEY_MOD_CTRL) m += 4; + return m; +} + +// Push ESC [ ; ~ +static void tty_cpush_csi_vt( tty_t* tty, code_t mods, uint32_t vtcode ) { + tty_cpushf(tty,"\x1B[%u;%u~", vtcode, csi_mods(mods) ); +} + +// push ESC [ 1 ; +static void tty_cpush_csi_xterm( tty_t* tty, code_t mods, char xcode ) { + tty_cpushf(tty,"\x1B[1;%u%c", csi_mods(mods), xcode ); +} + +// push ESC [ ; u +static void tty_cpush_csi_unicode( tty_t* tty, code_t mods, uint32_t unicode ) { + if ((unicode < 0x80 && mods == 0) || + (mods == KEY_MOD_CTRL && unicode < ' ' && unicode != KEY_TAB && unicode != KEY_ENTER + && unicode != KEY_LINEFEED && unicode != KEY_BACKSP) || + (mods == KEY_MOD_SHIFT && unicode >= ' ' && unicode <= KEY_RUBOUT)) { + tty_cpush_char(tty,(uint8_t)unicode); + } + else { + tty_cpushf(tty,"\x1B[%u;%uu", unicode, csi_mods(mods) ); + } +} + +//------------------------------------------------------------- +// Init +//------------------------------------------------------------- + +static bool tty_init_raw(tty_t* tty); +static void tty_done_raw(tty_t* tty); + +static bool tty_init_utf8(tty_t* tty) { + #ifdef _WIN32 + tty->is_utf8 = true; + #else + const char* loc = setlocale(LC_ALL,""); + tty->is_utf8 = (ic_icontains(loc,"UTF-8") || ic_icontains(loc,"utf8") || ic_stricmp(loc,"C") == 0); + debug_msg("tty: utf8: %s (loc=%s)\n", tty->is_utf8 ? "true" : "false", loc); + #endif + return true; +} + +ic_private tty_t* tty_new(alloc_t* mem, int fd_in) +{ + tty_t* tty = mem_zalloc_tp(mem, tty_t); + tty->mem = mem; + tty->fd_in = (fd_in < 0 ? STDIN_FILENO : fd_in); + #if defined(__APPLE__) + tty->esc_initial_timeout = 200; // apple use ESC+ for alt- + #else + tty->esc_initial_timeout = 100; + #endif + tty->esc_timeout = 10; + if (!(isatty(tty->fd_in) && tty_init_raw(tty) && tty_init_utf8(tty))) { + tty_free(tty); + return NULL; + } + return tty; +} + +ic_private void tty_free(tty_t* tty) { + if (tty==NULL) return; + tty_end_raw(tty); + tty_done_raw(tty); + mem_free(tty->mem,tty); +} + +ic_private bool tty_is_utf8(const tty_t* tty) { + if (tty == NULL) return true; + return (tty->is_utf8); +} + +ic_private bool tty_term_resize_event(tty_t* tty) { + if (tty == NULL) return true; + if (tty->has_term_resize_event) { + if (!tty->term_resize_event) return false; + tty->term_resize_event = false; // reset. + } + return true; // always return true on systems without a resize event (more expensive but still ok) +} + +ic_private void tty_set_esc_delay(tty_t* tty, long initial_delay_ms, long followup_delay_ms) { + tty->esc_initial_timeout = (initial_delay_ms < 0 ? 0 : (initial_delay_ms > 1000 ? 1000 : initial_delay_ms)); + tty->esc_timeout = (followup_delay_ms < 0 ? 0 : (followup_delay_ms > 1000 ? 1000 : followup_delay_ms)); +} + +//------------------------------------------------------------- +// Unix +//------------------------------------------------------------- +#if !defined(_WIN32) + +static bool tty_readc_blocking(tty_t* tty, uint8_t* c) { + if (tty_cpop(tty,c)) return true; + *c = 0; + ssize_t nread = read(tty->fd_in, (char*)c, 1); + if (nread < 0 && errno == EINTR) { + // can happen on SIGWINCH signal for terminal resize + } + return (nread == 1); +} + + +// non blocking read -- with a small timeout used for reading escape sequences. +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms) +{ + // in our pushback buffer? + if (tty_cpop(tty, c)) return true; + + // blocking read? + if (timeout_ms < 0) { + return tty_readc_blocking(tty,c); + } + + // if supported, peek first if any char is available. + #if defined(FIONREAD) + { int navail = 0; + if (ioctl(0, FIONREAD, &navail) == 0) { + if (navail >= 1) { + return tty_readc_blocking(tty, c); + } + else if (timeout_ms == 0) { + return false; // return early if there is no input available (with a zero timeout) + } + } + } + #endif + + // otherwise block for at most timeout milliseconds + #if defined(FD_SET) + // we can use select to detect when input becomes available + fd_set readset; + struct timeval time; + FD_ZERO(&readset); + FD_SET(tty->fd_in, &readset); + time.tv_sec = (timeout_ms > 0 ? timeout_ms / 1000 : 0); + time.tv_usec = (timeout_ms > 0 ? 1000*(timeout_ms % 1000) : 0); + if (select(tty->fd_in + 1, &readset, NULL, NULL, &time) == 1) { + // input available + return tty_readc_blocking(tty, c); + } + #else + // no select, we cannot timeout; use usleeps :-( + // todo: this seems very rare nowadays; should be even support this? + do { + // peek ahead if possible + #if defined(FIONREAD) + int navail = 0; + if (ioctl(0, FIONREAD, &navail) == 0 && navail >= 1) { + return tty_readc_blocking(tty, c); + } + #elif defined(O_NONBLOCK) + // use a temporary non-blocking read mode + int fstatus = fcntl(tty->fd_in, F_GETFL, 0); + if (fstatus != -1) { + if (fcntl(tty->fd_in, F_SETFL, (fstatus | O_NONBLOCK)) != -1) { + char buf[2] = { 0, 0 }; + ssize_t nread = read(tty->fd_in, buf, 1); + fcntl(tty->fd_in, F_SETFL, fstatus); + if (nread >= 1) { + *c = (uint8_t)buf[0]; + return true; + } + } + } + #else + #error "define an nonblocking read for this platform" + #endif + // and sleep a bit + if (timeout_ms > 0) { + usleep(50*1000L); // sleep at most 0.05s at a time + timeout_ms -= 100; + if (timeout_ms < 0) { timeout_ms = 0; } + } + } + while (timeout_ms > 0); + #endif + return false; +} + +#if defined(TIOCSTI) +ic_private bool tty_async_stop(const tty_t* tty) { + // insert ^C in the input stream + char c = KEY_CTRL_C; + return (ioctl(tty->fd_in, TIOCSTI, &c) >= 0); +} +#else +ic_private bool tty_async_stop(const tty_t* tty) { + return false; +} +#endif + +// We install various signal handlers to restore the terminal settings +// in case of a terminating signal. This is also used to catch terminal window resizes. +// This is not strictly needed so this can be disabled on +// (older) platforms that do not support signal handling well. +#if defined(SIGWINCH) && defined(SA_RESTART) // ensure basic signal functionality is defined + +// store the tty in a global so we access it on unexpected termination +static tty_t* sig_tty; // = NULL + +// Catch all termination signals (and SIGWINCH) +typedef struct signal_handler_s { + int signum; + union { + int _avoid_warning; + struct sigaction previous; + } action; +} signal_handler_t; + +static signal_handler_t sighandlers[] = { + { SIGWINCH, {0} }, + { SIGTERM , {0} }, + { SIGINT , {0} }, + { SIGQUIT , {0} }, + { SIGHUP , {0} }, + { SIGSEGV , {0} }, + { SIGTRAP , {0} }, + { SIGBUS , {0} }, + { SIGTSTP , {0} }, + { SIGTTIN , {0} }, + { SIGTTOU , {0} }, + { 0 , {0} } +}; + +static bool sigaction_is_valid( struct sigaction* sa ) { + return (sa->sa_sigaction != NULL && sa->sa_handler != SIG_DFL && sa->sa_handler != SIG_IGN); +} + +// Generic signal handler +static void sig_handler(int signum, siginfo_t* siginfo, void* uap ) { + if (signum == SIGWINCH) { + if (sig_tty != NULL) { + sig_tty->term_resize_event = true; + } + } + else { + // the rest are termination signals; restore the terminal mode. (`tcsetattr` is signal-safe) + if (sig_tty != NULL && sig_tty->raw_enabled) { + tcsetattr(sig_tty->fd_in, TCSAFLUSH, &sig_tty->orig_ios); + sig_tty->raw_enabled = false; + } + } + // call previous handler + signal_handler_t* sh = sighandlers; + while( sh->signum != 0 && sh->signum != signum) { sh++; } + if (sh->signum == signum) { + if (sigaction_is_valid(&sh->action.previous)) { + (sh->action.previous.sa_sigaction)(signum, siginfo, uap); + } + } +} + +static void signals_install(tty_t* tty) { + sig_tty = tty; + // generic signal handler + struct sigaction handler; + memset(&handler,0,sizeof(handler)); + sigemptyset(&handler.sa_mask); + handler.sa_sigaction = &sig_handler; + handler.sa_flags = SA_RESTART; + // install for all signals + for( signal_handler_t* sh = sighandlers; sh->signum != 0; sh++ ) { + if (sigaction( sh->signum, NULL, &sh->action.previous) == 0) { // get previous + if (sh->action.previous.sa_handler != SIG_IGN) { // if not to be ignored + if (sigaction( sh->signum, &handler, &sh->action.previous ) < 0) { // install our handler + sh->action.previous.sa_sigaction = NULL; // do not restore on error + } + else if (sh->signum == SIGWINCH) { + sig_tty->has_term_resize_event = true; + }; + } + } + } +} + +static void signals_restore(void) { + // restore all signal handlers + for( signal_handler_t* sh = sighandlers; sh->signum != 0; sh++ ) { + if (sigaction_is_valid(&sh->action.previous)) { + sigaction( sh->signum, &sh->action.previous, NULL ); + }; + } + sig_tty = NULL; +} + +#else +static void signals_install(tty_t* tty) { + ic_unused(tty); + // nothing +} +static void signals_restore(void) { + // nothing +} + +#endif + +ic_private bool tty_start_raw(tty_t* tty) { + if (tty == NULL) return false; + if (tty->raw_enabled) return true; + if (tcsetattr(tty->fd_in,TCSAFLUSH,&tty->raw_ios) < 0) return false; + tty->raw_enabled = true; + return true; +} + +ic_private void tty_end_raw(tty_t* tty) { + if (tty == NULL) return; + if (!tty->raw_enabled) return; + tty->cpush_count = 0; + if (tcsetattr(tty->fd_in,TCSAFLUSH,&tty->orig_ios) < 0) return; + tty->raw_enabled = false; +} + +static bool tty_init_raw(tty_t* tty) +{ + // Set input to raw mode. See . + if (tcgetattr(tty->fd_in,&tty->orig_ios) == -1) return false; + tty->raw_ios = tty->orig_ios; + // input: no break signal, no \r to \n, no parity check, no 8-bit to 7-bit, no flow control + tty->raw_ios.c_iflag &= ~(unsigned long)(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + // control: allow 8-bit + tty->raw_ios.c_cflag |= CS8; + // local: no echo, no line-by-line (canonical), no extended input processing, no signals for ^z,^c + tty->raw_ios.c_lflag &= ~(unsigned long)(ECHO | ICANON | IEXTEN | ISIG); + // 1 byte at a time, no delay + tty->raw_ios.c_cc[VTIME] = 0; + tty->raw_ios.c_cc[VMIN] = 1; + + // store in global so our signal handlers can restore the terminal mode + signals_install(tty); + + return true; +} + +static void tty_done_raw(tty_t* tty) { + ic_unused(tty); + signals_restore(); +} + + +#else + +//------------------------------------------------------------- +// Windows +// For best portability we push CSI escape sequences directly +// to the character stream (instead of returning key codes). +//------------------------------------------------------------- + +static void tty_waitc_console(tty_t* tty, long timeout_ms); + +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms) { // don't modify `c` if there is no input + // in our pushback buffer? + if (tty_cpop(tty, c)) return true; + // any events in the input queue? + tty_waitc_console(tty, timeout_ms); + return tty_cpop(tty, c); +} + +// Read from the console input events and push escape codes into the tty cbuffer. +static void tty_waitc_console(tty_t* tty, long timeout_ms) +{ + // wait for a key down event + INPUT_RECORD inp; + DWORD count; + uint32_t surrogate_hi = 0; + while (true) { + // check if there are events if in non-blocking timeout mode + if (timeout_ms >= 0) { + // first peek ahead + if (!GetNumberOfConsoleInputEvents(tty->hcon, &count)) return; + if (count == 0) { + if (timeout_ms == 0) { + // out of time + return; + } + else { + // wait for input events for at most timeout milli seconds + ULONGLONG start_ms = GetTickCount64(); + DWORD res = WaitForSingleObject(tty->hcon, (DWORD)timeout_ms); + switch (res) { + case WAIT_OBJECT_0: { + // input is available, decrease our timeout + ULONGLONG waited_ms = (GetTickCount64() - start_ms); + timeout_ms -= (long)waited_ms; + if (timeout_ms < 0) { + timeout_ms = 0; + } + break; + } + case WAIT_TIMEOUT: + case WAIT_ABANDONED: + case WAIT_FAILED: + default: + return; + } + } + } + } + + // (blocking) Read from the input + if (!ReadConsoleInputW(tty->hcon, &inp, 1, &count)) return; + if (count != 1) return; + + // resize event? + if (inp.EventType == WINDOW_BUFFER_SIZE_EVENT) { + tty->term_resize_event = true; + continue; + } + + // wait for key down events + if (inp.EventType != KEY_EVENT) continue; + + // the modifier state + DWORD modstate = inp.Event.KeyEvent.dwControlKeyState; + + // we need to handle shift up events separately + if (!inp.Event.KeyEvent.bKeyDown && inp.Event.KeyEvent.wVirtualKeyCode == VK_SHIFT) { + modstate &= (DWORD)~SHIFT_PRESSED; + } + + // ignore AltGr + DWORD altgr = LEFT_CTRL_PRESSED | RIGHT_ALT_PRESSED; + if ((modstate & altgr) == altgr) { modstate &= ~altgr; } + + + // get modifiers + code_t mods = 0; + if ((modstate & ( RIGHT_CTRL_PRESSED | LEFT_CTRL_PRESSED )) != 0) mods |= KEY_MOD_CTRL; + if ((modstate & ( RIGHT_ALT_PRESSED | LEFT_ALT_PRESSED )) != 0) mods |= KEY_MOD_ALT; + if ((modstate & SHIFT_PRESSED) != 0) mods |= KEY_MOD_SHIFT; + + // virtual keys + uint32_t chr = (uint32_t)inp.Event.KeyEvent.uChar.UnicodeChar; + WORD virt = inp.Event.KeyEvent.wVirtualKeyCode; + debug_msg("tty: console %s: %s%s%s virt 0x%04x, chr 0x%04x ('%c')\n", inp.Event.KeyEvent.bKeyDown ? "down" : "up", mods&KEY_MOD_CTRL ? "ctrl-" : "", mods&KEY_MOD_ALT ? "alt-" : "", mods&KEY_MOD_SHIFT ? "shift-" : "", virt, chr, chr); + + // only process keydown events (except for Alt-up which is used for unicode pasting...) + if (!inp.Event.KeyEvent.bKeyDown && virt != VK_MENU) { + continue; + } + + if (chr == 0) { + switch (virt) { + case VK_UP: tty_cpush_csi_xterm(tty, mods, 'A'); return; + case VK_DOWN: tty_cpush_csi_xterm(tty, mods, 'B'); return; + case VK_RIGHT: tty_cpush_csi_xterm(tty, mods, 'C'); return; + case VK_LEFT: tty_cpush_csi_xterm(tty, mods, 'D'); return; + case VK_END: tty_cpush_csi_xterm(tty, mods, 'F'); return; + case VK_HOME: tty_cpush_csi_xterm(tty, mods, 'H'); return; + case VK_DELETE: tty_cpush_csi_vt(tty,mods,3); return; + case VK_PRIOR: tty_cpush_csi_vt(tty,mods,5); return; //page up + case VK_NEXT: tty_cpush_csi_vt(tty,mods,6); return; //page down + case VK_TAB: tty_cpush_csi_unicode(tty,mods,9); return; + case VK_RETURN: tty_cpush_csi_unicode(tty,mods,13); return; + default: { + uint32_t vtcode = 0; + if (virt >= VK_F1 && virt <= VK_F5) { + vtcode = 10 + (virt - VK_F1); + } + else if (virt >= VK_F6 && virt <= VK_F10) { + vtcode = 17 + (virt - VK_F6); + } + else if (virt >= VK_F11 && virt <= VK_F12) { + vtcode = 13 + (virt - VK_F11); + } + if (vtcode > 0) { + tty_cpush_csi_vt(tty,mods,vtcode); + return; + } + } + } + // ignore other control keys (shift etc). + } + // high surrogate pair + else if (chr >= 0xD800 && chr <= 0xDBFF) { + surrogate_hi = (chr - 0xD800); + } + // low surrogate pair + else if (chr >= 0xDC00 && chr <= 0xDFFF) { + chr = ((surrogate_hi << 10) + (chr - 0xDC00) + 0x10000); + tty_cpush_csi_unicode(tty,mods,chr); + surrogate_hi = 0; + return; + } + // regular character + else { + tty_cpush_csi_unicode(tty,mods,chr); + return; + } + } +} + +ic_private bool tty_async_stop(const tty_t* tty) { + // send ^c + INPUT_RECORD events[2]; + memset(events, 0, 2*sizeof(INPUT_RECORD)); + events[0].EventType = KEY_EVENT; + events[0].Event.KeyEvent.bKeyDown = TRUE; + events[0].Event.KeyEvent.uChar.AsciiChar = KEY_CTRL_C; + events[1] = events[0]; + events[1].Event.KeyEvent.bKeyDown = FALSE; + DWORD nwritten = 0; + WriteConsoleInput(tty->hcon, events, 2, &nwritten); + return (nwritten == 2); +} + +ic_private bool tty_start_raw(tty_t* tty) { + if (tty->raw_enabled) return true; + GetConsoleMode(tty->hcon,&tty->hcon_orig_mode); + DWORD mode = ENABLE_QUICK_EDIT_MODE // cut&paste allowed + | ENABLE_WINDOW_INPUT // to catch resize events + // | ENABLE_VIRTUAL_TERMINAL_INPUT + // | ENABLE_PROCESSED_INPUT + ; + SetConsoleMode(tty->hcon, mode ); + tty->raw_enabled = true; + return true; +} + +ic_private void tty_end_raw(tty_t* tty) { + if (!tty->raw_enabled) return; + SetConsoleMode(tty->hcon, tty->hcon_orig_mode ); + tty->raw_enabled = false; +} + +static bool tty_init_raw(tty_t* tty) { + tty->hcon = GetStdHandle( STD_INPUT_HANDLE ); + tty->has_term_resize_event = true; + return true; +} + +static void tty_done_raw(tty_t* tty) { + ic_unused(tty); +} + +#endif + + diff --git a/extern/isocline/src/tty.h b/extern/isocline/src/tty.h new file mode 100644 index 00000000..a0062bf3 --- /dev/null +++ b/extern/isocline/src/tty.h @@ -0,0 +1,160 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_TTY_H +#define IC_TTY_H + +#include "common.h" + +//------------------------------------------------------------- +// TTY/Keyboard input +//------------------------------------------------------------- + +// Key code +typedef uint32_t code_t; + +// TTY interface +struct tty_s; +typedef struct tty_s tty_t; + + +ic_private tty_t* tty_new(alloc_t* mem, int fd_in); +ic_private void tty_free(tty_t* tty); + +ic_private bool tty_is_utf8(const tty_t* tty); +ic_private bool tty_start_raw(tty_t* tty); +ic_private void tty_end_raw(tty_t* tty); +ic_private code_t tty_read(tty_t* tty); +ic_private bool tty_read_timeout(tty_t* tty, long timeout_ms, code_t* c ); + +ic_private void tty_code_pushback( tty_t* tty, code_t c ); +ic_private bool code_is_ascii_char(code_t c, char* chr ); +ic_private bool code_is_unicode(code_t c, unicode_t* uchr); +ic_private bool code_is_virt_key(code_t c ); + +ic_private bool tty_term_resize_event(tty_t* tty); // did the terminal resize? +ic_private bool tty_async_stop(const tty_t* tty); // unblock the read asynchronously +ic_private void tty_set_esc_delay(tty_t* tty, long initial_delay_ms, long followup_delay_ms); + +// shared between tty.c and tty_esc.c: low level character push +ic_private void tty_cpush_char(tty_t* tty, uint8_t c); +ic_private bool tty_cpop(tty_t* tty, uint8_t* c); +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms); +ic_private code_t tty_read_esc(tty_t* tty, long esc_initial_timeout, long esc_timeout); // in tty_esc.c + +// used by term.c to read back ANSI escape responses +ic_private bool tty_read_esc_response(tty_t* tty, char esc_start, bool final_st, char* buf, ssize_t buflen ); + + +//------------------------------------------------------------- +// Key codes: a code_t is 32 bits. +// we use the bottom 24 (nah, 21) bits for unicode (up to x0010FFFF) +// The codes after x01000000 are for virtual keys +// and events use x02000000. +// The top 4 bits are used for modifiers. +//------------------------------------------------------------- + +static inline code_t key_char( char c ) { + // careful about signed character conversion (negative char ~> 0x80 - 0xFF) + return ((uint8_t)c); +} + +static inline code_t key_unicode( unicode_t u ) { + return u; +} + + +#define KEY_MOD_SHIFT (0x10000000U) +#define KEY_MOD_ALT (0x20000000U) +#define KEY_MOD_CTRL (0x40000000U) + +#define KEY_NO_MODS(k) (k & 0x0FFFFFFFU) +#define KEY_MODS(k) (k & 0xF0000000U) + +#define WITH_SHIFT(x) (x | KEY_MOD_SHIFT) +#define WITH_ALT(x) (x | KEY_MOD_ALT) +#define WITH_CTRL(x) (x | KEY_MOD_CTRL) + +#define KEY_NONE (0) +#define KEY_CTRL_A (1) +#define KEY_CTRL_B (2) +#define KEY_CTRL_C (3) +#define KEY_CTRL_D (4) +#define KEY_CTRL_E (5) +#define KEY_CTRL_F (6) +#define KEY_BELL (7) +#define KEY_BACKSP (8) +#define KEY_TAB (9) +#define KEY_LINEFEED (10) // ctrl/shift + enter is considered KEY_LINEFEED +#define KEY_CTRL_K (11) +#define KEY_CTRL_L (12) +#define KEY_ENTER (13) +#define KEY_CTRL_N (14) +#define KEY_CTRL_O (15) +#define KEY_CTRL_P (16) +#define KEY_CTRL_Q (17) +#define KEY_CTRL_R (18) +#define KEY_CTRL_S (19) +#define KEY_CTRL_T (20) +#define KEY_CTRL_U (21) +#define KEY_CTRL_V (22) +#define KEY_CTRL_W (23) +#define KEY_CTRL_X (24) +#define KEY_CTRL_Y (25) +#define KEY_CTRL_Z (26) +#define KEY_ESC (27) +#define KEY_SPACE (32) +#define KEY_RUBOUT (127) // always translated to KEY_BACKSP +#define KEY_UNICODE_MAX (0x0010FFFFU) + + +#define KEY_VIRT (0x01000000U) +#define KEY_UP (KEY_VIRT+0) +#define KEY_DOWN (KEY_VIRT+1) +#define KEY_LEFT (KEY_VIRT+2) +#define KEY_RIGHT (KEY_VIRT+3) +#define KEY_HOME (KEY_VIRT+4) +#define KEY_END (KEY_VIRT+5) +#define KEY_DEL (KEY_VIRT+6) +#define KEY_PAGEUP (KEY_VIRT+7) +#define KEY_PAGEDOWN (KEY_VIRT+8) +#define KEY_INS (KEY_VIRT+9) + +#define KEY_F1 (KEY_VIRT+11) +#define KEY_F2 (KEY_VIRT+12) +#define KEY_F3 (KEY_VIRT+13) +#define KEY_F4 (KEY_VIRT+14) +#define KEY_F5 (KEY_VIRT+15) +#define KEY_F6 (KEY_VIRT+16) +#define KEY_F7 (KEY_VIRT+17) +#define KEY_F8 (KEY_VIRT+18) +#define KEY_F9 (KEY_VIRT+19) +#define KEY_F10 (KEY_VIRT+20) +#define KEY_F11 (KEY_VIRT+21) +#define KEY_F12 (KEY_VIRT+22) +#define KEY_F(n) (KEY_F1 + (n) - 1) + +#define KEY_EVENT_BASE (0x02000000U) +#define KEY_EVENT_RESIZE (KEY_EVENT_BASE+1) +#define KEY_EVENT_AUTOTAB (KEY_EVENT_BASE+2) +#define KEY_EVENT_STOP (KEY_EVENT_BASE+3) + +// Convenience +#define KEY_CTRL_UP (WITH_CTRL(KEY_UP)) +#define KEY_CTRL_DOWN (WITH_CTRL(KEY_DOWN)) +#define KEY_CTRL_LEFT (WITH_CTRL(KEY_LEFT)) +#define KEY_CTRL_RIGHT (WITH_CTRL(KEY_RIGHT)) +#define KEY_CTRL_HOME (WITH_CTRL(KEY_HOME)) +#define KEY_CTRL_END (WITH_CTRL(KEY_END)) +#define KEY_CTRL_DEL (WITH_CTRL(KEY_DEL)) +#define KEY_CTRL_PAGEUP (WITH_CTRL(KEY_PAGEUP)) +#define KEY_CTRL_PAGEDOWN (WITH_CTRL(KEY_PAGEDOWN))) +#define KEY_CTRL_INS (WITH_CTRL(KEY_INS)) + +#define KEY_SHIFT_TAB (WITH_SHIFT(KEY_TAB)) + +#endif // IC_TTY_H diff --git a/extern/isocline/src/tty_esc.c b/extern/isocline/src/tty_esc.c new file mode 100644 index 00000000..0ac8761d --- /dev/null +++ b/extern/isocline/src/tty_esc.c @@ -0,0 +1,401 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include "tty.h" + +/*------------------------------------------------------------- +Decoding escape sequences to key codes. +This is a bit tricky there are many variants to encode keys as escape sequences, see for example: +- . +- +- +- +- + +Generally, for our purposes we accept a subset of escape sequences as: + + escseq ::= ESC + | ESC char + | ESC start special? (number (';' modifiers)?)? final + +where: + char ::= [\x00-\xFF] # any character + special ::= [:<=>?] + number ::= [0-9+] + modifiers ::= [1-9] + intermediate ::= [\x20-\x2F] # !"#$%&'()*+,-./ + final ::= [\x40-\x7F] # @A–Z[\]^_`a–z{|}~ + ESC ::= '\x1B' + CSI ::= ESC '[' + SS3 ::= ESC 'O' + +In ECMA48 `special? (number (';' modifiers)?)?` is the more liberal `[\x30-\x3F]*` +but that seems never used for key codes. If the number (vtcode or unicode) or the +modifiers are not given, we assume these are '1'. +We then accept the following key sequences: + + key ::= ESC # lone ESC + | ESC char # Alt+char + | ESC '[' special? vtcode ';' modifiers '~' # vt100 codes + | ESC '[' special? '1' ';' modifiers [A-Z] # xterm codes + | ESC 'O' special? '1' ';' modifiers [A-Za-z] # SS3 codes + | ESC '[' special? unicode ';' modifiers 'u' # direct unicode code + +Moreover, we translate the following special cases that do not fit into the above grammar. +First we translate away special starter sequences: +--------------------------------------------------------------------- + ESC '[' '[' .. ~> ESC '[' .. # Linux sometimes uses extra '[' for CSI + ESC '[' 'O' .. ~> ESC 'O' .. # Linux sometimes uses extra '[' for SS3 + ESC 'o' .. ~> ESC 'O' .. # Eterm: ctrl + SS3 + ESC '?' .. ~> ESC 'O' .. # vt52 treated as SS3 + +And then translate the following special cases into a standard form: +--------------------------------------------------------------------- + ESC '[' .. '@' ~> ESC '[' '3' '~' # Del on Mach + ESC '[' .. '9' ~> ESC '[' '2' '~' # Ins on Mach + ESC .. [^@$] ~> ESC .. '~' # ETerm,xrvt,urxt: ^ = ctrl, $ = shift, @ = alt + ESC '[' [a-d] ~> ESC '[' '1' ';' '2' [A-D] # Eterm shift+ + ESC 'O' [1-9] final ~> ESC 'O' '1' ';' [1-9] final # modifiers as parameter 1 (like on Haiku) + ESC '[' [1-9] [^~u] ~> ESC 'O' '1' ';' [1-9] final # modifiers as parameter 1 + +The modifier keys are encoded as "(modifiers-1) & mask" where the +shift mask is 0x01, alt 0x02 and ctrl 0x04. Therefore: +------------------------------------------------------------ + 1: - 5: ctrl 9: alt (for minicom) + 2: shift 6: shift+ctrl + 3: alt 7: alt+ctrl + 4: shift+alt 8: shift+alt+ctrl + +The different encodings fox vt100, xterm, and SS3 are: + +vt100: ESC [ vtcode ';' modifiers '~' +-------------------------------------- + 1: Home 10-15: F1-F5 + 2: Ins 16 : F5 + 3: Del 17-21: F6-F10 + 4: End 23-26: F11-F14 + 5: PageUp 28 : F15 + 6: PageDn 29 : F16 + 7: Home 31-34: F17-F20 + 8: End + +xterm: ESC [ 1 ';' modifiers [A-Z] +----------------------------------- + A: Up N: F2 + B: Down O: F3 + C: Right P: F4 + D: Left Q: F5 + E: '5' R: F6 + F: End S: F7 + G: T: F8 + H: Home U: PageDn + I: PageUp V: PageUp + J: W: F11 + K: X: F12 + L: Ins Y: End + M: F1 Z: shift+Tab + +SS3: ESC 'O' 1 ';' modifiers [A-Za-z] +--------------------------------------- + (normal) (numpad) + A: Up N: a: Up n: + B: Down O: b: Down o: + C: Right P: F1 c: Right p: Ins + D: Left Q: F2 d: Left q: End + E: '5' R: F3 e: r: Down + F: End S: F4 f: s: PageDn + G: T: F5 g: t: Left + H: Home U: F6 h: u: '5' + I: Tab V: F7 i: v: Right + J: W: F8 j: '*' w: Home + K: X: F9 k: '+' x: Up + L: Y: F10 l: ',' y: PageUp + M: \x0A '\n' Z: shift+Tab m: '-' z: + +-------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Decode escape sequences +//------------------------------------------------------------- + +static code_t esc_decode_vt(uint32_t vt_code ) { + switch(vt_code) { + case 1: return KEY_HOME; + case 2: return KEY_INS; + case 3: return KEY_DEL; + case 4: return KEY_END; + case 5: return KEY_PAGEUP; + case 6: return KEY_PAGEDOWN; + case 7: return KEY_HOME; + case 8: return KEY_END; + default: + if (vt_code >= 10 && vt_code <= 15) return KEY_F(1 + (vt_code - 10)); + if (vt_code == 16) return KEY_F5; // minicom + if (vt_code >= 17 && vt_code <= 21) return KEY_F(6 + (vt_code - 17)); + if (vt_code >= 23 && vt_code <= 26) return KEY_F(11 + (vt_code - 23)); + if (vt_code >= 28 && vt_code <= 29) return KEY_F(15 + (vt_code - 28)); + if (vt_code >= 31 && vt_code <= 34) return KEY_F(17 + (vt_code - 31)); + } + return KEY_NONE; +} + +static code_t esc_decode_xterm( uint8_t xcode ) { + // ESC [ + switch(xcode) { + case 'A': return KEY_UP; + case 'B': return KEY_DOWN; + case 'C': return KEY_RIGHT; + case 'D': return KEY_LEFT; + case 'E': return '5'; // numpad 5 + case 'F': return KEY_END; + case 'H': return KEY_HOME; + case 'Z': return KEY_TAB | KEY_MOD_SHIFT; + // Freebsd: + case 'I': return KEY_PAGEUP; + case 'L': return KEY_INS; + case 'M': return KEY_F1; + case 'N': return KEY_F2; + case 'O': return KEY_F3; + case 'P': return KEY_F4; // note: differs from + case 'Q': return KEY_F5; + case 'R': return KEY_F6; + case 'S': return KEY_F7; + case 'T': return KEY_F8; + case 'U': return KEY_PAGEDOWN; // Mach + case 'V': return KEY_PAGEUP; // Mach + case 'W': return KEY_F11; + case 'X': return KEY_F12; + case 'Y': return KEY_END; // Mach + } + return KEY_NONE; +} + +static code_t esc_decode_ss3( uint8_t ss3_code ) { + // ESC O + switch(ss3_code) { + case 'A': return KEY_UP; + case 'B': return KEY_DOWN; + case 'C': return KEY_RIGHT; + case 'D': return KEY_LEFT; + case 'E': return '5'; // numpad 5 + case 'F': return KEY_END; + case 'H': return KEY_HOME; + case 'I': return KEY_TAB; + case 'Z': return KEY_TAB | KEY_MOD_SHIFT; + case 'M': return KEY_LINEFEED; + case 'P': return KEY_F1; + case 'Q': return KEY_F2; + case 'R': return KEY_F3; + case 'S': return KEY_F4; + // on Mach + case 'T': return KEY_F5; + case 'U': return KEY_F6; + case 'V': return KEY_F7; + case 'W': return KEY_F8; + case 'X': return KEY_F9; // '=' on vt220 + case 'Y': return KEY_F10; + // numpad + case 'a': return KEY_UP; + case 'b': return KEY_DOWN; + case 'c': return KEY_RIGHT; + case 'd': return KEY_LEFT; + case 'j': return '*'; + case 'k': return '+'; + case 'l': return ','; + case 'm': return '-'; + case 'n': return KEY_DEL; // '.' + case 'o': return '/'; + case 'p': return KEY_INS; + case 'q': return KEY_END; + case 'r': return KEY_DOWN; + case 's': return KEY_PAGEDOWN; + case 't': return KEY_LEFT; + case 'u': return '5'; + case 'v': return KEY_RIGHT; + case 'w': return KEY_HOME; + case 'x': return KEY_UP; + case 'y': return KEY_PAGEUP; + } + return KEY_NONE; +} + +static void tty_read_csi_num(tty_t* tty, uint8_t* ppeek, uint32_t* num, long esc_timeout) { + *num = 1; // default + ssize_t count = 0; + uint32_t i = 0; + while (*ppeek >= '0' && *ppeek <= '9' && count < 16) { + uint8_t digit = *ppeek - '0'; + if (!tty_readc_noblock(tty,ppeek,esc_timeout)) break; // peek is not modified in this case + count++; + i = 10*i + digit; + } + if (count > 0) *num = i; +} + +static code_t tty_read_csi(tty_t* tty, uint8_t c1, uint8_t peek, code_t mods0, long esc_timeout) { + // CSI starts with 0x9b (c1=='[') | ESC [ (c1=='[') | ESC [Oo?] (c1 == 'O') /* = SS3 */ + + // check for extra starter '[' (Linux sends ESC [ [ 15 ~ for F5 for example) + if (c1 == '[' && strchr("[Oo", (char)peek) != NULL) { + uint8_t cx = peek; + if (tty_readc_noblock(tty,&peek,esc_timeout)) { + c1 = cx; + } + } + + // "special" characters ('?' is used for private sequences) + uint8_t special = 0; + if (strchr(":<=>?",(char)peek) != NULL) { + special = peek; + if (!tty_readc_noblock(tty,&peek,esc_timeout)) { + tty_cpush_char(tty,special); // recover + return (key_unicode(c1) | KEY_MOD_ALT); // Alt+ + } + } + + // up to 2 parameters that default to 1 + uint32_t num1 = 1; + uint32_t num2 = 1; + tty_read_csi_num(tty,&peek,&num1,esc_timeout); + if (peek == ';') { + if (!tty_readc_noblock(tty,&peek,esc_timeout)) return KEY_NONE; + tty_read_csi_num(tty,&peek,&num2,esc_timeout); + } + + // the final character (we do not allow 'intermediate characters') + uint8_t final = peek; + code_t modifiers = mods0; + + debug_msg("tty: escape sequence: ESC %c %c %d;%d %c\n", c1, (special == 0 ? '_' : special), num1, num2, final); + + // Adjust special cases into standard ones. + if ((final == '@' || final == '9') && c1 == '[' && num1 == 1) { + // ESC [ @, ESC [ 9 : on Mach + if (final == '@') num1 = 3; // DEL + else if (final == '9') num1 = 2; // INS + final = '~'; + } + else if (final == '^' || final == '$' || final == '@') { + // Eterm/rxvt/urxt + if (final=='^') modifiers |= KEY_MOD_CTRL; + if (final=='$') modifiers |= KEY_MOD_SHIFT; + if (final=='@') modifiers |= KEY_MOD_SHIFT | KEY_MOD_CTRL; + final = '~'; + } + else if (c1 == '[' && final >= 'a' && final <= 'd') { // note: do not catch ESC [ .. u (for unicode) + // ESC [ [a-d] : on Eterm for shift+ cursor + modifiers |= KEY_MOD_SHIFT; + final = 'A' + (final - 'a'); + } + + if (((c1 == 'O') || (c1=='[' && final != '~' && final != 'u')) && + (num2 == 1 && num1 > 1 && num1 <= 8)) + { + // on haiku the modifier can be parameter 1, make it parameter 2 instead + num2 = num1; + num1 = 1; + } + + // parameter 2 determines the modifiers + if (num2 > 1 && num2 <= 9) { + if (num2 == 9) num2 = 3; // iTerm2 in xterm mode + num2--; + if (num2 & 0x1) modifiers |= KEY_MOD_SHIFT; + if (num2 & 0x2) modifiers |= KEY_MOD_ALT; + if (num2 & 0x4) modifiers |= KEY_MOD_CTRL; + } + + // and translate + code_t code = KEY_NONE; + if (final == '~') { + // vt codes + code = esc_decode_vt(num1); + } + else if (c1 == '[' && final == 'u') { + // unicode + code = key_unicode(num1); + } + else if (c1 == 'O' && ((final >= 'A' && final <= 'Z') || (final >= 'a' && final <= 'z'))) { + // ss3 + code = esc_decode_ss3(final); + } + else if (num1 == 1 && final >= 'A' && final <= 'Z') { + // xterm + code = esc_decode_xterm(final); + } + else if (c1 == '[' && final == 'R') { + // cursor position + code = KEY_NONE; + } + + if (code == KEY_NONE && final != 'R') { + debug_msg("tty: ignore escape sequence: ESC %c %zu;%zu %c\n", c1, num1, num2, final); + } + return (code != KEY_NONE ? (code | modifiers) : KEY_NONE); +} + +static code_t tty_read_osc( tty_t* tty, uint8_t* ppeek, long esc_timeout ) { + debug_msg("discard OSC response..\n"); + // keep reading until termination: OSC is terminated by BELL, or ESC \ (ST) (and STX) + while (true) { + uint8_t c = *ppeek; + if (c <= '\x07') { // BELL and anything below (STX, ^C, ^D) + if (c != '\x07') { tty_cpush_char( tty, c ); } + break; + } + else if (c=='\x1B') { + uint8_t c1; + if (!tty_readc_noblock(tty, &c1, esc_timeout)) break; + if (c1=='\\') break; + tty_cpush_char(tty,c1); + } + if (!tty_readc_noblock(tty, ppeek, esc_timeout)) break; + } + return KEY_NONE; +} + +ic_private code_t tty_read_esc(tty_t* tty, long esc_initial_timeout, long esc_timeout) { + code_t mods = 0; + uint8_t peek = 0; + + // lone ESC? + if (!tty_readc_noblock(tty, &peek, esc_initial_timeout)) return KEY_ESC; + + // treat ESC ESC as Alt modifier (macOS sends ESC ESC [ [A-D] for alt-) + if (peek == KEY_ESC) { + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + mods |= KEY_MOD_ALT; + } + + // CSI ? + if (peek == '[') { + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + return tty_read_csi(tty, '[', peek, mods, esc_timeout); // ESC [ ... + } + + // SS3? + if (peek == 'O' || peek == 'o' || peek == '?' /*vt52*/) { + uint8_t c1 = peek; + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + if (c1 == 'o') { + // ETerm uses this for ctrl+ + mods |= KEY_MOD_CTRL; + } + // treat all as standard SS3 'O' + return tty_read_csi(tty,'O',peek,mods, esc_timeout); // ESC [Oo?] ... + } + + // OSC: we may get a delayed query response; ensure it is ignored + if (peek == ']') { + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + return tty_read_osc(tty, &peek, esc_timeout); // ESC ] ... + } + +alt: + // Alt+ + return (key_unicode(peek) | KEY_MOD_ALT); // ESC +} diff --git a/extern/isocline/src/undo.c b/extern/isocline/src/undo.c new file mode 100644 index 00000000..eefc318d --- /dev/null +++ b/extern/isocline/src/undo.c @@ -0,0 +1,67 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" +#include "stringbuf.h" +#include "completions.h" +#include "undo.h" + + + +//------------------------------------------------------------- +// edit state +//------------------------------------------------------------- +struct editstate_s { + struct editstate_s* next; + const char* input; // input + ssize_t pos; // cursor position +}; + +ic_private void editstate_init( editstate_t** es ) { + *es = NULL; +} + +ic_private void editstate_done( alloc_t* mem, editstate_t** es ) { + while (*es != NULL) { + editstate_t* next = (*es)->next; + mem_free(mem, (*es)->input); + mem_free(mem, *es ); + *es = next; + } + *es = NULL; +} + +ic_private void editstate_capture( alloc_t* mem, editstate_t** es, const char* input, ssize_t pos) { + if (input==NULL) input = ""; + // alloc + editstate_t* entry = mem_zalloc_tp(mem, editstate_t); + if (entry == NULL) return; + // initialize + entry->input = mem_strdup( mem, input); + entry->pos = pos; + if (entry->input == NULL) { mem_free(mem, entry); return; } + // and push + entry->next = *es; + *es = entry; +} + +// caller should free *input +ic_private bool editstate_restore( alloc_t* mem, editstate_t** es, const char** input, ssize_t* pos ) { + if (*es == NULL) return false; + // pop + editstate_t* entry = *es; + *es = entry->next; + *input = entry->input; + *pos = entry->pos; + mem_free(mem, entry); + return true; +} + diff --git a/extern/isocline/src/undo.h b/extern/isocline/src/undo.h new file mode 100644 index 00000000..576cf977 --- /dev/null +++ b/extern/isocline/src/undo.h @@ -0,0 +1,24 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_UNDO_H +#define IC_UNDO_H + +#include "common.h" + +//------------------------------------------------------------- +// Edit state +//------------------------------------------------------------- +struct editstate_s; +typedef struct editstate_s editstate_t; + +ic_private void editstate_init( editstate_t** es ); +ic_private void editstate_done( alloc_t* mem, editstate_t** es ); +ic_private void editstate_capture( alloc_t* mem, editstate_t** es, const char* input, ssize_t pos); +ic_private bool editstate_restore( alloc_t* mem, editstate_t** es, const char** input, ssize_t* pos ); // caller needs to free input + +#endif // IC_UNDO_H diff --git a/extern/isocline/src/wcwidth.c b/extern/isocline/src/wcwidth.c new file mode 100644 index 00000000..85187d41 --- /dev/null +++ b/extern/isocline/src/wcwidth.c @@ -0,0 +1,292 @@ +// include in "stringbuf.c" +/* + * This is an implementation of wcwidth() and wcswidth() (defined in + * IEEE Std 1002.1-2001) for Unicode. + * + * http://www.opengroup.org/onlinepubs/007904975/functions/wcwidth.html + * http://www.opengroup.org/onlinepubs/007904975/functions/wcswidth.html + * + * In fixed-width output devices, Latin characters all occupy a single + * "cell" position of equal width, whereas ideographic CJK characters + * occupy two such cells. Interoperability between terminal-line + * applications and (teletype-style) character terminals using the + * UTF-8 encoding requires agreement on which character should advance + * the cursor by how many cell positions. No established formal + * standards exist at present on which Unicode character shall occupy + * how many cell positions on character terminals. These routines are + * a first attempt of defining such behavior based on simple rules + * applied to data provided by the Unicode Consortium. + * + * For some graphical characters, the Unicode standard explicitly + * defines a character-cell width via the definition of the East Asian + * FullWidth (F), Wide (W), Half-width (H), and Narrow (Na) classes. + * In all these cases, there is no ambiguity about which width a + * terminal shall use. For characters in the East Asian Ambiguous (A) + * class, the width choice depends purely on a preference of backward + * compatibility with either historic CJK or Western practice. + * Choosing single-width for these characters is easy to justify as + * the appropriate long-term solution, as the CJK practice of + * displaying these characters as double-width comes from historic + * implementation simplicity (8-bit encoded characters were displayed + * single-width and 16-bit ones double-width, even for Greek, + * Cyrillic, etc.) and not any typographic considerations. + * + * Much less clear is the choice of width for the Not East Asian + * (Neutral) class. Existing practice does not dictate a width for any + * of these characters. It would nevertheless make sense + * typographically to allocate two character cells to characters such + * as for instance EM SPACE or VOLUME INTEGRAL, which cannot be + * represented adequately with a single-width glyph. The following + * routines at present merely assign a single-cell width to all + * neutral characters, in the interest of simplicity. This is not + * entirely satisfactory and should be reconsidered before + * establishing a formal standard in this area. At the moment, the + * decision which Not East Asian (Neutral) characters should be + * represented by double-width glyphs cannot yet be answered by + * applying a simple rule from the Unicode database content. Setting + * up a proper standard for the behavior of UTF-8 character terminals + * will require a careful analysis not only of each Unicode character, + * but also of each presentation form, something the author of these + * routines has avoided to do so far. + * + * http://www.unicode.org/unicode/reports/tr11/ + * + * Markus Kuhn -- 2007-05-26 (Unicode 5.0) + * + * Permission to use, copy, modify, and distribute this software + * for any purpose and without fee is hereby granted. The author + * disclaims all warranties with regard to this software. + * + * Latest version: http://www.cl.cam.ac.uk/~mgk25/ucs/wcwidth.c + */ + +#include +#include + +struct interval { + int32_t first; + int32_t last; +}; + +/* auxiliary function for binary search in interval table */ +static int bisearch(int32_t ucs, const struct interval *table, int max) { + int min = 0; + int mid; + + if (ucs < table[0].first || ucs > table[max].last) + return 0; + while (max >= min) { + mid = (min + max) / 2; + if (ucs > table[mid].last) + min = mid + 1; + else if (ucs < table[mid].first) + max = mid - 1; + else + return 1; + } + + return 0; +} + + +/* The following two functions define the column width of an ISO 10646 + * character as follows: + * + * - The null character (U+0000) has a column width of 0. + * + * - Other C0/C1 control characters and DEL will lead to a return + * value of -1. + * + * - Non-spacing and enclosing combining characters (general + * category code Mn or Me in the Unicode database) have a + * column width of 0. + * + * - SOFT HYPHEN (U+00AD) has a column width of 1. + * + * - Other format characters (general category code Cf in the Unicode + * database) and ZERO WIDTH SPACE (U+200B) have a column width of 0. + * + * - Hangul Jamo medial vowels and final consonants (U+1160-U+11FF) + * have a column width of 0. + * + * - Spacing characters in the East Asian Wide (W) or East Asian + * Full-width (F) category as defined in Unicode Technical + * Report #11 have a column width of 2. + * + * - All remaining characters (including all printable + * ISO 8859-1 and WGL4 characters, Unicode control characters, + * etc.) have a column width of 1. + * + * This implementation assumes that wchar_t characters are encoded + * in ISO 10646. + */ + +static int mk_is_wide_char(int32_t ucs) { + static const struct interval wide[] = { + {0x1100, 0x115f}, {0x231a, 0x231b}, {0x2329, 0x232a}, + {0x23e9, 0x23ec}, {0x23f0, 0x23f0}, {0x23f3, 0x23f3}, + {0x25fd, 0x25fe}, {0x2614, 0x2615}, {0x2648, 0x2653}, + {0x267f, 0x267f}, {0x2693, 0x2693}, {0x26a1, 0x26a1}, + {0x26aa, 0x26ab}, {0x26bd, 0x26be}, {0x26c4, 0x26c5}, + {0x26ce, 0x26ce}, {0x26d4, 0x26d4}, {0x26ea, 0x26ea}, + {0x26f2, 0x26f3}, {0x26f5, 0x26f5}, {0x26fa, 0x26fa}, + {0x26fd, 0x26fd}, {0x2705, 0x2705}, {0x270a, 0x270b}, + {0x2728, 0x2728}, {0x274c, 0x274c}, {0x274e, 0x274e}, + {0x2753, 0x2755}, {0x2757, 0x2757}, {0x2795, 0x2797}, + {0x27b0, 0x27b0}, {0x27bf, 0x27bf}, {0x2b1b, 0x2b1c}, + {0x2b50, 0x2b50}, {0x2b55, 0x2b55}, {0x2e80, 0x2fdf}, + {0x2ff0, 0x303e}, {0x3040, 0x3247}, {0x3250, 0x4dbf}, + {0x4e00, 0xa4cf}, {0xa960, 0xa97f}, {0xac00, 0xd7a3}, + {0xf900, 0xfaff}, {0xfe10, 0xfe19}, {0xfe30, 0xfe6f}, + {0xff01, 0xff60}, {0xffe0, 0xffe6}, {0x16fe0, 0x16fe1}, + {0x17000, 0x18aff}, {0x1b000, 0x1b12f}, {0x1b170, 0x1b2ff}, + {0x1f004, 0x1f004}, {0x1f0cf, 0x1f0cf}, {0x1f18e, 0x1f18e}, + {0x1f191, 0x1f19a}, {0x1f200, 0x1f202}, {0x1f210, 0x1f23b}, + {0x1f240, 0x1f248}, {0x1f250, 0x1f251}, {0x1f260, 0x1f265}, + {0x1f300, 0x1f320}, {0x1f32d, 0x1f335}, {0x1f337, 0x1f37c}, + {0x1f37e, 0x1f393}, {0x1f3a0, 0x1f3ca}, {0x1f3cf, 0x1f3d3}, + {0x1f3e0, 0x1f3f0}, {0x1f3f4, 0x1f3f4}, {0x1f3f8, 0x1f43e}, + {0x1f440, 0x1f440}, {0x1f442, 0x1f4fc}, {0x1f4ff, 0x1f53d}, + {0x1f54b, 0x1f54e}, {0x1f550, 0x1f567}, {0x1f57a, 0x1f57a}, + {0x1f595, 0x1f596}, {0x1f5a4, 0x1f5a4}, {0x1f5fb, 0x1f64f}, + {0x1f680, 0x1f6c5}, {0x1f6cc, 0x1f6cc}, {0x1f6d0, 0x1f6d2}, + {0x1f6eb, 0x1f6ec}, {0x1f6f4, 0x1f6f8}, {0x1f910, 0x1f93e}, + {0x1f940, 0x1f94c}, {0x1f950, 0x1f96b}, {0x1f980, 0x1f997}, + {0x1f9c0, 0x1f9c0}, {0x1f9d0, 0x1f9e6}, {0x20000, 0x2fffd}, + {0x30000, 0x3fffd}, + }; + + if ( bisearch(ucs, wide, sizeof(wide) / sizeof(struct interval) - 1) ) { + return 1; + } + + return 0; +} + +static int mk_wcwidth(int32_t ucs) { + /* sorted list of non-overlapping intervals of non-spacing characters */ + /* generated by "uniset +cat=Me +cat=Mn +cat=Cf -00AD +1160-11FF +200B c" */ + static const struct interval combining[] = { + {0x00ad, 0x00ad}, {0x0300, 0x036f}, {0x0483, 0x0489}, + {0x0591, 0x05bd}, {0x05bf, 0x05bf}, {0x05c1, 0x05c2}, + {0x05c4, 0x05c5}, {0x05c7, 0x05c7}, {0x0610, 0x061a}, + {0x061c, 0x061c}, {0x064b, 0x065f}, {0x0670, 0x0670}, + {0x06d6, 0x06dc}, {0x06df, 0x06e4}, {0x06e7, 0x06e8}, + {0x06ea, 0x06ed}, {0x0711, 0x0711}, {0x0730, 0x074a}, + {0x07a6, 0x07b0}, {0x07eb, 0x07f3}, {0x0816, 0x0819}, + {0x081b, 0x0823}, {0x0825, 0x0827}, {0x0829, 0x082d}, + {0x0859, 0x085b}, {0x08d4, 0x08e1}, {0x08e3, 0x0902}, + {0x093a, 0x093a}, {0x093c, 0x093c}, {0x0941, 0x0948}, + {0x094d, 0x094d}, {0x0951, 0x0957}, {0x0962, 0x0963}, + {0x0981, 0x0981}, {0x09bc, 0x09bc}, {0x09c1, 0x09c4}, + {0x09cd, 0x09cd}, {0x09e2, 0x09e3}, {0x0a01, 0x0a02}, + {0x0a3c, 0x0a3c}, {0x0a41, 0x0a42}, {0x0a47, 0x0a48}, + {0x0a4b, 0x0a4d}, {0x0a51, 0x0a51}, {0x0a70, 0x0a71}, + {0x0a75, 0x0a75}, {0x0a81, 0x0a82}, {0x0abc, 0x0abc}, + {0x0ac1, 0x0ac5}, {0x0ac7, 0x0ac8}, {0x0acd, 0x0acd}, + {0x0ae2, 0x0ae3}, {0x0afa, 0x0aff}, {0x0b01, 0x0b01}, + {0x0b3c, 0x0b3c}, {0x0b3f, 0x0b3f}, {0x0b41, 0x0b44}, + {0x0b4d, 0x0b4d}, {0x0b56, 0x0b56}, {0x0b62, 0x0b63}, + {0x0b82, 0x0b82}, {0x0bc0, 0x0bc0}, {0x0bcd, 0x0bcd}, + {0x0c00, 0x0c00}, {0x0c3e, 0x0c40}, {0x0c46, 0x0c48}, + {0x0c4a, 0x0c4d}, {0x0c55, 0x0c56}, {0x0c62, 0x0c63}, + {0x0c81, 0x0c81}, {0x0cbc, 0x0cbc}, {0x0cbf, 0x0cbf}, + {0x0cc6, 0x0cc6}, {0x0ccc, 0x0ccd}, {0x0ce2, 0x0ce3}, + {0x0d00, 0x0d01}, {0x0d3b, 0x0d3c}, {0x0d41, 0x0d44}, + {0x0d4d, 0x0d4d}, {0x0d62, 0x0d63}, {0x0dca, 0x0dca}, + {0x0dd2, 0x0dd4}, {0x0dd6, 0x0dd6}, {0x0e31, 0x0e31}, + {0x0e34, 0x0e3a}, {0x0e47, 0x0e4e}, {0x0eb1, 0x0eb1}, + {0x0eb4, 0x0eb9}, {0x0ebb, 0x0ebc}, {0x0ec8, 0x0ecd}, + {0x0f18, 0x0f19}, {0x0f35, 0x0f35}, {0x0f37, 0x0f37}, + {0x0f39, 0x0f39}, {0x0f71, 0x0f7e}, {0x0f80, 0x0f84}, + {0x0f86, 0x0f87}, {0x0f8d, 0x0f97}, {0x0f99, 0x0fbc}, + {0x0fc6, 0x0fc6}, {0x102d, 0x1030}, {0x1032, 0x1037}, + {0x1039, 0x103a}, {0x103d, 0x103e}, {0x1058, 0x1059}, + {0x105e, 0x1060}, {0x1071, 0x1074}, {0x1082, 0x1082}, + {0x1085, 0x1086}, {0x108d, 0x108d}, {0x109d, 0x109d}, + {0x1160, 0x11ff}, {0x135d, 0x135f}, {0x1712, 0x1714}, + {0x1732, 0x1734}, {0x1752, 0x1753}, {0x1772, 0x1773}, + {0x17b4, 0x17b5}, {0x17b7, 0x17bd}, {0x17c6, 0x17c6}, + {0x17c9, 0x17d3}, {0x17dd, 0x17dd}, {0x180b, 0x180e}, + {0x1885, 0x1886}, {0x18a9, 0x18a9}, {0x1920, 0x1922}, + {0x1927, 0x1928}, {0x1932, 0x1932}, {0x1939, 0x193b}, + {0x1a17, 0x1a18}, {0x1a1b, 0x1a1b}, {0x1a56, 0x1a56}, + {0x1a58, 0x1a5e}, {0x1a60, 0x1a60}, {0x1a62, 0x1a62}, + {0x1a65, 0x1a6c}, {0x1a73, 0x1a7c}, {0x1a7f, 0x1a7f}, + {0x1ab0, 0x1abe}, {0x1b00, 0x1b03}, {0x1b34, 0x1b34}, + {0x1b36, 0x1b3a}, {0x1b3c, 0x1b3c}, {0x1b42, 0x1b42}, + {0x1b6b, 0x1b73}, {0x1b80, 0x1b81}, {0x1ba2, 0x1ba5}, + {0x1ba8, 0x1ba9}, {0x1bab, 0x1bad}, {0x1be6, 0x1be6}, + {0x1be8, 0x1be9}, {0x1bed, 0x1bed}, {0x1bef, 0x1bf1}, + {0x1c2c, 0x1c33}, {0x1c36, 0x1c37}, {0x1cd0, 0x1cd2}, + {0x1cd4, 0x1ce0}, {0x1ce2, 0x1ce8}, {0x1ced, 0x1ced}, + {0x1cf4, 0x1cf4}, {0x1cf8, 0x1cf9}, {0x1dc0, 0x1df9}, + {0x1dfb, 0x1dff}, {0x200b, 0x200f}, {0x202a, 0x202e}, + {0x2060, 0x2064}, {0x2066, 0x206f}, {0x20d0, 0x20f0}, + {0x2cef, 0x2cf1}, {0x2d7f, 0x2d7f}, {0x2de0, 0x2dff}, + {0x302a, 0x302d}, {0x3099, 0x309a}, {0xa66f, 0xa672}, + {0xa674, 0xa67d}, {0xa69e, 0xa69f}, {0xa6f0, 0xa6f1}, + {0xa802, 0xa802}, {0xa806, 0xa806}, {0xa80b, 0xa80b}, + {0xa825, 0xa826}, {0xa8c4, 0xa8c5}, {0xa8e0, 0xa8f1}, + {0xa926, 0xa92d}, {0xa947, 0xa951}, {0xa980, 0xa982}, + {0xa9b3, 0xa9b3}, {0xa9b6, 0xa9b9}, {0xa9bc, 0xa9bc}, + {0xa9e5, 0xa9e5}, {0xaa29, 0xaa2e}, {0xaa31, 0xaa32}, + {0xaa35, 0xaa36}, {0xaa43, 0xaa43}, {0xaa4c, 0xaa4c}, + {0xaa7c, 0xaa7c}, {0xaab0, 0xaab0}, {0xaab2, 0xaab4}, + {0xaab7, 0xaab8}, {0xaabe, 0xaabf}, {0xaac1, 0xaac1}, + {0xaaec, 0xaaed}, {0xaaf6, 0xaaf6}, {0xabe5, 0xabe5}, + {0xabe8, 0xabe8}, {0xabed, 0xabed}, {0xfb1e, 0xfb1e}, + {0xfe00, 0xfe0f}, {0xfe20, 0xfe2f}, {0xfeff, 0xfeff}, + {0xfff9, 0xfffb}, {0x101fd, 0x101fd}, {0x102e0, 0x102e0}, + {0x10376, 0x1037a}, {0x10a01, 0x10a03}, {0x10a05, 0x10a06}, + {0x10a0c, 0x10a0f}, {0x10a38, 0x10a3a}, {0x10a3f, 0x10a3f}, + {0x10ae5, 0x10ae6}, {0x11001, 0x11001}, {0x11038, 0x11046}, + {0x1107f, 0x11081}, {0x110b3, 0x110b6}, {0x110b9, 0x110ba}, + {0x11100, 0x11102}, {0x11127, 0x1112b}, {0x1112d, 0x11134}, + {0x11173, 0x11173}, {0x11180, 0x11181}, {0x111b6, 0x111be}, + {0x111ca, 0x111cc}, {0x1122f, 0x11231}, {0x11234, 0x11234}, + {0x11236, 0x11237}, {0x1123e, 0x1123e}, {0x112df, 0x112df}, + {0x112e3, 0x112ea}, {0x11300, 0x11301}, {0x1133c, 0x1133c}, + {0x11340, 0x11340}, {0x11366, 0x1136c}, {0x11370, 0x11374}, + {0x11438, 0x1143f}, {0x11442, 0x11444}, {0x11446, 0x11446}, + {0x114b3, 0x114b8}, {0x114ba, 0x114ba}, {0x114bf, 0x114c0}, + {0x114c2, 0x114c3}, {0x115b2, 0x115b5}, {0x115bc, 0x115bd}, + {0x115bf, 0x115c0}, {0x115dc, 0x115dd}, {0x11633, 0x1163a}, + {0x1163d, 0x1163d}, {0x1163f, 0x11640}, {0x116ab, 0x116ab}, + {0x116ad, 0x116ad}, {0x116b0, 0x116b5}, {0x116b7, 0x116b7}, + {0x1171d, 0x1171f}, {0x11722, 0x11725}, {0x11727, 0x1172b}, + {0x11a01, 0x11a06}, {0x11a09, 0x11a0a}, {0x11a33, 0x11a38}, + {0x11a3b, 0x11a3e}, {0x11a47, 0x11a47}, {0x11a51, 0x11a56}, + {0x11a59, 0x11a5b}, {0x11a8a, 0x11a96}, {0x11a98, 0x11a99}, + {0x11c30, 0x11c36}, {0x11c38, 0x11c3d}, {0x11c3f, 0x11c3f}, + {0x11c92, 0x11ca7}, {0x11caa, 0x11cb0}, {0x11cb2, 0x11cb3}, + {0x11cb5, 0x11cb6}, {0x11d31, 0x11d36}, {0x11d3a, 0x11d3a}, + {0x11d3c, 0x11d3d}, {0x11d3f, 0x11d45}, {0x11d47, 0x11d47}, + {0x16af0, 0x16af4}, {0x16b30, 0x16b36}, {0x16f8f, 0x16f92}, + {0x1bc9d, 0x1bc9e}, {0x1bca0, 0x1bca3}, {0x1d167, 0x1d169}, + {0x1d173, 0x1d182}, {0x1d185, 0x1d18b}, {0x1d1aa, 0x1d1ad}, + {0x1d242, 0x1d244}, {0x1da00, 0x1da36}, {0x1da3b, 0x1da6c}, + {0x1da75, 0x1da75}, {0x1da84, 0x1da84}, {0x1da9b, 0x1da9f}, + {0x1daa1, 0x1daaf}, {0x1e000, 0x1e006}, {0x1e008, 0x1e018}, + {0x1e01b, 0x1e021}, {0x1e023, 0x1e024}, {0x1e026, 0x1e02a}, + {0x1e8d0, 0x1e8d6}, {0x1e944, 0x1e94a}, {0xe0001, 0xe0001}, + {0xe0020, 0xe007f}, {0xe0100, 0xe01ef}, + }; + + /* test for 8-bit control characters */ + if ( ucs == 0 ) { + return 0; + } + if ( ( ucs < 32 ) || ( ( ucs >= 0x7f ) && ( ucs < 0xa0 ) ) ) { + return -1; + } + + /* binary search in table of non-spacing characters */ + if ( bisearch( ucs, combining, sizeof( combining ) / sizeof( struct interval ) - 1 ) ) { + return 0; + } + + /* if we arrive here, ucs is not a combining or C0/C1 control character */ + return ( mk_is_wide_char( ucs ) ? 2 : 1 ); +} + diff --git a/extern/linenoise.hpp b/extern/linenoise.hpp deleted file mode 100644 index ae36eb0b..00000000 --- a/extern/linenoise.hpp +++ /dev/null @@ -1,2415 +0,0 @@ -/* - * linenoise.hpp -- Multi-platfrom C++ header-only linenoise library. - * - * All credits and commendations have to go to the authors of the - * following excellent libraries. - * - * - linenoise.h and linenose.c (https://github.com/antirez/linenoise) - * - ANSI.c (https://github.com/adoxa/ansicon) - * - Win32_ANSI.h and Win32_ANSI.c (https://github.com/MSOpenTech/redis) - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2015 yhirose - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -/* linenoise.h -- guerrilla line editing library against the idea that a - * line editing lib needs to be 20,000 lines of C code. - * - * See linenoise.c for more information. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010, Salvatore Sanfilippo - * Copyright (c) 2010, Pieter Noordhuis - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -/* - * ANSI.c - ANSI escape sequence console driver. - * - * Copyright (C) 2005-2014 Jason Hood - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the author be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - * - * Jason Hood - * jadoxa@yahoo.com.au - */ - -/* - * Win32_ANSI.h and Win32_ANSI.c - * - * Derived from ANSI.c by Jason Hood, from his ansicon project (https://github.com/adoxa/ansicon), with modifications. - * - * Copyright (c), Microsoft Open Technologies, Inc. - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef LINENOISE_HPP -#define LINENOISE_HPP - -#ifndef _WIN32 -#include -#include -#include -#else -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#include -#ifndef STDIN_FILENO -#define STDIN_FILENO (_fileno(stdin)) -#endif -#ifndef STDOUT_FILENO -#define STDOUT_FILENO 1 -#endif -#define isatty _isatty -#define write win32_write -#define read _read -#pragma warning(push) -#pragma warning(disable : 4996) -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace linenoise { - -typedef std::function&)> CompletionCallback; - -#ifdef _WIN32 - -namespace ansi { - -#define lenof(array) (sizeof(array)/sizeof(*(array))) - -typedef struct -{ - BYTE foreground; // ANSI base color (0 to 7; add 30) - BYTE background; // ANSI base color (0 to 7; add 40) - BYTE bold; // console FOREGROUND_INTENSITY bit - BYTE underline; // console BACKGROUND_INTENSITY bit - BYTE rvideo; // swap foreground/bold & background/underline - BYTE concealed; // set foreground/bold to background/underline - BYTE reverse; // swap console foreground & background attributes -} GRM, *PGRM; // Graphic Rendition Mode - - -inline bool is_digit(char c) { return '0' <= c && c <= '9'; } - -// ========== Global variables and constants - -HANDLE hConOut; // handle to CONOUT$ - -const char ESC = '\x1B'; // ESCape character -const char BEL = '\x07'; -const char SO = '\x0E'; // Shift Out -const char SI = '\x0F'; // Shift In - -const int MAX_ARG = 16; // max number of args in an escape sequence -int state; // automata state -WCHAR prefix; // escape sequence prefix ( '[', ']' or '(' ); -WCHAR prefix2; // secondary prefix ( '?' or '>' ); -WCHAR suffix; // escape sequence suffix -int es_argc; // escape sequence args count -int es_argv[MAX_ARG]; // escape sequence args -WCHAR Pt_arg[MAX_PATH * 2]; // text parameter for Operating System Command -int Pt_len; -BOOL shifted; - - -// DEC Special Graphics Character Set from -// http://vt100.net/docs/vt220-rm/table2-4.html -// Some of these may not look right, depending on the font and code page (in -// particular, the Control Pictures probably won't work at all). -const WCHAR G1[] = -{ - ' ', // _ - blank - L'\x2666', // ` - Black Diamond Suit - L'\x2592', // a - Medium Shade - L'\x2409', // b - HT - L'\x240c', // c - FF - L'\x240d', // d - CR - L'\x240a', // e - LF - L'\x00b0', // f - Degree Sign - L'\x00b1', // g - Plus-Minus Sign - L'\x2424', // h - NL - L'\x240b', // i - VT - L'\x2518', // j - Box Drawings Light Up And Left - L'\x2510', // k - Box Drawings Light Down And Left - L'\x250c', // l - Box Drawings Light Down And Right - L'\x2514', // m - Box Drawings Light Up And Right - L'\x253c', // n - Box Drawings Light Vertical And Horizontal - L'\x00af', // o - SCAN 1 - Macron - L'\x25ac', // p - SCAN 3 - Black Rectangle - L'\x2500', // q - SCAN 5 - Box Drawings Light Horizontal - L'_', // r - SCAN 7 - Low Line - L'_', // s - SCAN 9 - Low Line - L'\x251c', // t - Box Drawings Light Vertical And Right - L'\x2524', // u - Box Drawings Light Vertical And Left - L'\x2534', // v - Box Drawings Light Up And Horizontal - L'\x252c', // w - Box Drawings Light Down And Horizontal - L'\x2502', // x - Box Drawings Light Vertical - L'\x2264', // y - Less-Than Or Equal To - L'\x2265', // z - Greater-Than Or Equal To - L'\x03c0', // { - Greek Small Letter Pi - L'\x2260', // | - Not Equal To - L'\x00a3', // } - Pound Sign - L'\x00b7', // ~ - Middle Dot -}; - -#define FIRST_G1 '_' -#define LAST_G1 '~' - - -// color constants - -#define FOREGROUND_BLACK 0 -#define FOREGROUND_WHITE FOREGROUND_RED|FOREGROUND_GREEN|FOREGROUND_BLUE - -#define BACKGROUND_BLACK 0 -#define BACKGROUND_WHITE BACKGROUND_RED|BACKGROUND_GREEN|BACKGROUND_BLUE - -const BYTE foregroundcolor[8] = - { - FOREGROUND_BLACK, // black foreground - FOREGROUND_RED, // red foreground - FOREGROUND_GREEN, // green foreground - FOREGROUND_RED | FOREGROUND_GREEN, // yellow foreground - FOREGROUND_BLUE, // blue foreground - FOREGROUND_BLUE | FOREGROUND_RED, // magenta foreground - FOREGROUND_BLUE | FOREGROUND_GREEN, // cyan foreground - FOREGROUND_WHITE // white foreground - }; - -const BYTE backgroundcolor[8] = - { - BACKGROUND_BLACK, // black background - BACKGROUND_RED, // red background - BACKGROUND_GREEN, // green background - BACKGROUND_RED | BACKGROUND_GREEN, // yellow background - BACKGROUND_BLUE, // blue background - BACKGROUND_BLUE | BACKGROUND_RED, // magenta background - BACKGROUND_BLUE | BACKGROUND_GREEN, // cyan background - BACKGROUND_WHITE, // white background - }; - -const BYTE attr2ansi[8] = // map console attribute to ANSI number -{ - 0, // black - 4, // blue - 2, // green - 6, // cyan - 1, // red - 5, // magenta - 3, // yellow - 7 // white -}; - -GRM grm; - -// saved cursor position -COORD SavePos; - -// ========== Print Buffer functions - -#define BUFFER_SIZE 2048 - -int nCharInBuffer; -WCHAR ChBuffer[BUFFER_SIZE]; - -//----------------------------------------------------------------------------- -// FlushBuffer() -// Writes the buffer to the console and empties it. -//----------------------------------------------------------------------------- - -inline void FlushBuffer(void) -{ - DWORD nWritten; - if (nCharInBuffer <= 0) return; - WriteConsoleW(hConOut, ChBuffer, nCharInBuffer, &nWritten, NULL); - nCharInBuffer = 0; -} - -//----------------------------------------------------------------------------- -// PushBuffer( WCHAR c ) -// Adds a character in the buffer. -//----------------------------------------------------------------------------- - -inline void PushBuffer(WCHAR c) -{ - if (shifted && c >= FIRST_G1 && c <= LAST_G1) - c = G1[c - FIRST_G1]; - ChBuffer[nCharInBuffer] = c; - if (++nCharInBuffer == BUFFER_SIZE) - FlushBuffer(); -} - -//----------------------------------------------------------------------------- -// SendSequence( LPCWSTR seq ) -// Send the string to the input buffer. -//----------------------------------------------------------------------------- - -inline void SendSequence(LPCWSTR seq) -{ - DWORD out; - INPUT_RECORD in; - HANDLE hStdIn = GetStdHandle(STD_INPUT_HANDLE); - - in.EventType = KEY_EVENT; - in.Event.KeyEvent.bKeyDown = TRUE; - in.Event.KeyEvent.wRepeatCount = 1; - in.Event.KeyEvent.wVirtualKeyCode = 0; - in.Event.KeyEvent.wVirtualScanCode = 0; - in.Event.KeyEvent.dwControlKeyState = 0; - for (; *seq; ++seq) - { - in.Event.KeyEvent.uChar.UnicodeChar = *seq; - WriteConsoleInput(hStdIn, &in, 1, &out); - } -} - -// ========== Print functions - -//----------------------------------------------------------------------------- -// InterpretEscSeq() -// Interprets the last escape sequence scanned by ParseAndPrintANSIString -// prefix escape sequence prefix -// es_argc escape sequence args count -// es_argv[] escape sequence args array -// suffix escape sequence suffix -// -// for instance, with \e[33;45;1m we have -// prefix = '[', -// es_argc = 3, es_argv[0] = 33, es_argv[1] = 45, es_argv[2] = 1 -// suffix = 'm' -//----------------------------------------------------------------------------- - -inline void InterpretEscSeq(void) -{ - int i; - WORD attribut; - CONSOLE_SCREEN_BUFFER_INFO Info; - CONSOLE_CURSOR_INFO CursInfo; - DWORD len, NumberOfCharsWritten; - COORD Pos; - SMALL_RECT Rect; - CHAR_INFO CharInfo; - - if (prefix == '[') - { - if (prefix2 == '?' && (suffix == 'h' || suffix == 'l')) - { - if (es_argc == 1 && es_argv[0] == 25) - { - GetConsoleCursorInfo(hConOut, &CursInfo); - CursInfo.bVisible = (suffix == 'h'); - SetConsoleCursorInfo(hConOut, &CursInfo); - return; - } - } - // Ignore any other \e[? or \e[> sequences. - if (prefix2 != 0) - return; - - GetConsoleScreenBufferInfo(hConOut, &Info); - switch (suffix) - { - case 'm': - if (es_argc == 0) es_argv[es_argc++] = 0; - for (i = 0; i < es_argc; i++) - { - if (30 <= es_argv[i] && es_argv[i] <= 37) - grm.foreground = es_argv[i] - 30; - else if (40 <= es_argv[i] && es_argv[i] <= 47) - grm.background = es_argv[i] - 40; - else switch (es_argv[i]) - { - case 0: - case 39: - case 49: - { - WCHAR def[4]; - int a; - *def = '7'; def[1] = '\0'; - GetEnvironmentVariableW(L"ANSICON_DEF", def, lenof(def)); - a = wcstol(def, NULL, 16); - grm.reverse = FALSE; - if (a < 0) - { - grm.reverse = TRUE; - a = -a; - } - if (es_argv[i] != 49) - grm.foreground = attr2ansi[a & 7]; - if (es_argv[i] != 39) - grm.background = attr2ansi[(a >> 4) & 7]; - if (es_argv[i] == 0) - { - if (es_argc == 1) - { - grm.bold = a & FOREGROUND_INTENSITY; - grm.underline = a & BACKGROUND_INTENSITY; - } - else - { - grm.bold = 0; - grm.underline = 0; - } - grm.rvideo = 0; - grm.concealed = 0; - } - } - break; - - case 1: grm.bold = FOREGROUND_INTENSITY; break; - case 5: // blink - case 4: grm.underline = BACKGROUND_INTENSITY; break; - case 7: grm.rvideo = 1; break; - case 8: grm.concealed = 1; break; - case 21: // oops, this actually turns on double underline - case 22: grm.bold = 0; break; - case 25: - case 24: grm.underline = 0; break; - case 27: grm.rvideo = 0; break; - case 28: grm.concealed = 0; break; - } - } - if (grm.concealed) - { - if (grm.rvideo) - { - attribut = foregroundcolor[grm.foreground] - | backgroundcolor[grm.foreground]; - if (grm.bold) - attribut |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; - } - else - { - attribut = foregroundcolor[grm.background] - | backgroundcolor[grm.background]; - if (grm.underline) - attribut |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; - } - } - else if (grm.rvideo) - { - attribut = foregroundcolor[grm.background] - | backgroundcolor[grm.foreground]; - if (grm.bold) - attribut |= BACKGROUND_INTENSITY; - if (grm.underline) - attribut |= FOREGROUND_INTENSITY; - } - else - attribut = foregroundcolor[grm.foreground] | grm.bold - | backgroundcolor[grm.background] | grm.underline; - if (grm.reverse) - attribut = ((attribut >> 4) & 15) | ((attribut & 15) << 4); - SetConsoleTextAttribute(hConOut, attribut); - return; - - case 'J': - if (es_argc == 0) es_argv[es_argc++] = 0; // ESC[J == ESC[0J - if (es_argc != 1) return; - switch (es_argv[0]) - { - case 0: // ESC[0J erase from cursor to end of display - len = (Info.dwSize.Y - Info.dwCursorPosition.Y - 1) * Info.dwSize.X - + Info.dwSize.X - Info.dwCursorPosition.X - 1; - FillConsoleOutputCharacter(hConOut, ' ', len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - return; - - case 1: // ESC[1J erase from start to cursor. - Pos.X = 0; - Pos.Y = 0; - len = Info.dwCursorPosition.Y * Info.dwSize.X - + Info.dwCursorPosition.X + 1; - FillConsoleOutputCharacter(hConOut, ' ', len, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, Pos, - &NumberOfCharsWritten); - return; - - case 2: // ESC[2J Clear screen and home cursor - Pos.X = 0; - Pos.Y = 0; - len = Info.dwSize.X * Info.dwSize.Y; - FillConsoleOutputCharacter(hConOut, ' ', len, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, Pos, - &NumberOfCharsWritten); - SetConsoleCursorPosition(hConOut, Pos); - return; - - default: - return; - } - - case 'K': - if (es_argc == 0) es_argv[es_argc++] = 0; // ESC[K == ESC[0K - if (es_argc != 1) return; - switch (es_argv[0]) - { - case 0: // ESC[0K Clear to end of line - len = Info.dwSize.X - Info.dwCursorPosition.X + 1; - FillConsoleOutputCharacter(hConOut, ' ', len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - return; - - case 1: // ESC[1K Clear from start of line to cursor - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - FillConsoleOutputCharacter(hConOut, ' ', - Info.dwCursorPosition.X + 1, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, - Info.dwCursorPosition.X + 1, Pos, - &NumberOfCharsWritten); - return; - - case 2: // ESC[2K Clear whole line. - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - FillConsoleOutputCharacter(hConOut, ' ', Info.dwSize.X, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, - Info.dwSize.X, Pos, - &NumberOfCharsWritten); - return; - - default: - return; - } - - case 'X': // ESC[#X Erase # characters. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[X == ESC[1X - if (es_argc != 1) return; - FillConsoleOutputCharacter(hConOut, ' ', es_argv[0], - Info.dwCursorPosition, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, es_argv[0], - Info.dwCursorPosition, - &NumberOfCharsWritten); - return; - - case 'L': // ESC[#L Insert # blank lines. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[L == ESC[1L - if (es_argc != 1) return; - Rect.Left = 0; - Rect.Top = Info.dwCursorPosition.Y; - Rect.Right = Info.dwSize.X - 1; - Rect.Bottom = Info.dwSize.Y - 1; - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); - return; - - case 'M': // ESC[#M Delete # lines. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[M == ESC[1M - if (es_argc != 1) return; - if (es_argv[0] > Info.dwSize.Y - Info.dwCursorPosition.Y) - es_argv[0] = Info.dwSize.Y - Info.dwCursorPosition.Y; - Rect.Left = 0; - Rect.Top = Info.dwCursorPosition.Y + es_argv[0]; - Rect.Right = Info.dwSize.X - 1; - Rect.Bottom = Info.dwSize.Y - 1; - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); - return; - - case 'P': // ESC[#P Delete # characters. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[P == ESC[1P - if (es_argc != 1) return; - if (Info.dwCursorPosition.X + es_argv[0] > Info.dwSize.X - 1) - es_argv[0] = Info.dwSize.X - Info.dwCursorPosition.X; - Rect.Left = Info.dwCursorPosition.X + es_argv[0]; - Rect.Top = Info.dwCursorPosition.Y; - Rect.Right = Info.dwSize.X - 1; - Rect.Bottom = Info.dwCursorPosition.Y; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Info.dwCursorPosition, - &CharInfo); - return; - - case '@': // ESC[#@ Insert # blank characters. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[@ == ESC[1@ - if (es_argc != 1) return; - if (Info.dwCursorPosition.X + es_argv[0] > Info.dwSize.X - 1) - es_argv[0] = Info.dwSize.X - Info.dwCursorPosition.X; - Rect.Left = Info.dwCursorPosition.X; - Rect.Top = Info.dwCursorPosition.Y; - Rect.Right = Info.dwSize.X - 1 - es_argv[0]; - Rect.Bottom = Info.dwCursorPosition.Y; - Pos.X = Info.dwCursorPosition.X + es_argv[0]; - Pos.Y = Info.dwCursorPosition.Y; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); - return; - - case 'k': // ESC[#k - case 'A': // ESC[#A Moves cursor up # lines - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[A == ESC[1A - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y - es_argv[0]; - if (Pos.Y < 0) Pos.Y = 0; - Pos.X = Info.dwCursorPosition.X; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'e': // ESC[#e - case 'B': // ESC[#B Moves cursor down # lines - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[B == ESC[1B - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - Pos.X = Info.dwCursorPosition.X; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'a': // ESC[#a - case 'C': // ESC[#C Moves cursor forward # spaces - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[C == ESC[1C - if (es_argc != 1) return; - Pos.X = Info.dwCursorPosition.X + es_argv[0]; - if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; - Pos.Y = Info.dwCursorPosition.Y; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'j': // ESC[#j - case 'D': // ESC[#D Moves cursor back # spaces - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[D == ESC[1D - if (es_argc != 1) return; - Pos.X = Info.dwCursorPosition.X - es_argv[0]; - if (Pos.X < 0) Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'E': // ESC[#E Moves cursor down # lines, column 1. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[E == ESC[1E - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - Pos.X = 0; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'F': // ESC[#F Moves cursor up # lines, column 1. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[F == ESC[1F - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y - es_argv[0]; - if (Pos.Y < 0) Pos.Y = 0; - Pos.X = 0; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case '`': // ESC[#` - case 'G': // ESC[#G Moves cursor column # in current row. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[G == ESC[1G - if (es_argc != 1) return; - Pos.X = es_argv[0] - 1; - if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; - if (Pos.X < 0) Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'd': // ESC[#d Moves cursor row #, current column. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[d == ESC[1d - if (es_argc != 1) return; - Pos.Y = es_argv[0] - 1; - if (Pos.Y < 0) Pos.Y = 0; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'f': // ESC[#;#f - case 'H': // ESC[#;#H Moves cursor to line #, column # - if (es_argc == 0) - es_argv[es_argc++] = 1; // ESC[H == ESC[1;1H - if (es_argc == 1) - es_argv[es_argc++] = 1; // ESC[#H == ESC[#;1H - if (es_argc > 2) return; - Pos.X = es_argv[1] - 1; - if (Pos.X < 0) Pos.X = 0; - if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; - Pos.Y = es_argv[0] - 1; - if (Pos.Y < 0) Pos.Y = 0; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 's': // ESC[s Saves cursor position for recall later - if (es_argc != 0) return; - SavePos = Info.dwCursorPosition; - return; - - case 'u': // ESC[u Return to saved cursor position - if (es_argc != 0) return; - SetConsoleCursorPosition(hConOut, SavePos); - return; - - case 'n': // ESC[#n Device status report - if (es_argc != 1) return; // ESC[n == ESC[0n -> ignored - switch (es_argv[0]) - { - case 5: // ESC[5n Report status - SendSequence(L"\33[0n"); // "OK" - return; - - case 6: // ESC[6n Report cursor position - { - WCHAR buf[32]; - swprintf(buf, 32, L"\33[%d;%dR", Info.dwCursorPosition.Y + 1, - Info.dwCursorPosition.X + 1); - SendSequence(buf); - } - return; - - default: - return; - } - - case 't': // ESC[#t Window manipulation - if (es_argc != 1) return; - if (es_argv[0] == 21) // ESC[21t Report xterm window's title - { - WCHAR buf[MAX_PATH * 2]; - DWORD len = GetConsoleTitleW(buf + 3, lenof(buf) - 3 - 2); - // Too bad if it's too big or fails. - buf[0] = ESC; - buf[1] = ']'; - buf[2] = 'l'; - buf[3 + len] = ESC; - buf[3 + len + 1] = '\\'; - buf[3 + len + 2] = '\0'; - SendSequence(buf); - } - return; - - default: - return; - } - } - else // (prefix == ']') - { - // Ignore any \e]? or \e]> sequences. - if (prefix2 != 0) - return; - - if (es_argc == 1 && es_argv[0] == 0) // ESC]0;titleST - { - SetConsoleTitleW(Pt_arg); - } - } -} - -//----------------------------------------------------------------------------- -// ParseAndPrintANSIString(hDev, lpBuffer, nNumberOfBytesToWrite) -// Parses the string lpBuffer, interprets the escapes sequences and prints the -// characters in the device hDev (console). -// The lexer is a three states automata. -// If the number of arguments es_argc > MAX_ARG, only the MAX_ARG-1 firsts and -// the last arguments are processed (no es_argv[] overflow). -//----------------------------------------------------------------------------- - -inline BOOL ParseAndPrintANSIString(HANDLE hDev, LPCVOID lpBuffer, DWORD nNumberOfBytesToWrite, LPDWORD lpNumberOfBytesWritten) -{ - DWORD i; - LPCSTR s; - - if (hDev != hConOut) // reinit if device has changed - { - hConOut = hDev; - state = 1; - shifted = FALSE; - } - for (i = nNumberOfBytesToWrite, s = (LPCSTR)lpBuffer; i > 0; i--, s++) - { - if (state == 1) - { - if (*s == ESC) state = 2; - else if (*s == SO) shifted = TRUE; - else if (*s == SI) shifted = FALSE; - else PushBuffer(*s); - } - else if (state == 2) - { - if (*s == ESC); // \e\e...\e == \e - else if ((*s == '[') || (*s == ']')) - { - FlushBuffer(); - prefix = *s; - prefix2 = 0; - state = 3; - Pt_len = 0; - *Pt_arg = '\0'; - } - else if (*s == ')' || *s == '(') state = 6; - else state = 1; - } - else if (state == 3) - { - if (is_digit(*s)) - { - es_argc = 0; - es_argv[0] = *s - '0'; - state = 4; - } - else if (*s == ';') - { - es_argc = 1; - es_argv[0] = 0; - es_argv[1] = 0; - state = 4; - } - else if (*s == '?' || *s == '>') - { - prefix2 = *s; - } - else - { - es_argc = 0; - suffix = *s; - InterpretEscSeq(); - state = 1; - } - } - else if (state == 4) - { - if (is_digit(*s)) - { - es_argv[es_argc] = 10 * es_argv[es_argc] + (*s - '0'); - } - else if (*s == ';') - { - if (es_argc < MAX_ARG - 1) es_argc++; - es_argv[es_argc] = 0; - if (prefix == ']') - state = 5; - } - else - { - es_argc++; - suffix = *s; - InterpretEscSeq(); - state = 1; - } - } - else if (state == 5) - { - if (*s == BEL) - { - Pt_arg[Pt_len] = '\0'; - InterpretEscSeq(); - state = 1; - } - else if (*s == '\\' && Pt_len > 0 && Pt_arg[Pt_len - 1] == ESC) - { - Pt_arg[--Pt_len] = '\0'; - InterpretEscSeq(); - state = 1; - } - else if (Pt_len < lenof(Pt_arg) - 1) - Pt_arg[Pt_len++] = *s; - } - else if (state == 6) - { - // Ignore it (ESC ) 0 is implicit; nothing else is supported). - state = 1; - } - } - FlushBuffer(); - if (lpNumberOfBytesWritten != NULL) - *lpNumberOfBytesWritten = nNumberOfBytesToWrite - i; - return (i == 0); -} - -} // namespace ansi - -HANDLE hOut; -HANDLE hIn; -DWORD consolemodeIn = 0; - -inline int win32read(int *c) { - DWORD foo; - INPUT_RECORD b; - KEY_EVENT_RECORD e; - BOOL altgr; - - while (1) { - if (!ReadConsoleInput(hIn, &b, 1, &foo)) return 0; - if (!foo) return 0; - - if (b.EventType == KEY_EVENT && b.Event.KeyEvent.bKeyDown) { - - e = b.Event.KeyEvent; - *c = b.Event.KeyEvent.uChar.AsciiChar; - - altgr = e.dwControlKeyState & (LEFT_CTRL_PRESSED | RIGHT_ALT_PRESSED); - - if (e.dwControlKeyState & (LEFT_CTRL_PRESSED | RIGHT_CTRL_PRESSED) && !altgr) { - - /* Ctrl+Key */ - switch (*c) { - case 'D': - *c = 4; - return 1; - case 'C': - *c = 3; - return 1; - case 'H': - *c = 8; - return 1; - case 'T': - *c = 20; - return 1; - case 'B': /* ctrl-b, left_arrow */ - *c = 2; - return 1; - case 'F': /* ctrl-f right_arrow*/ - *c = 6; - return 1; - case 'P': /* ctrl-p up_arrow*/ - *c = 16; - return 1; - case 'N': /* ctrl-n down_arrow*/ - *c = 14; - return 1; - case 'U': /* Ctrl+u, delete the whole line. */ - *c = 21; - return 1; - case 'K': /* Ctrl+k, delete from current to end of line. */ - *c = 11; - return 1; - case 'A': /* Ctrl+a, go to the start of the line */ - *c = 1; - return 1; - case 'E': /* ctrl+e, go to the end of the line */ - *c = 5; - return 1; - } - - /* Other Ctrl+KEYs ignored */ - } else { - - switch (e.wVirtualKeyCode) { - - case VK_ESCAPE: /* ignore - send ctrl-c, will return -1 */ - *c = 3; - return 1; - case VK_RETURN: /* enter */ - *c = 13; - return 1; - case VK_LEFT: /* left */ - *c = 2; - return 1; - case VK_RIGHT: /* right */ - *c = 6; - return 1; - case VK_UP: /* up */ - *c = 16; - return 1; - case VK_DOWN: /* down */ - *c = 14; - return 1; - case VK_HOME: - *c = 1; - return 1; - case VK_END: - *c = 5; - return 1; - case VK_BACK: - *c = 8; - return 1; - case VK_DELETE: - *c = 4; /* same as Ctrl+D above */ - return 1; - default: - if (*c) return 1; - } - } - } - } - - return -1; /* Makes compiler happy */ -} - -inline int win32_write(int fd, const void *buffer, unsigned int count) { - if (fd == _fileno(stdout)) { - DWORD bytesWritten = 0; - if (FALSE != ansi::ParseAndPrintANSIString(GetStdHandle(STD_OUTPUT_HANDLE), buffer, (DWORD)count, &bytesWritten)) { - return (int)bytesWritten; - } else { - errno = GetLastError(); - return 0; - } - } else if (fd == _fileno(stderr)) { - DWORD bytesWritten = 0; - if (FALSE != ansi::ParseAndPrintANSIString(GetStdHandle(STD_ERROR_HANDLE), buffer, (DWORD)count, &bytesWritten)) { - return (int)bytesWritten; - } else { - errno = GetLastError(); - return 0; - } - } else { - return _write(fd, buffer, count); - } -} -#endif // _WIN32 - -#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 -#define LINENOISE_MAX_LINE 4096 -static const char *unsupported_term[] = {"dumb","cons25","emacs",NULL}; -static CompletionCallback completionCallback; - -#ifndef _WIN32 -static struct termios orig_termios; /* In order to restore at exit.*/ -#endif -static bool rawmode = false; /* For atexit() function to check if restore is needed*/ -static bool mlmode = false; /* Multi line mode. Default is single line. */ -static bool atexit_registered = false; /* Register atexit just 1 time. */ -static size_t history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; -static std::vector history; - -/* The linenoiseState structure represents the state during line editing. - * We pass this state to functions implementing specific editing - * functionalities. */ -struct linenoiseState { - int ifd; /* Terminal stdin file descriptor. */ - int ofd; /* Terminal stdout file descriptor. */ - char *buf; /* Edited line buffer. */ - int buflen; /* Edited line buffer size. */ - std::string prompt; /* Prompt to display. */ - int pos; /* Current cursor position. */ - int oldcolpos; /* Previous refresh cursor column position. */ - int len; /* Current edited line length. */ - int cols; /* Number of columns in terminal. */ - int maxrows; /* Maximum num of rows used so far (multiline mode) */ - int history_index; /* The history index we are currently editing. */ -}; - -enum KEY_ACTION { - KEY_NULL = 0, /* NULL */ - CTRL_A = 1, /* Ctrl+a */ - CTRL_B = 2, /* Ctrl-b */ - CTRL_C = 3, /* Ctrl-c */ - CTRL_D = 4, /* Ctrl-d */ - CTRL_E = 5, /* Ctrl-e */ - CTRL_F = 6, /* Ctrl-f */ - CTRL_H = 8, /* Ctrl-h */ - TAB = 9, /* Tab */ - CTRL_K = 11, /* Ctrl+k */ - CTRL_L = 12, /* Ctrl+l */ - ENTER = 13, /* Enter */ - CTRL_N = 14, /* Ctrl-n */ - CTRL_P = 16, /* Ctrl-p */ - CTRL_T = 20, /* Ctrl-t */ - CTRL_U = 21, /* Ctrl+u */ - CTRL_W = 23, /* Ctrl+w */ - ESC = 27, /* Escape */ - BACKSPACE = 127 /* Backspace */ -}; - -void linenoiseAtExit(void); -bool AddHistory(const char *line); -void refreshLine(struct linenoiseState *l); - -/* ============================ UTF8 utilities ============================== */ - -static unsigned long unicodeWideCharTable[][2] = { - { 0x1100, 0x115F }, { 0x2329, 0x232A }, { 0x2E80, 0x2E99, }, { 0x2E9B, 0x2EF3, }, - { 0x2F00, 0x2FD5, }, { 0x2FF0, 0x2FFB, }, { 0x3000, 0x303E, }, { 0x3041, 0x3096, }, - { 0x3099, 0x30FF, }, { 0x3105, 0x312D, }, { 0x3131, 0x318E, }, { 0x3190, 0x31BA, }, - { 0x31C0, 0x31E3, }, { 0x31F0, 0x321E, }, { 0x3220, 0x3247, }, { 0x3250, 0x4DBF, }, - { 0x4E00, 0xA48C, }, { 0xA490, 0xA4C6, }, { 0xA960, 0xA97C, }, { 0xAC00, 0xD7A3, }, - { 0xF900, 0xFAFF, }, { 0xFE10, 0xFE19, }, { 0xFE30, 0xFE52, }, { 0xFE54, 0xFE66, }, - { 0xFE68, 0xFE6B, }, { 0xFF01, 0xFFE6, }, - { 0x1B000, 0x1B001, }, { 0x1F200, 0x1F202, }, { 0x1F210, 0x1F23A, }, - { 0x1F240, 0x1F248, }, { 0x1F250, 0x1F251, }, { 0x20000, 0x3FFFD, }, -}; - -static int unicodeWideCharTableSize = sizeof(unicodeWideCharTable) / sizeof(unicodeWideCharTable[0]); - -static int unicodeIsWideChar(unsigned long cp) -{ - int i; - for (i = 0; i < unicodeWideCharTableSize; i++) { - if (unicodeWideCharTable[i][0] <= cp && cp <= unicodeWideCharTable[i][1]) { - return 1; - } - } - return 0; -} - -static unsigned long unicodeCombiningCharTable[] = { - 0x0300,0x0301,0x0302,0x0303,0x0304,0x0305,0x0306,0x0307, - 0x0308,0x0309,0x030A,0x030B,0x030C,0x030D,0x030E,0x030F, - 0x0310,0x0311,0x0312,0x0313,0x0314,0x0315,0x0316,0x0317, - 0x0318,0x0319,0x031A,0x031B,0x031C,0x031D,0x031E,0x031F, - 0x0320,0x0321,0x0322,0x0323,0x0324,0x0325,0x0326,0x0327, - 0x0328,0x0329,0x032A,0x032B,0x032C,0x032D,0x032E,0x032F, - 0x0330,0x0331,0x0332,0x0333,0x0334,0x0335,0x0336,0x0337, - 0x0338,0x0339,0x033A,0x033B,0x033C,0x033D,0x033E,0x033F, - 0x0340,0x0341,0x0342,0x0343,0x0344,0x0345,0x0346,0x0347, - 0x0348,0x0349,0x034A,0x034B,0x034C,0x034D,0x034E,0x034F, - 0x0350,0x0351,0x0352,0x0353,0x0354,0x0355,0x0356,0x0357, - 0x0358,0x0359,0x035A,0x035B,0x035C,0x035D,0x035E,0x035F, - 0x0360,0x0361,0x0362,0x0363,0x0364,0x0365,0x0366,0x0367, - 0x0368,0x0369,0x036A,0x036B,0x036C,0x036D,0x036E,0x036F, - 0x0483,0x0484,0x0485,0x0486,0x0487,0x0591,0x0592,0x0593, - 0x0594,0x0595,0x0596,0x0597,0x0598,0x0599,0x059A,0x059B, - 0x059C,0x059D,0x059E,0x059F,0x05A0,0x05A1,0x05A2,0x05A3, - 0x05A4,0x05A5,0x05A6,0x05A7,0x05A8,0x05A9,0x05AA,0x05AB, - 0x05AC,0x05AD,0x05AE,0x05AF,0x05B0,0x05B1,0x05B2,0x05B3, - 0x05B4,0x05B5,0x05B6,0x05B7,0x05B8,0x05B9,0x05BA,0x05BB, - 0x05BC,0x05BD,0x05BF,0x05C1,0x05C2,0x05C4,0x05C5,0x05C7, - 0x0610,0x0611,0x0612,0x0613,0x0614,0x0615,0x0616,0x0617, - 0x0618,0x0619,0x061A,0x064B,0x064C,0x064D,0x064E,0x064F, - 0x0650,0x0651,0x0652,0x0653,0x0654,0x0655,0x0656,0x0657, - 0x0658,0x0659,0x065A,0x065B,0x065C,0x065D,0x065E,0x065F, - 0x0670,0x06D6,0x06D7,0x06D8,0x06D9,0x06DA,0x06DB,0x06DC, - 0x06DF,0x06E0,0x06E1,0x06E2,0x06E3,0x06E4,0x06E7,0x06E8, - 0x06EA,0x06EB,0x06EC,0x06ED,0x0711,0x0730,0x0731,0x0732, - 0x0733,0x0734,0x0735,0x0736,0x0737,0x0738,0x0739,0x073A, - 0x073B,0x073C,0x073D,0x073E,0x073F,0x0740,0x0741,0x0742, - 0x0743,0x0744,0x0745,0x0746,0x0747,0x0748,0x0749,0x074A, - 0x07A6,0x07A7,0x07A8,0x07A9,0x07AA,0x07AB,0x07AC,0x07AD, - 0x07AE,0x07AF,0x07B0,0x07EB,0x07EC,0x07ED,0x07EE,0x07EF, - 0x07F0,0x07F1,0x07F2,0x07F3,0x0816,0x0817,0x0818,0x0819, - 0x081B,0x081C,0x081D,0x081E,0x081F,0x0820,0x0821,0x0822, - 0x0823,0x0825,0x0826,0x0827,0x0829,0x082A,0x082B,0x082C, - 0x082D,0x0859,0x085A,0x085B,0x08E3,0x08E4,0x08E5,0x08E6, - 0x08E7,0x08E8,0x08E9,0x08EA,0x08EB,0x08EC,0x08ED,0x08EE, - 0x08EF,0x08F0,0x08F1,0x08F2,0x08F3,0x08F4,0x08F5,0x08F6, - 0x08F7,0x08F8,0x08F9,0x08FA,0x08FB,0x08FC,0x08FD,0x08FE, - 0x08FF,0x0900,0x0901,0x0902,0x093A,0x093C,0x0941,0x0942, - 0x0943,0x0944,0x0945,0x0946,0x0947,0x0948,0x094D,0x0951, - 0x0952,0x0953,0x0954,0x0955,0x0956,0x0957,0x0962,0x0963, - 0x0981,0x09BC,0x09C1,0x09C2,0x09C3,0x09C4,0x09CD,0x09E2, - 0x09E3,0x0A01,0x0A02,0x0A3C,0x0A41,0x0A42,0x0A47,0x0A48, - 0x0A4B,0x0A4C,0x0A4D,0x0A51,0x0A70,0x0A71,0x0A75,0x0A81, - 0x0A82,0x0ABC,0x0AC1,0x0AC2,0x0AC3,0x0AC4,0x0AC5,0x0AC7, - 0x0AC8,0x0ACD,0x0AE2,0x0AE3,0x0B01,0x0B3C,0x0B3F,0x0B41, - 0x0B42,0x0B43,0x0B44,0x0B4D,0x0B56,0x0B62,0x0B63,0x0B82, - 0x0BC0,0x0BCD,0x0C00,0x0C3E,0x0C3F,0x0C40,0x0C46,0x0C47, - 0x0C48,0x0C4A,0x0C4B,0x0C4C,0x0C4D,0x0C55,0x0C56,0x0C62, - 0x0C63,0x0C81,0x0CBC,0x0CBF,0x0CC6,0x0CCC,0x0CCD,0x0CE2, - 0x0CE3,0x0D01,0x0D41,0x0D42,0x0D43,0x0D44,0x0D4D,0x0D62, - 0x0D63,0x0DCA,0x0DD2,0x0DD3,0x0DD4,0x0DD6,0x0E31,0x0E34, - 0x0E35,0x0E36,0x0E37,0x0E38,0x0E39,0x0E3A,0x0E47,0x0E48, - 0x0E49,0x0E4A,0x0E4B,0x0E4C,0x0E4D,0x0E4E,0x0EB1,0x0EB4, - 0x0EB5,0x0EB6,0x0EB7,0x0EB8,0x0EB9,0x0EBB,0x0EBC,0x0EC8, - 0x0EC9,0x0ECA,0x0ECB,0x0ECC,0x0ECD,0x0F18,0x0F19,0x0F35, - 0x0F37,0x0F39,0x0F71,0x0F72,0x0F73,0x0F74,0x0F75,0x0F76, - 0x0F77,0x0F78,0x0F79,0x0F7A,0x0F7B,0x0F7C,0x0F7D,0x0F7E, - 0x0F80,0x0F81,0x0F82,0x0F83,0x0F84,0x0F86,0x0F87,0x0F8D, - 0x0F8E,0x0F8F,0x0F90,0x0F91,0x0F92,0x0F93,0x0F94,0x0F95, - 0x0F96,0x0F97,0x0F99,0x0F9A,0x0F9B,0x0F9C,0x0F9D,0x0F9E, - 0x0F9F,0x0FA0,0x0FA1,0x0FA2,0x0FA3,0x0FA4,0x0FA5,0x0FA6, - 0x0FA7,0x0FA8,0x0FA9,0x0FAA,0x0FAB,0x0FAC,0x0FAD,0x0FAE, - 0x0FAF,0x0FB0,0x0FB1,0x0FB2,0x0FB3,0x0FB4,0x0FB5,0x0FB6, - 0x0FB7,0x0FB8,0x0FB9,0x0FBA,0x0FBB,0x0FBC,0x0FC6,0x102D, - 0x102E,0x102F,0x1030,0x1032,0x1033,0x1034,0x1035,0x1036, - 0x1037,0x1039,0x103A,0x103D,0x103E,0x1058,0x1059,0x105E, - 0x105F,0x1060,0x1071,0x1072,0x1073,0x1074,0x1082,0x1085, - 0x1086,0x108D,0x109D,0x135D,0x135E,0x135F,0x1712,0x1713, - 0x1714,0x1732,0x1733,0x1734,0x1752,0x1753,0x1772,0x1773, - 0x17B4,0x17B5,0x17B7,0x17B8,0x17B9,0x17BA,0x17BB,0x17BC, - 0x17BD,0x17C6,0x17C9,0x17CA,0x17CB,0x17CC,0x17CD,0x17CE, - 0x17CF,0x17D0,0x17D1,0x17D2,0x17D3,0x17DD,0x180B,0x180C, - 0x180D,0x18A9,0x1920,0x1921,0x1922,0x1927,0x1928,0x1932, - 0x1939,0x193A,0x193B,0x1A17,0x1A18,0x1A1B,0x1A56,0x1A58, - 0x1A59,0x1A5A,0x1A5B,0x1A5C,0x1A5D,0x1A5E,0x1A60,0x1A62, - 0x1A65,0x1A66,0x1A67,0x1A68,0x1A69,0x1A6A,0x1A6B,0x1A6C, - 0x1A73,0x1A74,0x1A75,0x1A76,0x1A77,0x1A78,0x1A79,0x1A7A, - 0x1A7B,0x1A7C,0x1A7F,0x1AB0,0x1AB1,0x1AB2,0x1AB3,0x1AB4, - 0x1AB5,0x1AB6,0x1AB7,0x1AB8,0x1AB9,0x1ABA,0x1ABB,0x1ABC, - 0x1ABD,0x1B00,0x1B01,0x1B02,0x1B03,0x1B34,0x1B36,0x1B37, - 0x1B38,0x1B39,0x1B3A,0x1B3C,0x1B42,0x1B6B,0x1B6C,0x1B6D, - 0x1B6E,0x1B6F,0x1B70,0x1B71,0x1B72,0x1B73,0x1B80,0x1B81, - 0x1BA2,0x1BA3,0x1BA4,0x1BA5,0x1BA8,0x1BA9,0x1BAB,0x1BAC, - 0x1BAD,0x1BE6,0x1BE8,0x1BE9,0x1BED,0x1BEF,0x1BF0,0x1BF1, - 0x1C2C,0x1C2D,0x1C2E,0x1C2F,0x1C30,0x1C31,0x1C32,0x1C33, - 0x1C36,0x1C37,0x1CD0,0x1CD1,0x1CD2,0x1CD4,0x1CD5,0x1CD6, - 0x1CD7,0x1CD8,0x1CD9,0x1CDA,0x1CDB,0x1CDC,0x1CDD,0x1CDE, - 0x1CDF,0x1CE0,0x1CE2,0x1CE3,0x1CE4,0x1CE5,0x1CE6,0x1CE7, - 0x1CE8,0x1CED,0x1CF4,0x1CF8,0x1CF9,0x1DC0,0x1DC1,0x1DC2, - 0x1DC3,0x1DC4,0x1DC5,0x1DC6,0x1DC7,0x1DC8,0x1DC9,0x1DCA, - 0x1DCB,0x1DCC,0x1DCD,0x1DCE,0x1DCF,0x1DD0,0x1DD1,0x1DD2, - 0x1DD3,0x1DD4,0x1DD5,0x1DD6,0x1DD7,0x1DD8,0x1DD9,0x1DDA, - 0x1DDB,0x1DDC,0x1DDD,0x1DDE,0x1DDF,0x1DE0,0x1DE1,0x1DE2, - 0x1DE3,0x1DE4,0x1DE5,0x1DE6,0x1DE7,0x1DE8,0x1DE9,0x1DEA, - 0x1DEB,0x1DEC,0x1DED,0x1DEE,0x1DEF,0x1DF0,0x1DF1,0x1DF2, - 0x1DF3,0x1DF4,0x1DF5,0x1DFC,0x1DFD,0x1DFE,0x1DFF,0x20D0, - 0x20D1,0x20D2,0x20D3,0x20D4,0x20D5,0x20D6,0x20D7,0x20D8, - 0x20D9,0x20DA,0x20DB,0x20DC,0x20E1,0x20E5,0x20E6,0x20E7, - 0x20E8,0x20E9,0x20EA,0x20EB,0x20EC,0x20ED,0x20EE,0x20EF, - 0x20F0,0x2CEF,0x2CF0,0x2CF1,0x2D7F,0x2DE0,0x2DE1,0x2DE2, - 0x2DE3,0x2DE4,0x2DE5,0x2DE6,0x2DE7,0x2DE8,0x2DE9,0x2DEA, - 0x2DEB,0x2DEC,0x2DED,0x2DEE,0x2DEF,0x2DF0,0x2DF1,0x2DF2, - 0x2DF3,0x2DF4,0x2DF5,0x2DF6,0x2DF7,0x2DF8,0x2DF9,0x2DFA, - 0x2DFB,0x2DFC,0x2DFD,0x2DFE,0x2DFF,0x302A,0x302B,0x302C, - 0x302D,0x3099,0x309A,0xA66F,0xA674,0xA675,0xA676,0xA677, - 0xA678,0xA679,0xA67A,0xA67B,0xA67C,0xA67D,0xA69E,0xA69F, - 0xA6F0,0xA6F1,0xA802,0xA806,0xA80B,0xA825,0xA826,0xA8C4, - 0xA8E0,0xA8E1,0xA8E2,0xA8E3,0xA8E4,0xA8E5,0xA8E6,0xA8E7, - 0xA8E8,0xA8E9,0xA8EA,0xA8EB,0xA8EC,0xA8ED,0xA8EE,0xA8EF, - 0xA8F0,0xA8F1,0xA926,0xA927,0xA928,0xA929,0xA92A,0xA92B, - 0xA92C,0xA92D,0xA947,0xA948,0xA949,0xA94A,0xA94B,0xA94C, - 0xA94D,0xA94E,0xA94F,0xA950,0xA951,0xA980,0xA981,0xA982, - 0xA9B3,0xA9B6,0xA9B7,0xA9B8,0xA9B9,0xA9BC,0xA9E5,0xAA29, - 0xAA2A,0xAA2B,0xAA2C,0xAA2D,0xAA2E,0xAA31,0xAA32,0xAA35, - 0xAA36,0xAA43,0xAA4C,0xAA7C,0xAAB0,0xAAB2,0xAAB3,0xAAB4, - 0xAAB7,0xAAB8,0xAABE,0xAABF,0xAAC1,0xAAEC,0xAAED,0xAAF6, - 0xABE5,0xABE8,0xABED,0xFB1E,0xFE00,0xFE01,0xFE02,0xFE03, - 0xFE04,0xFE05,0xFE06,0xFE07,0xFE08,0xFE09,0xFE0A,0xFE0B, - 0xFE0C,0xFE0D,0xFE0E,0xFE0F,0xFE20,0xFE21,0xFE22,0xFE23, - 0xFE24,0xFE25,0xFE26,0xFE27,0xFE28,0xFE29,0xFE2A,0xFE2B, - 0xFE2C,0xFE2D,0xFE2E,0xFE2F, - 0x101FD,0x102E0,0x10376,0x10377,0x10378,0x10379,0x1037A,0x10A01, - 0x10A02,0x10A03,0x10A05,0x10A06,0x10A0C,0x10A0D,0x10A0E,0x10A0F, - 0x10A38,0x10A39,0x10A3A,0x10A3F,0x10AE5,0x10AE6,0x11001,0x11038, - 0x11039,0x1103A,0x1103B,0x1103C,0x1103D,0x1103E,0x1103F,0x11040, - 0x11041,0x11042,0x11043,0x11044,0x11045,0x11046,0x1107F,0x11080, - 0x11081,0x110B3,0x110B4,0x110B5,0x110B6,0x110B9,0x110BA,0x11100, - 0x11101,0x11102,0x11127,0x11128,0x11129,0x1112A,0x1112B,0x1112D, - 0x1112E,0x1112F,0x11130,0x11131,0x11132,0x11133,0x11134,0x11173, - 0x11180,0x11181,0x111B6,0x111B7,0x111B8,0x111B9,0x111BA,0x111BB, - 0x111BC,0x111BD,0x111BE,0x111CA,0x111CB,0x111CC,0x1122F,0x11230, - 0x11231,0x11234,0x11236,0x11237,0x112DF,0x112E3,0x112E4,0x112E5, - 0x112E6,0x112E7,0x112E8,0x112E9,0x112EA,0x11300,0x11301,0x1133C, - 0x11340,0x11366,0x11367,0x11368,0x11369,0x1136A,0x1136B,0x1136C, - 0x11370,0x11371,0x11372,0x11373,0x11374,0x114B3,0x114B4,0x114B5, - 0x114B6,0x114B7,0x114B8,0x114BA,0x114BF,0x114C0,0x114C2,0x114C3, - 0x115B2,0x115B3,0x115B4,0x115B5,0x115BC,0x115BD,0x115BF,0x115C0, - 0x115DC,0x115DD,0x11633,0x11634,0x11635,0x11636,0x11637,0x11638, - 0x11639,0x1163A,0x1163D,0x1163F,0x11640,0x116AB,0x116AD,0x116B0, - 0x116B1,0x116B2,0x116B3,0x116B4,0x116B5,0x116B7,0x1171D,0x1171E, - 0x1171F,0x11722,0x11723,0x11724,0x11725,0x11727,0x11728,0x11729, - 0x1172A,0x1172B,0x16AF0,0x16AF1,0x16AF2,0x16AF3,0x16AF4,0x16B30, - 0x16B31,0x16B32,0x16B33,0x16B34,0x16B35,0x16B36,0x16F8F,0x16F90, - 0x16F91,0x16F92,0x1BC9D,0x1BC9E,0x1D167,0x1D168,0x1D169,0x1D17B, - 0x1D17C,0x1D17D,0x1D17E,0x1D17F,0x1D180,0x1D181,0x1D182,0x1D185, - 0x1D186,0x1D187,0x1D188,0x1D189,0x1D18A,0x1D18B,0x1D1AA,0x1D1AB, - 0x1D1AC,0x1D1AD,0x1D242,0x1D243,0x1D244,0x1DA00,0x1DA01,0x1DA02, - 0x1DA03,0x1DA04,0x1DA05,0x1DA06,0x1DA07,0x1DA08,0x1DA09,0x1DA0A, - 0x1DA0B,0x1DA0C,0x1DA0D,0x1DA0E,0x1DA0F,0x1DA10,0x1DA11,0x1DA12, - 0x1DA13,0x1DA14,0x1DA15,0x1DA16,0x1DA17,0x1DA18,0x1DA19,0x1DA1A, - 0x1DA1B,0x1DA1C,0x1DA1D,0x1DA1E,0x1DA1F,0x1DA20,0x1DA21,0x1DA22, - 0x1DA23,0x1DA24,0x1DA25,0x1DA26,0x1DA27,0x1DA28,0x1DA29,0x1DA2A, - 0x1DA2B,0x1DA2C,0x1DA2D,0x1DA2E,0x1DA2F,0x1DA30,0x1DA31,0x1DA32, - 0x1DA33,0x1DA34,0x1DA35,0x1DA36,0x1DA3B,0x1DA3C,0x1DA3D,0x1DA3E, - 0x1DA3F,0x1DA40,0x1DA41,0x1DA42,0x1DA43,0x1DA44,0x1DA45,0x1DA46, - 0x1DA47,0x1DA48,0x1DA49,0x1DA4A,0x1DA4B,0x1DA4C,0x1DA4D,0x1DA4E, - 0x1DA4F,0x1DA50,0x1DA51,0x1DA52,0x1DA53,0x1DA54,0x1DA55,0x1DA56, - 0x1DA57,0x1DA58,0x1DA59,0x1DA5A,0x1DA5B,0x1DA5C,0x1DA5D,0x1DA5E, - 0x1DA5F,0x1DA60,0x1DA61,0x1DA62,0x1DA63,0x1DA64,0x1DA65,0x1DA66, - 0x1DA67,0x1DA68,0x1DA69,0x1DA6A,0x1DA6B,0x1DA6C,0x1DA75,0x1DA84, - 0x1DA9B,0x1DA9C,0x1DA9D,0x1DA9E,0x1DA9F,0x1DAA1,0x1DAA2,0x1DAA3, - 0x1DAA4,0x1DAA5,0x1DAA6,0x1DAA7,0x1DAA8,0x1DAA9,0x1DAAA,0x1DAAB, - 0x1DAAC,0x1DAAD,0x1DAAE,0x1DAAF,0x1E8D0,0x1E8D1,0x1E8D2,0x1E8D3, - 0x1E8D4,0x1E8D5,0x1E8D6,0xE0100,0xE0101,0xE0102,0xE0103,0xE0104, - 0xE0105,0xE0106,0xE0107,0xE0108,0xE0109,0xE010A,0xE010B,0xE010C, - 0xE010D,0xE010E,0xE010F,0xE0110,0xE0111,0xE0112,0xE0113,0xE0114, - 0xE0115,0xE0116,0xE0117,0xE0118,0xE0119,0xE011A,0xE011B,0xE011C, - 0xE011D,0xE011E,0xE011F,0xE0120,0xE0121,0xE0122,0xE0123,0xE0124, - 0xE0125,0xE0126,0xE0127,0xE0128,0xE0129,0xE012A,0xE012B,0xE012C, - 0xE012D,0xE012E,0xE012F,0xE0130,0xE0131,0xE0132,0xE0133,0xE0134, - 0xE0135,0xE0136,0xE0137,0xE0138,0xE0139,0xE013A,0xE013B,0xE013C, - 0xE013D,0xE013E,0xE013F,0xE0140,0xE0141,0xE0142,0xE0143,0xE0144, - 0xE0145,0xE0146,0xE0147,0xE0148,0xE0149,0xE014A,0xE014B,0xE014C, - 0xE014D,0xE014E,0xE014F,0xE0150,0xE0151,0xE0152,0xE0153,0xE0154, - 0xE0155,0xE0156,0xE0157,0xE0158,0xE0159,0xE015A,0xE015B,0xE015C, - 0xE015D,0xE015E,0xE015F,0xE0160,0xE0161,0xE0162,0xE0163,0xE0164, - 0xE0165,0xE0166,0xE0167,0xE0168,0xE0169,0xE016A,0xE016B,0xE016C, - 0xE016D,0xE016E,0xE016F,0xE0170,0xE0171,0xE0172,0xE0173,0xE0174, - 0xE0175,0xE0176,0xE0177,0xE0178,0xE0179,0xE017A,0xE017B,0xE017C, - 0xE017D,0xE017E,0xE017F,0xE0180,0xE0181,0xE0182,0xE0183,0xE0184, - 0xE0185,0xE0186,0xE0187,0xE0188,0xE0189,0xE018A,0xE018B,0xE018C, - 0xE018D,0xE018E,0xE018F,0xE0190,0xE0191,0xE0192,0xE0193,0xE0194, - 0xE0195,0xE0196,0xE0197,0xE0198,0xE0199,0xE019A,0xE019B,0xE019C, - 0xE019D,0xE019E,0xE019F,0xE01A0,0xE01A1,0xE01A2,0xE01A3,0xE01A4, - 0xE01A5,0xE01A6,0xE01A7,0xE01A8,0xE01A9,0xE01AA,0xE01AB,0xE01AC, - 0xE01AD,0xE01AE,0xE01AF,0xE01B0,0xE01B1,0xE01B2,0xE01B3,0xE01B4, - 0xE01B5,0xE01B6,0xE01B7,0xE01B8,0xE01B9,0xE01BA,0xE01BB,0xE01BC, - 0xE01BD,0xE01BE,0xE01BF,0xE01C0,0xE01C1,0xE01C2,0xE01C3,0xE01C4, - 0xE01C5,0xE01C6,0xE01C7,0xE01C8,0xE01C9,0xE01CA,0xE01CB,0xE01CC, - 0xE01CD,0xE01CE,0xE01CF,0xE01D0,0xE01D1,0xE01D2,0xE01D3,0xE01D4, - 0xE01D5,0xE01D6,0xE01D7,0xE01D8,0xE01D9,0xE01DA,0xE01DB,0xE01DC, - 0xE01DD,0xE01DE,0xE01DF,0xE01E0,0xE01E1,0xE01E2,0xE01E3,0xE01E4, - 0xE01E5,0xE01E6,0xE01E7,0xE01E8,0xE01E9,0xE01EA,0xE01EB,0xE01EC, - 0xE01ED,0xE01EE,0xE01EF, -}; - -static int unicodeCombiningCharTableSize = sizeof(unicodeCombiningCharTable) / sizeof(unicodeCombiningCharTable[0]); - -inline int unicodeIsCombiningChar(unsigned long cp) -{ - int i; - for (i = 0; i < unicodeCombiningCharTableSize; i++) { - if (unicodeCombiningCharTable[i] == cp) { - return 1; - } - } - return 0; -} - -/* Get length of previous UTF8 character - */ -inline int unicodePrevUTF8CharLen(char* buf, int pos) -{ - int end = pos--; - while (pos >= 0 && ((unsigned char)buf[pos] & 0xC0) == 0x80) { - pos--; - } - return end - pos; -} - -/* Get length of previous UTF8 character - */ -inline int unicodeUTF8CharLen(char* buf, int buf_len, int pos) -{ - if (pos == buf_len) { return 0; } - unsigned char ch = buf[pos]; - if (ch < 0x80) { return 1; } - else if (ch < 0xE0) { return 2; } - else if (ch < 0xF0) { return 3; } - else { return 4; } -} - -/* Convert UTF8 to Unicode code point - */ -inline int unicodeUTF8CharToCodePoint( - const char* buf, - int len, - int* cp) -{ - if (len) { - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - *cp = byte; - return 1; - } else if ((byte & 0xE0) == 0xC0) { - if (len >= 2) { - *cp = (((unsigned long)(buf[0] & 0x1F)) << 6) | - ((unsigned long)(buf[1] & 0x3F)); - return 2; - } - } else if ((byte & 0xF0) == 0xE0) { - if (len >= 3) { - *cp = (((unsigned long)(buf[0] & 0x0F)) << 12) | - (((unsigned long)(buf[1] & 0x3F)) << 6) | - ((unsigned long)(buf[2] & 0x3F)); - return 3; - } - } else if ((byte & 0xF8) == 0xF0) { - if (len >= 4) { - *cp = (((unsigned long)(buf[0] & 0x07)) << 18) | - (((unsigned long)(buf[1] & 0x3F)) << 12) | - (((unsigned long)(buf[2] & 0x3F)) << 6) | - ((unsigned long)(buf[3] & 0x3F)); - return 4; - } - } - } - return 0; -} - -/* Get length of grapheme - */ -inline int unicodeGraphemeLen(char* buf, int buf_len, int pos) -{ - if (pos == buf_len) { - return 0; - } - int beg = pos; - pos += unicodeUTF8CharLen(buf, buf_len, pos); - while (pos < buf_len) { - int len = unicodeUTF8CharLen(buf, buf_len, pos); - int cp = 0; - unicodeUTF8CharToCodePoint(buf + pos, len, &cp); - if (!unicodeIsCombiningChar(cp)) { - return pos - beg; - } - pos += len; - } - return pos - beg; -} - -/* Get length of previous grapheme - */ -inline int unicodePrevGraphemeLen(char* buf, int pos) -{ - if (pos == 0) { - return 0; - } - int end = pos; - while (pos > 0) { - int len = unicodePrevUTF8CharLen(buf, pos); - pos -= len; - int cp = 0; - unicodeUTF8CharToCodePoint(buf + pos, len, &cp); - if (!unicodeIsCombiningChar(cp)) { - return end - pos; - } - } - return 0; -} - -inline int isAnsiEscape(const char* buf, int buf_len, int* len) -{ - if (buf_len > 2 && !memcmp("\033[", buf, 2)) { - int off = 2; - while (off < buf_len) { - switch (buf[off++]) { - case 'A': case 'B': case 'C': case 'D': - case 'E': case 'F': case 'G': case 'H': - case 'J': case 'K': case 'S': case 'T': - case 'f': case 'm': - *len = off; - return 1; - } - } - } - return 0; -} - -/* Get column position for the single line mode. - */ -inline int unicodeColumnPos(const char* buf, int buf_len) -{ - int ret = 0; - - int off = 0; - while (off < buf_len) { - int len; - if (isAnsiEscape(buf + off, buf_len - off, &len)) { - off += len; - continue; - } - - int cp = 0; - len = unicodeUTF8CharToCodePoint(buf + off, buf_len - off, &cp); - - if (!unicodeIsCombiningChar(cp)) { - ret += unicodeIsWideChar(cp) ? 2 : 1; - } - - off += len; - } - - return ret; -} - -/* Get column position for the multi line mode. - */ -inline int unicodeColumnPosForMultiLine(char* buf, int buf_len, int pos, int cols, int ini_pos) -{ - int ret = 0; - int colwid = ini_pos; - - int off = 0; - while (off < buf_len) { - int cp = 0; - int len = unicodeUTF8CharToCodePoint(buf + off, buf_len - off, &cp); - - int wid = 0; - if (!unicodeIsCombiningChar(cp)) { - wid = unicodeIsWideChar(cp) ? 2 : 1; - } - - int dif = (int)(colwid + wid) - (int)cols; - if (dif > 0) { - ret += dif; - colwid = wid; - } else if (dif == 0) { - colwid = 0; - } else { - colwid += wid; - } - - if (off >= pos) { - break; - } - - off += len; - ret += wid; - } - - return ret; -} - -/* Read UTF8 character from file. - */ -inline int unicodeReadUTF8Char(int fd, char* buf, int* cp) -{ - int nread = read(fd,&buf[0],1); - - if (nread <= 0) { return nread; } - - unsigned char byte = buf[0]; - - if ((byte & 0x80) == 0) { - ; - } else if ((byte & 0xE0) == 0xC0) { - nread = read(fd,&buf[1],1); - if (nread <= 0) { return nread; } - } else if ((byte & 0xF0) == 0xE0) { - nread = read(fd,&buf[1],2); - if (nread <= 0) { return nread; } - } else if ((byte & 0xF8) == 0xF0) { - nread = read(fd,&buf[1],3); - if (nread <= 0) { return nread; } - } else { - return -1; - } - - return unicodeUTF8CharToCodePoint(buf, 4, cp); -} - -/* ======================= Low level terminal handling ====================== */ - -/* Set if to use or not the multi line mode. */ -inline void SetMultiLine(bool ml) { - mlmode = ml; -} - -/* Return true if the terminal name is in the list of terminals we know are - * not able to understand basic escape sequences. */ -inline bool isUnsupportedTerm(void) { -#ifndef _WIN32 - char *term = getenv("TERM"); - int j; - - if (term == NULL) return false; - for (j = 0; unsupported_term[j]; j++) - if (!strcasecmp(term,unsupported_term[j])) return true; -#endif - return false; -} - -/* Raw mode: 1960 magic shit. */ -inline bool enableRawMode(int fd) { -#ifndef _WIN32 - struct termios raw; - - if (!isatty(STDIN_FILENO)) goto fatal; - if (!atexit_registered) { - atexit(linenoiseAtExit); - atexit_registered = true; - } - if (tcgetattr(fd,&orig_termios) == -1) goto fatal; - - raw = orig_termios; /* modify the original mode */ - /* input modes: no break, no CR to NL, no parity check, no strip char, - * no start/stop output control. */ - raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); - /* output modes - disable post processing */ - // NOTE: Multithreaded issue #20 (https://github.com/yhirose/cpp-linenoise/issues/20) - // raw.c_oflag &= ~(OPOST); - /* control modes - set 8 bit chars */ - raw.c_cflag |= (CS8); - /* local modes - echoing off, canonical off, no extended functions, - * no signal chars (^Z,^C) */ - raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); - /* control chars - set return condition: min number of bytes and timer. - * We want read to return every single byte, without timeout. */ - raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ - - /* put terminal in raw mode after flushing */ - if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; - rawmode = true; -#else - if (!atexit_registered) { - /* Cleanup them at exit */ - atexit(linenoiseAtExit); - atexit_registered = true; - - /* Init windows console handles only once */ - hOut = GetStdHandle(STD_OUTPUT_HANDLE); - if (hOut==INVALID_HANDLE_VALUE) goto fatal; - } - - DWORD consolemodeOut; - if (!GetConsoleMode(hOut, &consolemodeOut)) { - CloseHandle(hOut); - errno = ENOTTY; - return false; - }; - - hIn = GetStdHandle(STD_INPUT_HANDLE); - if (hIn == INVALID_HANDLE_VALUE) { - CloseHandle(hOut); - errno = ENOTTY; - return false; - } - - GetConsoleMode(hIn, &consolemodeIn); - /* Enable raw mode */ - SetConsoleMode(hIn, consolemodeIn & ~ENABLE_PROCESSED_INPUT); - - rawmode = true; -#endif - return true; - -fatal: - errno = ENOTTY; - return false; -} - -inline void disableRawMode(int fd) { -#ifdef _WIN32 - if (consolemodeIn) { - SetConsoleMode(hIn, consolemodeIn); - consolemodeIn = 0; - } - rawmode = false; -#else - /* Don't even check the return value as it's too late. */ - if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) - rawmode = false; -#endif -} - -/* Use the ESC [6n escape sequence to query the horizontal cursor position - * and return it. On error -1 is returned, on success the position of the - * cursor. */ -inline int getCursorPosition(int ifd, int ofd) { - char buf[32]; - int cols, rows; - unsigned int i = 0; - - /* Report cursor location */ - if (write(ofd, "\x1b[6n", 4) != 4) return -1; - - /* Read the response: ESC [ rows ; cols R */ - while (i < sizeof(buf)-1) { - if (read(ifd,buf+i,1) != 1) break; - if (buf[i] == 'R') break; - i++; - } - buf[i] = '\0'; - - /* Parse it. */ - if (buf[0] != ESC || buf[1] != '[') return -1; - if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; - return cols; -} - -/* Try to get the number of columns in the current terminal, or assume 80 - * if it fails. */ -inline int getColumns(int ifd, int ofd) { -#ifdef _WIN32 - CONSOLE_SCREEN_BUFFER_INFO b; - - if (!GetConsoleScreenBufferInfo(hOut, &b)) return 80; - return b.srWindow.Right - b.srWindow.Left; -#else - struct winsize ws; - - if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { - /* ioctl() failed. Try to query the terminal itself. */ - int start, cols; - - /* Get the initial position so we can restore it later. */ - start = getCursorPosition(ifd,ofd); - if (start == -1) goto failed; - - /* Go to right margin and get position. */ - if (write(ofd,"\x1b[999C",6) != 6) goto failed; - cols = getCursorPosition(ifd,ofd); - if (cols == -1) goto failed; - - /* Restore position. */ - if (cols > start) { - char seq[32]; - snprintf(seq,32,"\x1b[%dD",cols-start); - if (write(ofd,seq,strlen(seq)) == -1) { - /* Can't recover... */ - } - } - return cols; - } else { - return ws.ws_col; - } - -failed: - return 80; -#endif -} - -/* Clear the screen. Used to handle ctrl+l */ -inline void linenoiseClearScreen(void) { - if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { - /* nothing to do, just to avoid warning. */ - } -} - -/* Beep, used for completion when there is nothing to complete or when all - * the choices were already shown. */ -inline void linenoiseBeep(void) { - fprintf(stderr, "\x7"); - fflush(stderr); -} - -/* ============================== Completion ================================ */ - -/* This is an helper function for linenoiseEdit() and is called when the - * user types the key in order to complete the string currently in the - * input. - * - * The state of the editing is encapsulated into the pointed linenoiseState - * structure as described in the structure definition. */ -inline int completeLine(struct linenoiseState *ls, char *cbuf, int *c) { - std::vector lc; - int nread = 0, nwritten; - *c = 0; - - completionCallback(ls->buf,lc); - if (lc.empty()) { - linenoiseBeep(); - } else { - int stop = 0, i = 0; - - while(!stop) { - /* Show completion or original buffer */ - if (i < static_cast(lc.size())) { - struct linenoiseState saved = *ls; - - ls->len = ls->pos = static_cast(lc[i].size()); - ls->buf = &lc[i][0]; - refreshLine(ls); - ls->len = saved.len; - ls->pos = saved.pos; - ls->buf = saved.buf; - } else { - refreshLine(ls); - } - - //nread = read(ls->ifd,&c,1); -#ifdef _WIN32 - nread = win32read(c); - if (nread == 1) { - cbuf[0] = *c; - } -#else - nread = unicodeReadUTF8Char(ls->ifd,cbuf,c); -#endif - if (nread <= 0) { - *c = -1; - return nread; - } - - switch(*c) { - case 9: /* tab */ - i = (i+1) % (lc.size()+1); - if (i == static_cast(lc.size())) linenoiseBeep(); - break; - case 27: /* escape */ - /* Re-show original buffer */ - if (i < static_cast(lc.size())) refreshLine(ls); - stop = 1; - break; - default: - /* Update buffer and return */ - if (i < static_cast(lc.size())) { - nwritten = snprintf(ls->buf,ls->buflen,"%s",&lc[i][0]); - ls->len = ls->pos = nwritten; - } - stop = 1; - break; - } - } - } - - return nread; -} - -/* Register a callback function to be called for tab-completion. */ -inline void SetCompletionCallback(CompletionCallback fn) { - completionCallback = fn; -} - -/* =========================== Line editing ================================= */ - -/* Single line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. */ -inline void refreshSingleLine(struct linenoiseState *l) { - char seq[64]; - int pcolwid = unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length())); - int fd = l->ofd; - char *buf = l->buf; - int len = l->len; - int pos = l->pos; - std::string ab; - - while((pcolwid+unicodeColumnPos(buf, pos)) >= l->cols) { - int glen = unicodeGraphemeLen(buf, len, 0); - buf += glen; - len -= glen; - pos -= glen; - } - while (pcolwid+unicodeColumnPos(buf, len) > l->cols) { - len -= unicodePrevGraphemeLen(buf, len); - } - - /* Cursor to left edge */ - snprintf(seq,64,"\r"); - ab += seq; - /* Write the prompt and the current buffer content */ - ab += l->prompt; - ab.append(buf, len); - /* Erase to right */ - snprintf(seq,64,"\x1b[0K"); - ab += seq; - /* Move cursor to original position. */ - snprintf(seq,64,"\r\x1b[%dC", (int)(unicodeColumnPos(buf, pos)+pcolwid)); - ab += seq; - if (write(fd,ab.c_str(), static_cast(ab.length())) == -1) {} /* Can't recover from write error. */ -} - -/* Multi line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. */ -inline void refreshMultiLine(struct linenoiseState *l) { - char seq[64]; - int pcolwid = unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length())); - int colpos = unicodeColumnPosForMultiLine(l->buf, l->len, l->len, l->cols, pcolwid); - int colpos2; /* cursor column position. */ - int rows = (pcolwid+colpos+l->cols-1)/l->cols; /* rows used by current buf. */ - int rpos = (pcolwid+l->oldcolpos+l->cols)/l->cols; /* cursor relative row. */ - int rpos2; /* rpos after refresh. */ - int col; /* colum position, zero-based. */ - int old_rows = (int)l->maxrows; - int fd = l->ofd, j; - std::string ab; - - /* Update maxrows if needed. */ - if (rows > (int)l->maxrows) l->maxrows = rows; - - /* First step: clear all the lines used before. To do so start by - * going to the last row. */ - if (old_rows-rpos > 0) { - snprintf(seq,64,"\x1b[%dB", old_rows-rpos); - ab += seq; - } - - /* Now for every row clear it, go up. */ - for (j = 0; j < old_rows-1; j++) { - snprintf(seq,64,"\r\x1b[0K\x1b[1A"); - ab += seq; - } - - /* Clean the top line. */ - snprintf(seq,64,"\r\x1b[0K"); - ab += seq; - - /* Write the prompt and the current buffer content */ - ab += l->prompt; - ab.append(l->buf, l->len); - - /* Get text width to cursor position */ - colpos2 = unicodeColumnPosForMultiLine(l->buf, l->len, l->pos, l->cols, pcolwid); - - /* If we are at the very end of the screen with our prompt, we need to - * emit a newline and move the prompt to the first column. */ - if (l->pos && - l->pos == l->len && - (colpos2+pcolwid) % l->cols == 0) - { - ab += "\n"; - snprintf(seq,64,"\r"); - ab += seq; - rows++; - if (rows > (int)l->maxrows) l->maxrows = rows; - } - - /* Move cursor to right position. */ - rpos2 = (pcolwid+colpos2+l->cols)/l->cols; /* current cursor relative row. */ - - /* Go up till we reach the expected positon. */ - if (rows-rpos2 > 0) { - snprintf(seq,64,"\x1b[%dA", rows-rpos2); - ab += seq; - } - - /* Set column. */ - col = (pcolwid + colpos2) % l->cols; - if (col) - snprintf(seq,64,"\r\x1b[%dC", col); - else - snprintf(seq,64,"\r"); - ab += seq; - - l->oldcolpos = colpos2; - - if (write(fd,ab.c_str(), static_cast(ab.length())) == -1) {} /* Can't recover from write error. */ -} - -/* Calls the two low level functions refreshSingleLine() or - * refreshMultiLine() according to the selected mode. */ -inline void refreshLine(struct linenoiseState *l) { - if (mlmode) - refreshMultiLine(l); - else - refreshSingleLine(l); -} - -/* Insert the character 'c' at cursor current position. - * - * On error writing to the terminal -1 is returned, otherwise 0. */ -inline int linenoiseEditInsert(struct linenoiseState *l, const char* cbuf, int clen) { - if (l->len < l->buflen) { - if (l->len == l->pos) { - memcpy(&l->buf[l->pos],cbuf,clen); - l->pos+=clen; - l->len+=clen;; - l->buf[l->len] = '\0'; - if ((!mlmode && unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length()))+unicodeColumnPos(l->buf,l->len) < l->cols) /* || mlmode */) { - /* Avoid a full update of the line in the - * trivial case. */ - if (write(l->ofd,cbuf,clen) == -1) return -1; - } else { - refreshLine(l); - } - } else { - memmove(l->buf+l->pos+clen,l->buf+l->pos,l->len-l->pos); - memcpy(&l->buf[l->pos],cbuf,clen); - l->pos+=clen; - l->len+=clen; - l->buf[l->len] = '\0'; - refreshLine(l); - } - } - return 0; -} - -/* Move cursor on the left. */ -inline void linenoiseEditMoveLeft(struct linenoiseState *l) { - if (l->pos > 0) { - l->pos -= unicodePrevGraphemeLen(l->buf, l->pos); - refreshLine(l); - } -} - -/* Move cursor on the right. */ -inline void linenoiseEditMoveRight(struct linenoiseState *l) { - if (l->pos != l->len) { - l->pos += unicodeGraphemeLen(l->buf, l->len, l->pos); - refreshLine(l); - } -} - -/* Move cursor to the start of the line. */ -inline void linenoiseEditMoveHome(struct linenoiseState *l) { - if (l->pos != 0) { - l->pos = 0; - refreshLine(l); - } -} - -/* Move cursor to the end of the line. */ -inline void linenoiseEditMoveEnd(struct linenoiseState *l) { - if (l->pos != l->len) { - l->pos = l->len; - refreshLine(l); - } -} - -/* Substitute the currently edited line with the next or previous history - * entry as specified by 'dir'. */ -#define LINENOISE_HISTORY_NEXT 0 -#define LINENOISE_HISTORY_PREV 1 -inline void linenoiseEditHistoryNext(struct linenoiseState *l, int dir) { - if (history.size() > 1) { - /* Update the current history entry before to - * overwrite it with the next one. */ - history[history.size() - 1 - l->history_index] = l->buf; - /* Show the new entry */ - l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; - if (l->history_index < 0) { - l->history_index = 0; - return; - } else if (l->history_index >= (int)history.size()) { - l->history_index = static_cast(history.size())-1; - return; - } - memset(l->buf, 0, l->buflen); - strcpy(l->buf,history[history.size() - 1 - l->history_index].c_str()); - l->len = l->pos = static_cast(strlen(l->buf)); - refreshLine(l); - } -} - -/* Delete the character at the right of the cursor without altering the cursor - * position. Basically this is what happens with the "Delete" keyboard key. */ -inline void linenoiseEditDelete(struct linenoiseState *l) { - if (l->len > 0 && l->pos < l->len) { - int glen = unicodeGraphemeLen(l->buf,l->len,l->pos); - memmove(l->buf+l->pos,l->buf+l->pos+glen,l->len-l->pos-glen); - l->len-=glen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Backspace implementation. */ -inline void linenoiseEditBackspace(struct linenoiseState *l) { - if (l->pos > 0 && l->len > 0) { - int glen = unicodePrevGraphemeLen(l->buf,l->pos); - memmove(l->buf+l->pos-glen,l->buf+l->pos,l->len-l->pos); - l->pos-=glen; - l->len-=glen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Delete the previosu word, maintaining the cursor at the start of the - * current word. */ -inline void linenoiseEditDeletePrevWord(struct linenoiseState *l) { - int old_pos = l->pos; - int diff; - - while (l->pos > 0 && l->buf[l->pos-1] == ' ') - l->pos--; - while (l->pos > 0 && l->buf[l->pos-1] != ' ') - l->pos--; - diff = old_pos - l->pos; - memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); - l->len -= diff; - refreshLine(l); -} - -/* This function is the core of the line editing capability of linenoise. - * It expects 'fd' to be already in "raw mode" so that every key pressed - * will be returned ASAP to read(). - * - * The resulting string is put into 'buf' when the user type enter, or - * when ctrl+d is typed. - * - * The function returns the length of the current buffer. */ -inline int linenoiseEdit(int stdin_fd, int stdout_fd, char *buf, int buflen, const char *prompt) -{ - struct linenoiseState l; - - /* Populate the linenoise state that we pass to functions implementing - * specific editing functionalities. */ - l.ifd = stdin_fd; - l.ofd = stdout_fd; - l.buf = buf; - l.buflen = buflen; - l.prompt = prompt; - l.oldcolpos = l.pos = 0; - l.len = 0; - l.cols = getColumns(stdin_fd, stdout_fd); - l.maxrows = 0; - l.history_index = 0; - - /* Buffer starts empty. */ - l.buf[0] = '\0'; - l.buflen--; /* Make sure there is always space for the nulterm */ - - /* The latest history entry is always our current buffer, that - * initially is just an empty string. */ - AddHistory(""); - - if (write(l.ofd,prompt, static_cast(l.prompt.length())) == -1) return -1; - while(1) { - int c; - char cbuf[4]; - int nread; - char seq[3]; - -#ifdef _WIN32 - nread = win32read(&c); - if (nread == 1) { - cbuf[0] = c; - } -#else - nread = unicodeReadUTF8Char(l.ifd,cbuf,&c); -#endif - if (nread <= 0) return (int)l.len; - - /* Only autocomplete when the callback is set. It returns < 0 when - * there was an error reading from fd. Otherwise it will return the - * character that should be handled next. */ - if (c == 9 && completionCallback != NULL) { - nread = completeLine(&l,cbuf,&c); - /* Return on errors */ - if (c < 0) return l.len; - /* Read next character when 0 */ - if (c == 0) continue; - } - - switch(c) { - case ENTER: /* enter */ - if (!history.empty()) history.pop_back(); - if (mlmode) linenoiseEditMoveEnd(&l); - return (int)l.len; - case CTRL_C: /* ctrl-c */ - errno = EAGAIN; - return -1; - case BACKSPACE: /* backspace */ - case 8: /* ctrl-h */ - linenoiseEditBackspace(&l); - break; - case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the - line is empty, act as end-of-file. */ - if (l.len > 0) { - linenoiseEditDelete(&l); - } else { - history.pop_back(); - return -1; - } - break; - case CTRL_T: /* ctrl-t, swaps current character with previous. */ - if (l.pos > 0 && l.pos < l.len) { - char aux = buf[l.pos-1]; - buf[l.pos-1] = buf[l.pos]; - buf[l.pos] = aux; - if (l.pos != l.len-1) l.pos++; - refreshLine(&l); - } - break; - case CTRL_B: /* ctrl-b */ - linenoiseEditMoveLeft(&l); - break; - case CTRL_F: /* ctrl-f */ - linenoiseEditMoveRight(&l); - break; - case CTRL_P: /* ctrl-p */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV); - break; - case CTRL_N: /* ctrl-n */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT); - break; - case ESC: /* escape sequence */ - /* Read the next two bytes representing the escape sequence. - * Use two calls to handle slow terminals returning the two - * chars at different times. */ - if (read(l.ifd,seq,1) == -1) break; - if (read(l.ifd,seq+1,1) == -1) break; - - /* ESC [ sequences. */ - if (seq[0] == '[') { - if (seq[1] >= '0' && seq[1] <= '9') { - /* Extended escape, read additional byte. */ - if (read(l.ifd,seq+2,1) == -1) break; - if (seq[2] == '~') { - switch(seq[1]) { - case '3': /* Delete key. */ - linenoiseEditDelete(&l); - break; - } - } - } else { - switch(seq[1]) { - case 'A': /* Up */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV); - break; - case 'B': /* Down */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT); - break; - case 'C': /* Right */ - linenoiseEditMoveRight(&l); - break; - case 'D': /* Left */ - linenoiseEditMoveLeft(&l); - break; - case 'H': /* Home */ - linenoiseEditMoveHome(&l); - break; - case 'F': /* End*/ - linenoiseEditMoveEnd(&l); - break; - } - } - } - - /* ESC O sequences. */ - else if (seq[0] == 'O') { - switch(seq[1]) { - case 'H': /* Home */ - linenoiseEditMoveHome(&l); - break; - case 'F': /* End*/ - linenoiseEditMoveEnd(&l); - break; - } - } - break; - default: - if (linenoiseEditInsert(&l,cbuf,nread)) return -1; - break; - case CTRL_U: /* Ctrl+u, delete the whole line. */ - buf[0] = '\0'; - l.pos = l.len = 0; - refreshLine(&l); - break; - case CTRL_K: /* Ctrl+k, delete from current to end of line. */ - buf[l.pos] = '\0'; - l.len = l.pos; - refreshLine(&l); - break; - case CTRL_A: /* Ctrl+a, go to the start of the line */ - linenoiseEditMoveHome(&l); - break; - case CTRL_E: /* ctrl+e, go to the end of the line */ - linenoiseEditMoveEnd(&l); - break; - case CTRL_L: /* ctrl+l, clear screen */ - linenoiseClearScreen(); - refreshLine(&l); - break; - case CTRL_W: /* ctrl+w, delete previous word */ - linenoiseEditDeletePrevWord(&l); - break; - } - } - return l.len; -} - -/* This function calls the line editing function linenoiseEdit() using - * the STDIN file descriptor set in raw mode. */ -inline bool linenoiseRaw(const char *prompt, std::string& line) { - bool quit = false; - - if (!isatty(STDIN_FILENO)) { - /* Not a tty: read from file / pipe. */ - std::getline(std::cin, line); - } else { - /* Interactive editing. */ - if (enableRawMode(STDIN_FILENO) == false) { - return quit; - } - - char buf[LINENOISE_MAX_LINE]; - auto count = linenoiseEdit(STDIN_FILENO, STDOUT_FILENO, buf, LINENOISE_MAX_LINE, prompt); - if (count == -1) { - quit = true; - } else { - line.assign(buf, count); - } - - disableRawMode(STDIN_FILENO); - printf("\n"); - } - return quit; -} - -/* The high level function that is the main API of the linenoise library. - * This function checks if the terminal has basic capabilities, just checking - * for a blacklist of stupid terminals, and later either calls the line - * editing function or uses dummy fgets() so that you will be able to type - * something even in the most desperate of the conditions. */ -inline bool Readline(const char *prompt, std::string& line) { - if (isUnsupportedTerm()) { - printf("%s",prompt); - fflush(stdout); - std::getline(std::cin, line); - return false; - } else { - return linenoiseRaw(prompt, line); - } -} - -inline std::string Readline(const char *prompt, bool& quit) { - std::string line; - quit = Readline(prompt, line); - return line; -} - -inline std::string Readline(const char *prompt) { - bool quit; // dummy - return Readline(prompt, quit); -} - -/* ================================ History ================================= */ - -/* At exit we'll try to fix the terminal to the initial conditions. */ -inline void linenoiseAtExit(void) { - disableRawMode(STDIN_FILENO); -} - -/* This is the API call to add a new entry in the linenoise history. - * It uses a fixed array of char pointers that are shifted (memmoved) - * when the history max length is reached in order to remove the older - * entry and make room for the new one, so it is not exactly suitable for huge - * histories, but will work well for a few hundred of entries. - * - * Using a circular buffer is smarter, but a bit more complex to handle. */ -inline bool AddHistory(const char* line) { - if (history_max_len == 0) return false; - - /* Don't add duplicated lines. */ - if (!history.empty() && history.back() == line) return false; - - /* If we reached the max length, remove the older line. */ - if (history.size() == history_max_len) { - history.erase(history.begin()); - } - history.push_back(line); - - return true; -} - -/* Set the maximum length for the history. This function can be called even - * if there is already some history, the function will make sure to retain - * just the latest 'len' elements if the new history length value is smaller - * than the amount of items already inside the history. */ -inline bool SetHistoryMaxLen(size_t len) { - if (len < 1) return false; - history_max_len = len; - if (len < history.size()) { - history.resize(len); - } - return true; -} - -/* Save the history in the specified file. On success *true* is returned - * otherwise *false* is returned. */ -inline bool SaveHistory(const char* path) { - std::ofstream f(path); // TODO: need 'std::ios::binary'? - if (!f) return false; - for (const auto& h: history) { - f << h << std::endl; - } - return true; -} - -/* Load the history from the specified file. If the file does not exist - * zero is returned and no operation is performed. - * - * If the file exists and the operation succeeded *true* is returned, otherwise - * on error *false* is returned. */ -inline bool LoadHistory(const char* path) { - std::ifstream f(path); - if (!f) return false; - std::string line; - while (std::getline(f, line)) { - AddHistory(line.c_str()); - } - return true; -} - -inline const std::vector& GetHistory() { - return history; -} - -} // namespace linenoise - -#ifdef _WIN32 -#undef isatty -#undef write -#undef read -#pragma warning(pop) -#endif - -#endif /* __LINENOISE_HPP */ diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index e8e3b315..59e12574 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2536,7 +2536,7 @@ TEST_CASE("autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - check(R"( + check(R"( local temp = false local even = true; local a = true @@ -2551,63 +2551,63 @@ a = if temp then even elseif true then temp e@8 a = if temp then even elseif true then temp else e@9 )"); - auto ac = autocomplete('1'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('2'); - CHECK(ac.entryMap.count("temp") == 0); - CHECK(ac.entryMap.count("true") == 0); - CHECK(ac.entryMap.count("then")); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + ac = autocomplete('2'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('3'); - CHECK(ac.entryMap.count("even")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + ac = autocomplete('3'); + CHECK(ac.entryMap.count("even")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('4'); - CHECK(ac.entryMap.count("even") == 0); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else")); - CHECK(ac.entryMap.count("elseif")); + ac = autocomplete('4'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); - ac = autocomplete('5'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + ac = autocomplete('5'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('6'); - CHECK(ac.entryMap.count("temp") == 0); - CHECK(ac.entryMap.count("true") == 0); - CHECK(ac.entryMap.count("then")); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + ac = autocomplete('6'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('7'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + ac = autocomplete('7'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('8'); - CHECK(ac.entryMap.count("even") == 0); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else")); - CHECK(ac.entryMap.count("elseif")); + ac = autocomplete('8'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); - ac = autocomplete('9'); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + ac = autocomplete('9'); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 4a28bdde..d8af94db 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -611,7 +611,8 @@ TEST_CASE("TableLiteralsIndexConstant") CHECK_EQ("\n" + compileFunction0(R"( local a, b = "key", "value" return {[a] = 42, [b] = 0} -)"), R"( +)"), + R"( NEWTABLE R0 2 0 LOADN R1 42 SETTABLEKS R1 R0 K0 @@ -624,7 +625,8 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0(R"( local a, b = 1, 2 return {[a] = 42, [b] = 0} -)"), R"( +)"), + R"( NEWTABLE R0 0 2 LOADN R1 42 SETTABLEN R1 R0 1 @@ -789,8 +791,6 @@ RETURN R0 1 TEST_CASE("TableSizePredictionLoop") { - ScopedFastFlag sff("LuauPredictTableSizeLoop", true); - CHECK_EQ("\n" + compileFunction0(R"( local t = {} for i=1,4 do @@ -2827,7 +2827,7 @@ RETURN R1 -1 TEST_CASE("FastcallSelect") { - ScopedFastFlag sff("LuauCompileSelectBuiltin", true); + ScopedFastFlag sff("LuauCompileSelectBuiltin2", true); // select(_, ...) compiles to a builtin call CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( @@ -2846,7 +2846,8 @@ for i=1, select('#', ...) do sum += select(i, ...) end return sum -)"), R"( +)"), + R"( LOADN R0 0 LOADN R3 1 LOADK R5 K0 @@ -2856,13 +2857,14 @@ GETVARARGS R6 -1 CALL R4 -1 1 MOVE R1 R4 LOADN R2 1 -FORNPREP R1 +7 -FASTCALL1 57 R3 +3 +FORNPREP R1 +8 +FASTCALL1 57 R3 +4 GETIMPORT R4 2 +MOVE R5 R3 GETVARARGS R6 -1 CALL R4 -1 1 ADD R0 R0 R4 -FORNLOOP R1 -7 +FORNLOOP R1 -8 RETURN R0 1 )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 914b881f..e580949f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -492,7 +492,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag sffr("LuauBytecodeV2Read", true); ScopedFastFlag sffw("LuauBytecodeV2Write", true); runConformance("debug.lua"); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp index 8a092779..606f6de3 100644 --- a/tests/LValue.test.cpp +++ b/tests/LValue.test.cpp @@ -38,8 +38,6 @@ TEST_SUITE_BEGIN("LValue"); TEST_CASE("Luau_merge_hashmap_order") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - std::string a = "a"; std::string b = "b"; std::string c = "c"; @@ -58,20 +56,18 @@ TEST_CASE("Luau_merge_hashmap_order") TypeArena arena; merge(arena, m, other); - REQUIRE_EQ(3, m.NEW_refinements.size()); - REQUIRE(m.NEW_refinements.count(mkSymbol(a))); - REQUIRE(m.NEW_refinements.count(mkSymbol(b))); - REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); - CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("string", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); } TEST_CASE("Luau_merge_hashmap_order2") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - std::string a = "a"; std::string b = "b"; std::string c = "c"; @@ -90,20 +86,18 @@ TEST_CASE("Luau_merge_hashmap_order2") TypeArena arena; merge(arena, m, other); - REQUIRE_EQ(3, m.NEW_refinements.size()); - REQUIRE(m.NEW_refinements.count(mkSymbol(a))); - REQUIRE(m.NEW_refinements.count(mkSymbol(b))); - REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); - CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("string", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); } TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - std::string a = "a"; std::string b = "b"; std::string c = "c"; @@ -125,18 +119,18 @@ TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") TypeArena arena; merge(arena, m, other); - REQUIRE_EQ(5, m.NEW_refinements.size()); - REQUIRE(m.NEW_refinements.count(mkSymbol(a))); - REQUIRE(m.NEW_refinements.count(mkSymbol(b))); - REQUIRE(m.NEW_refinements.count(mkSymbol(c))); - REQUIRE(m.NEW_refinements.count(mkSymbol(d))); - REQUIRE(m.NEW_refinements.count(mkSymbol(e))); + REQUIRE_EQ(5, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); + REQUIRE(m.count(mkSymbol(d))); + REQUIRE(m.count(mkSymbol(e))); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); - CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(b)])); - CHECK_EQ("boolean | string", toString(m.NEW_refinements[mkSymbol(c)])); - CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(d)])); - CHECK_EQ("boolean", toString(m.NEW_refinements[mkSymbol(e)])); + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("number", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | string", toString(m[mkSymbol(c)])); + CHECK_EQ("number", toString(m[mkSymbol(d)])); + CHECK_EQ("boolean", toString(m[mkSymbol(e)])); } TEST_CASE("hashing_lvalue_global_prop_access") @@ -159,7 +153,7 @@ TEST_CASE("hashing_lvalue_global_prop_access") CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); - NEW_RefinementMap m; + RefinementMap m; m[t_x1] = getSingletonTypes().stringType; m[t_x2] = getSingletonTypes().numberType; @@ -188,7 +182,7 @@ TEST_CASE("hashing_lvalue_local_prop_access") CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); - NEW_RefinementMap m; + RefinementMap m; m[t_x1] = getSingletonTypes().stringType; m[t_x2] = getSingletonTypes().numberType; diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 5ad06f0d..d1cc49b2 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -54,6 +54,17 @@ return _ CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } +TEST_CASE_FIXTURE(Fixture, "PlaceholderReadGlobal") +{ + LintResult result = lint(R"( +_ = 5 +print(_) +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); +} + TEST_CASE_FIXTURE(Fixture, "PlaceholderWrite") { LintResult result = lint(R"( @@ -853,7 +864,7 @@ string.format("%Y") local _ = ("%"):format() -- correct format strings, just to uh make sure -string.format("hello %d %f", 4, 5) +string.format("hello %+10d %.02f %%", 4, 5) )"); CHECK_EQ(result.warnings.size(), 4); @@ -1078,16 +1089,18 @@ TEST_CASE_FIXTURE(Fixture, "FormatStringDate") os.date("%") os.date("%L") os.date("%?") +os.date("\0") -- correct formats os.date("it's %c now") os.date("!*t") )"); - CHECK_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings.size(), 4); CHECK_EQ(result.warnings[0].text, "Invalid date format: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); + CHECK_EQ(result.warnings[3].text, "Invalid date format: date format can not contain null characters"); } TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") @@ -1396,8 +1409,6 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { - ScopedFastFlag sff("LuauLintTableCreateTable", true); - LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1435,8 +1446,10 @@ table.create(42, {} :: {}) "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); - CHECK_EQ(result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); - CHECK_EQ(result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ( + result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ( + result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 90831ee9..c1a8887b 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauFixAmbiguousErrorRecoveryInAssign) - using namespace Luau; namespace @@ -1639,10 +1637,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " "statements"); - if (FFlag::LuauFixAmbiguousErrorRecoveryInAssign) - CHECK(result4.errors.size() == 1); - else - CHECK(result4.errors.size() == 5); + CHECK(result4.errors.size() == 1); } TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 86165814..572b882d 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -209,8 +209,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") { - ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; - CheckResult result = check(R"( local a = 55 :: number? local b = a :: number @@ -224,7 +222,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") { - ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 5e08654a..6730bedb 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -889,4 +889,55 @@ TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") REQUIRE(gtv->definition); } +TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") +{ + ScopedFastFlag sff[]{ + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions", true}, + }; + + CheckResult result = check(R"( + local function f(x: (number | boolean)?) + return assert(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") +{ + ScopedFastFlag sff[]{ + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions", true}, + }; + + CheckResult result = check(R"( + local function f(...: number?) + return assert(...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(...number?) -> (number, ...number?)", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") +{ + ScopedFastFlag sff[]{ + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions", true}, + }; + + CheckResult result = check(R"( + local function f(x: nil) + return assert(x, "hmm") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil) -> nil", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 47c13be9..e5eb0dca 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -176,19 +176,6 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a REQUIRE_EQ("{| [any]: any, x: number, y: number |}", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") -{ - CheckResult result = check(R"( - local foo: {x: number}? = nil - local bar = foo and foo.x -- TODO: Geez. We are inferring the wrong types here. Should be 'number?'. - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Binary and/or return types are straight up wrong. JIRA: CLI-40300 - CHECK_EQ("boolean | number", toString(requireType("bar"))); -} - // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f346ddfd..3a610c3a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -939,8 +939,6 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( local foo: string? = "hi" assert(foo) @@ -955,8 +953,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_pre TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( type T = {x: string | number} local t: T? = {x = "hi"} @@ -974,8 +970,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( type T = { x: { y: number }? } @@ -993,8 +987,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { - ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; - CheckResult result = check(R"( type T = { [string]: { prop: number }? } local t: T = {} @@ -1061,27 +1053,62 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } -TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") { - ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; - CheckResult result = check(R"( - type T = { [string]: { prop: number }? } - local t: T = {} - - if t["hello"] then - local foo = t["hello"].prop + local function len(a: {any}) + return a and #a or nil end )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") +TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions", true}, + {"LuauAssertStripsFalsyTypes", true}, + }; + CheckResult result = check(R"( - local function len(a: {any}) - return a and #a or nil + local function is_true(b: true) end + local function is_false(b: false) end + + local function f(x: boolean) + if x then + is_true(x) + else + is_false(x) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions", true}, + {"LuauAssertStripsFalsyTypes", true}, + }; + + CheckResult result = check(R"( + type Ok = { ok: true, value: T } + type Err = { ok: false, error: E } + type Result = Ok | Err + + local function apply(t: Result, f: (T) -> (), g: (E) -> ()) + if t.ok then + f(t.value) + else + g(t.error) + end end )"); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 48310921..f19cb618 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1482,7 +1482,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_i REQUIRE(tm); CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); // Should t now have an indexer? - // It would if the assignment to rt was correctly typed. + // It would if the assignment to rt was correctly typed. CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); } @@ -2082,7 +2082,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, @@ -2103,7 +2103,7 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, @@ -2131,7 +2131,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index c9b30e1a..ead3d762 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -2429,21 +2429,6 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "x_or_y_forces_both_x_and_y_to_be_of_same_type_if_either_is_free") -{ - CheckResult result = check(R"( - local function f(x, y) return x or y end - - local x = f(1, 2) - local y = f(3, "foo") - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*requireType("x"), *typeChecker.numberType); - - CHECK_EQ(result.errors[0], (TypeError{Location{{4, 23}, {4, 28}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); -} - TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") { CheckResult result = check(R"( @@ -4509,7 +4494,7 @@ f(function(x) print(x) end) } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ +{ ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; CheckResult result = check(R"( @@ -4777,7 +4762,7 @@ local a: X = if true then {"1", 2, 3} else {4, 5, 6} TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") { ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true }; + ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; CheckResult result = check(R"( local a: number? = if true then 1 else nil @@ -5012,16 +4997,14 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ( - toString(result.errors[0]), R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' + CHECK_EQ(toString(result.errors[0]), + R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( local function f(thing: any | string) local foo = thing.SomeRandomKey @@ -5120,4 +5103,65 @@ end )"); } +TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") +{ + ScopedFastFlag committingTxnLog{"LuauUseCommittingTxnLog", true}; + ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; + + CheckResult result = check(R"( + --!strict + --!nolint + + type FieldSpecifier = { + fieldName: string, + } + + type ReadFieldOptions = FieldSpecifier & { from: number? } + + type Policies = { + getStoreFieldName: (self: Policies, fieldSpec: FieldSpecifier) -> string, + } + + local Policies = {} + + local function foo(p: Policies) + end + + function Policies:getStoreFieldName(specifier: FieldSpecifier): string + return "" + end + + function Policies:readField(options: ReadFieldOptions) + local _ = self:getStoreFieldName(options) + --[[ + Type error: + TypeError { "MainModule", Location { { line = 25, col = 16 }, { line = 25, col = 20 } }, TypeMismatch { Policies, {- getStoreFieldName: (tp1) -> (a, b...) -} } } + ]] + foo(self) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") +{ + ScopedFastFlag noSealedTypeMod{"LuauNoSealedTypeMod", true}; + + fileResolver.source["game/A"] = R"( +export type Type = { unrelated: boolean } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = {} +function x:Destroy(): () end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1e790eba..0aeca096 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -260,4 +260,17 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") CHECK(unifyErrors.size() == 0); } +TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") +{ + ScopedFastFlag luauUnionTagMatchFix{"LuauUnionTagMatchFix", true}; + + TypeVar redirect{FreeTypeVar{TypeLevel{}}}; + TypeVar table{TableTypeVar{}}; + TypeVar metatable{MetatableTypeVar{&redirect, &table}}; + redirect = BoundTypeVar{&metatable}; // Now we have a metatable that is recursive on the table type + TypeVar variant{UnionTypeVar{{&metatable, typeChecker.numberType}}}; + + state.tryUnify(&metatable, &variant); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 188b8ebc..de091632 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,6 +118,10 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) +local ud = newproxy(true) +getmetatable(ud).__len = function() return 42 end +assert((function() return #ud end)() == 42) + assert((function() local a = 1 a = -a return a end)() == -1) -- while/repeat diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.lua index 5aaa422b..d05f9577 100644 --- a/tests/conformance/vararg.lua +++ b/tests/conformance/vararg.lua @@ -105,6 +105,45 @@ assert(a==5 and b==4 and c==3 and d==2 and e==1) a,b,c,d,e = f(4) assert(a==nil and b==nil and c==nil and d==nil and e==nil) +-- select tests +a = {select(3, unpack{10,20,30,40})} +assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) +a = {select(1)} +assert(next(a) == nil) +a = {select(-1, 3, 5, 7)} +assert(a[1] == 7 and a[2] == nil) +a = {select(-2, 3, 5, 7)} +assert(a[1] == 5 and a[2] == 7 and a[3] == nil) +pcall(select, 10000) +pcall(select, -10000) + +-- select(_, ...) has special optimizations so it needs extra testing +function selectone(n, ...) + local e = select(n, ...) + return e +end + +function selectmany(n, ...) + return table.concat({select(n, ...)}, ',') +end + +assert(selectone('#') == 0) +assert(selectmany('#') == "0") + +assert(selectone('#', 10, 20, 30) == 3) +assert(selectmany('#', 10, 20, 30) == "3") + +assert(selectone(1, 10, 20, 30) == 10) +assert(selectmany(1, 10, 20, 30) == "10,20,30") + +assert(selectone(2, 10, 20, 30) == 20) +assert(selectmany(2, 10, 20, 30) == "20,30") + +assert(selectone(-2, 10, 20, 30) == 20) +assert(selectmany(-2, 10, 20, 30) == "20,30") + +assert(selectone('3', 10, 20, 30) == 30) +assert(selectmany('3', 10, 20, 30) == "30") -- varargs for main chunks f = loadstring[[ return {...} ]] @@ -122,16 +161,5 @@ f = loadstring[[ assert(f("a", "b", nil, {}, assert)) assert(f()) -a = {select(3, unpack{10,20,30,40})} -assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) -a = {select(1)} -assert(next(a) == nil) -a = {select(-1, 3, 5, 7)} -assert(a[1] == 7 and a[2] == nil) -a = {select(-2, 3, 5, 7)} -assert(a[1] == 5 and a[2] == 7 and a[3] == nil) -pcall(select, 10000) -pcall(select, -10000) - return('OK') From 4e60eec1fc97132bee2f5bb09664b08f235f2c55 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 3 Feb 2022 16:31:50 -0800 Subject: [PATCH 022/102] Apply fix to the crash --- Analysis/src/TypeInfer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4d25fe2e..b9096d2e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4977,6 +4977,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.log = TxnLog::empty(); applyTypeFunction.typeArguments.clear(); applyTypeFunction.typePackArguments.clear(); applyTypeFunction.currentModule = currentModule; From 4748777ce850813d639e97338679357171d98289 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 3 Feb 2022 16:43:44 -0800 Subject: [PATCH 023/102] Fix isocline warnings --- extern/isocline/src/isocline.c | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extern/isocline/src/isocline.c b/extern/isocline/src/isocline.c index 13278062..8b6055cf 100644 --- a/extern/isocline/src/isocline.c +++ b/extern/isocline/src/isocline.c @@ -13,7 +13,12 @@ // $ gcc -c src/isocline.c //------------------------------------------------------------- #if !defined(IC_SEPARATE_OBJS) -# define _CRT_SECURE_NO_WARNINGS // for msvc +# ifndef _CRT_NONSTDC_NO_WARNINGS +# define _CRT_NONSTDC_NO_WARNINGS // for msvc +# endif +# ifndef _CRT_SECURE_NO_WARNINGS +# define _CRT_SECURE_NO_WARNINGS // for msvc +# endif # define _XOPEN_SOURCE 700 // for wcwidth # define _DEFAULT_SOURCE // ensure usleep stays visible with _XOPEN_SOURCE >= 700 # include "attr.c" From bbae46600635e43d875cc95392a1e8481f445524 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 4 Feb 2022 12:31:19 -0800 Subject: [PATCH 024/102] Sync to upstream/release/513 This takes the extra bug fix for generic name confusion --- Analysis/include/Luau/TypeInfer.h | 3 ++- Analysis/src/TypeInfer.cpp | 7 ++++--- tests/TypeInfer.generics.test.cpp | 25 +++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 90dc9f42..f61ecbf5 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -346,7 +346,8 @@ private: // Note: `scope` must be a fresh scope. GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames); + const AstArray& genericNames, const AstArray& genericPackNames, + bool useCache = false); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4d25fe2e..e1987937 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,6 +29,7 @@ LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as fals LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) +LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) @@ -1199,7 +1200,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauProperTypeLevels) aliasScope->level.subLevel = subLevel; - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); @@ -5361,7 +5362,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames) + const AstArray& genericNames, const AstArray& genericPackNames, bool useCache) { LUAU_ASSERT(scope->parent); @@ -5387,7 +5388,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) + if (FFlag::LuauRecursiveTypeParameterRestriction && (!FFlag::LuauGenericFunctionsDontCacheTypeParams || useCache)) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index a7f27551..8a2c6f27 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -673,4 +673,29 @@ local d: D = c R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } +TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") +{ + ScopedFastFlag sff{"LuauGenericFunctionsDontCacheTypeParams", true}; + + CheckResult result = check(R"( +-- See https://github.com/Roblox/luau/issues/332 +-- This function has a type parameter with the same name as clones, +-- so if we cache type parameter names for functions these get confused. +-- function id(x : Z) : Z +function id(x : X) : X + return x +end + +function clone(dict: {[X]:Y}): {[X]:Y} + local copy = {} + for k, v in pairs(dict) do + copy[k] = v + end + return copy +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); From e9bf182585e4cfc3bdf9bf71fc74ca947027b8dd Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 11 Feb 2022 10:43:14 -0800 Subject: [PATCH 025/102] Sync to upstream/release/514 --- Analysis/include/Luau/Autocomplete.h | 2 + Analysis/include/Luau/Linter.h | 1 + Analysis/include/Luau/Substitution.h | 2 +- Analysis/include/Luau/TxnLog.h | 4 +- Analysis/include/Luau/TypeInfer.h | 11 ++ Analysis/include/Luau/TypePack.h | 3 - Analysis/include/Luau/TypeVar.h | 3 - Analysis/src/EmbeddedBuiltinDefinitions.cpp | 12 +- Analysis/src/Linter.cpp | 103 ++++++++++++-- Analysis/src/Module.cpp | 28 ++-- Analysis/src/Scope.cpp | 4 + Analysis/src/TxnLog.cpp | 8 ++ Analysis/src/TypeInfer.cpp | 101 +++++++++++--- Analysis/src/TypeVar.cpp | 4 +- Analysis/src/Unifier.cpp | 125 +++++++++++------ Ast/src/Parser.cpp | 2 +- CLI/Ast.cpp | 86 ++++++++++++ CLI/Repl.cpp | 68 ++++++--- CLI/Repl.h | 7 +- CMakeLists.txt | 14 +- Compiler/src/BytecodeBuilder.cpp | 8 +- Compiler/src/Compiler.cpp | 62 ++------- Sources.cmake | 8 ++ VM/src/lbuiltins.cpp | 2 +- VM/src/lgcdebug.cpp | 4 + VM/src/lmem.cpp | 127 ++++++++++++----- VM/src/lobject.cpp | 4 +- VM/src/lvmexecute.cpp | 3 +- extern/isocline/src/isocline.c | 7 +- tests/Compiler.test.cpp | 4 - tests/Conformance.test.cpp | 2 - tests/Linter.test.cpp | 36 ++++- tests/Repl.test.cpp | 92 ++++++++++++ tests/TypeInfer.aliases.test.cpp | 61 ++++++++ tests/TypeInfer.builtins.test.cpp | 15 +- tests/TypeInfer.provisional.test.cpp | 74 +++++++++- tests/TypeInfer.refinements.test.cpp | 22 +-- tests/TypeInfer.test.cpp | 147 ++++++++++++++++++++ tests/TypeInfer.tryUnify.test.cpp | 17 +++ tests/conformance/basic.lua | 10 +- tests/conformance/debug.lua | 1 + tests/conformance/errors.lua | 16 ++- tests/conformance/gc.lua | 7 +- tests/conformance/math.lua | 1 + tests/conformance/vararg.lua | 6 + tests/conformance/vector.lua | 9 ++ tools/heapgraph.py | 33 +++-- tools/svg.py | 2 +- 48 files changed, 1100 insertions(+), 268 deletions(-) create mode 100644 CLI/Ast.cpp diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 58534293..65b788d3 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -86,6 +86,8 @@ struct OwningAutocompleteResult }; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); + +// Deprecated, do not use in new work. OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback); } // namespace Luau diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 1f7f7f9d..ec3c124d 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -49,6 +49,7 @@ struct LintWarning Code_DeprecatedApi = 22, Code_TableOperations = 23, Code_DuplicateCondition = 24, + Code_MisleadingAndOr = 25, Code__Count }; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 4f3307cd..f85b4269 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -93,7 +93,7 @@ struct Tarjan // This should never be null; ensure you initialize it before calling // substitution methods. - const TxnLog* log; + const TxnLog* log = nullptr; std::vector edgesTy; std::vector edgesTp; diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 02b87374..f238e258 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -307,8 +307,8 @@ private: // // We can't use a DenseHashMap here because we need a non-const iterator // over the map when we concatenate. - std::unordered_map> typeVarChanges; - std::unordered_map> typePackChanges; + std::unordered_map, DenseHashPointer> typeVarChanges; + std::unordered_map, DenseHashPointer> typePackChanges; TxnLog* parent = nullptr; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index f61ecbf5..5592fa1f 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -103,6 +103,11 @@ struct GenericTypeDefinitions std::vector genericPacks; }; +struct HashBoolNamePair +{ + size_t operator()(const std::pair& pair) const; +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -411,6 +416,12 @@ public: private: int checkRecursionCount = 0; int recursionCount = 0; + + /** + * We use this to avoid doing second-pass analysis of type aliases that are duplicates. We record a pair + * (exported, name) to properly deal with the case where the two duplicates do not have the same export status. + */ + DenseHashSet, HashBoolNamePair> duplicateTypeAliases; }; // Unit test hook diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index ca588ccb..c74bad11 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -54,9 +54,6 @@ struct TypePackVar bool persistent = false; // Pointer to the type arena that allocated this type. - // Do not depend on the value of this under any circumstances. This is for - // debugging purposes only. This is only set in debug builds; it is nullptr - // in all other environments. TypeArena* owningArena = nullptr; }; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 11dc9377..8d1a9fa6 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -449,9 +449,6 @@ struct TypeVar final std::optional documentationSymbol; // Pointer to the type arena that allocated this type. - // Do not depend on the value of this under any circumstances. This is for - // debugging purposes only. This is only set in debug builds; it is nullptr - // in all other environments. TypeArena* owningArena = nullptr; bool operator==(const TypeVar& rhs) const; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 24982506..f3ef88fc 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAGVARIABLE(LuauFixTonumberReturnType, false) - namespace Luau { @@ -115,6 +113,7 @@ declare function gcinfo(): number declare function error(message: T, level: number?) declare function tostring(value: T): string + declare function tonumber(value: T, radix: number?): number? declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V @@ -200,14 +199,7 @@ declare function gcinfo(): number std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrc; - - if (FFlag::LuauFixTonumberReturnType) - result += "declare function tonumber(value: T, radix: number?): number?\n"; - else - result += "declare function tonumber(value: T, radix: number?): number\n"; - - return result; + return kBuiltinDefinitionLuaSrc; } } // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 57a33e93..2ba6a0fc 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -43,6 +43,7 @@ static const char* kWarningNames[] = { "DeprecatedApi", "TableOperations", "DuplicateCondition", + "MisleadingAndOr", }; // clang-format on @@ -2040,18 +2041,28 @@ private: const Property* prop = lookupClassProp(cty, node->index.value); if (prop && prop->deprecated) - { - if (!prop->deprecatedSuggestion.empty()) - emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated, use '%s' instead", - cty->name.c_str(), node->index.value, prop->deprecatedSuggestion.c_str()); - else - emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated", cty->name.c_str(), - node->index.value); - } + report(node->location, *prop, cty->name.c_str(), node->index.value); + } + else if (const TableTypeVar* tty = get(follow(*ty))) + { + auto prop = tty->props.find(node->index.value); + + if (prop != tty->props.end() && prop->second.deprecated) + report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); } return true; } + + void report(const Location& location, const Property& prop, const char* container, const char* field) + { + std::string suggestion = prop.deprecatedSuggestion.empty() ? "" : format(", use '%s' instead", prop.deprecatedSuggestion.c_str()); + + if (container) + emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s.%s' is deprecated%s", container, field, suggestion.c_str()); + else + emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated%s", field, suggestion.c_str()); + } }; class LintTableOperations : AstVisitor @@ -2257,6 +2268,39 @@ private: return false; } + bool visit(AstExprIfElse* expr) override + { + if (!expr->falseExpr->is()) + return true; + + // if..elseif chain detected, we need to unroll it + std::vector conditions; + conditions.reserve(2); + + AstExprIfElse* head = expr; + while (head) + { + head->condition->visit(this); + head->trueExpr->visit(this); + + conditions.push_back(head->condition); + + if (head->falseExpr->is()) + { + head = head->falseExpr->as(); + continue; + } + + head->falseExpr->visit(this); + break; + } + + detectDuplicates(conditions); + + // block recursive visits so that we only analyze each chain once + return false; + } + bool visit(AstExprBinary* expr) override { if (expr->op != AstExprBinary::And && expr->op != AstExprBinary::Or) @@ -2418,6 +2462,46 @@ private: } }; +class LintMisleadingAndOr : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintMisleadingAndOr pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprBinary* node) override + { + if (node->op != AstExprBinary::Or) + return true; + + AstExprBinary* and_ = node->left->as(); + if (!and_ || and_->op != AstExprBinary::And) + return true; + + const char* alt = nullptr; + + if (and_->right->is()) + alt = "nil"; + else if (AstExprConstantBool* c = and_->right->as(); c && c->value == false) + alt = "false"; + + if (alt) + emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location, + "The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else " + "expression instead", + alt); + + return true; + } +}; + static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env) { ScopePtr current = env; @@ -2522,6 +2606,9 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_DuplicateLocal)) LintDuplicateLocal::process(context); + if (context.warningEnabled(LintWarning::Code_MisleadingAndOr)) + LintMisleadingAndOr::process(context); + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 4fdff8f7..817a33e9 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,10 +12,10 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) +LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) - +LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) namespace Luau @@ -66,7 +66,7 @@ TypeId TypeArena::addTV(TypeVar&& tv) { TypeId allocated = typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -76,7 +76,7 @@ TypeId TypeArena::freshType(TypeLevel level) { TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -86,7 +86,7 @@ TypePackId TypeArena::addTypePack(std::initializer_list types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -96,7 +96,7 @@ TypePackId TypeArena::addTypePack(std::vector types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -106,7 +106,7 @@ TypePackId TypeArena::addTypePack(TypePack tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -116,7 +116,7 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -454,8 +454,16 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - // TODO: Make this work when the arena of 'res' might be frozen - asMutable(res)->documentationSymbol = typeId->documentationSymbol; + if (FFlag::LuauImmutableTypes) + { + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + else + { + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } } return res; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index c30db9c2..0a362a5e 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,6 +2,8 @@ #include "Luau/Scope.h" +LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix); + namespace Luau { @@ -17,6 +19,8 @@ Scope::Scope(const ScopePtr& parent, int subLevel) , returnType(parent->returnType) , level(parent->level.incr()) { + if (FFlag::LuauTwoPassAliasDefinitionFix) + level = level.incr(); level.subLevel = subLevel; } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 0968a4c1..00067bdd 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -250,6 +250,10 @@ PendingTypePack* TxnLog::queue(TypePackId tp) PendingType* TxnLog::pending(TypeId ty) const { + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); + for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) @@ -261,6 +265,10 @@ PendingType* TxnLog::pending(TypeId ty) const PendingTypePack* TxnLog::pending(TypePackId tp) const { + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); + for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e1987937..f1c314cd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,12 +32,13 @@ LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) +LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions, false) +LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -47,7 +48,10 @@ LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) +LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) +LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(LuauAnotherTypeLevelFix, false) namespace Luau { @@ -213,6 +217,11 @@ static bool isMetamethod(const Name& name) name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; } +size_t HashBoolNamePair::operator()(const std::pair& pair) const +{ + return std::hash()(pair.first) ^ std::hash()(pair.second); +} + TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) : resolver(resolver) , iceHandler(iceHandler) @@ -225,6 +234,7 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , anyType(getSingletonTypes().anyType) , optionalNumberType(getSingletonTypes().optionalNumberType) , anyTypePack(getSingletonTypes().anyTypePack) + , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -291,6 +301,9 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona unifierState.skipCacheForType.clear(); } + if (FFlag::LuauTwoPassAliasDefinitionFix) + duplicateTypeAliases.clear(); + return std::move(currentModule); } @@ -496,6 +509,9 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { + if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == Parser::errorName) + continue; + auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; Name name = typealias->name.value; @@ -1176,6 +1192,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // Once with forwardDeclare, and once without. Name name = typealias.name.value; + // If the alias is missing a name, we can't do anything with it. Ignore it. + if (FFlag::LuauTwoPassAliasDefinitionFix && name == Parser::errorName) + return; + std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) binding = it->second; @@ -1192,6 +1212,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; + if (FFlag::LuauTwoPassAliasDefinitionFix) + duplicateTypeAliases.insert({typealias.exported, name}); } else { @@ -1211,6 +1233,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { + // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be + // interesting. + if (FFlag::LuauTwoPassAliasDefinitionFix && duplicateTypeAliases.find({typealias.exported, name})) + return; + if (!binding) ice("Not predeclared"); @@ -1235,7 +1262,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy - if (ttv->name) + // Additionally, we can't modify types that come from other modules + if (ttv->name || (FFlag::LuauImmutableTypes && follow(ty)->owningArena != ¤tModule->internalTypes)) { bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), binding->typeParams.end(), [](auto&& itp, auto&& tp) { @@ -1247,7 +1275,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias }); // Copy can be skipped if this is an identical alias - if (ttv->name != name || !sameTys || !sameTps) + if ((FFlag::LuauImmutableTypes && !ttv->name) || ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1279,9 +1307,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } } else if (auto mtv = getMutable(follow(ty))) - mtv->syntheticName = name; + { + // We can't modify types that come from other modules + if (!FFlag::LuauImmutableTypes || follow(ty)->owningArena == ¤tModule->internalTypes) + mtv->syntheticName = name; + } - unify(ty, bindingsMap[name].type, typealias.location); + TypeId& bindingType = bindingsMap[name].type; + bool ok = unify(ty, bindingType, typealias.location); + + if (FFlag::LuauTwoPassAliasDefinitionFix && ok) + bindingType = ty; } } @@ -1564,7 +1600,12 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!", expr.location); + { + if (FFlag::LuauReturnAnyInsteadOfICE) + return {anyType, std::move(result.predicates)}; + else + ice("Unexpected abstract type pack!", expr.location); + } else ice("Unknown TypePack type!", expr.location); } @@ -1614,11 +1655,23 @@ std::optional TypeChecker::getIndexTypeFromType( tablify(type); - const PrimitiveTypeVar* primitiveType = get(type); - if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) + if (FFlag::LuauDiscriminableUnions2) { - if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) + if (isString(type)) + { + std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + LUAU_ASSERT(mtIndex); type = *mtIndex; + } + } + else + { + const PrimitiveTypeVar* primitiveType = get(type); + if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) + { + if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) + type = *mtIndex; + } } if (TableTypeVar* tableType = getMutableTableType(type)) @@ -2476,7 +2529,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy), + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) @@ -2489,7 +2542,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2497,8 +2550,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); PredicateVec predicates; @@ -2785,12 +2838,16 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = freshType(scope); + TypeId resultType = freshType(FFlag::LuauAnotherTypeLevelFix ? exprTable->level : scope->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } else { + /* + * If we use [] indexing to fetch a property from a sealed table that has no indexer, we have no idea if it will + * work, so we just mint a fresh type, return that, and hope for the best. + */ TypeId resultType = freshType(scope); return resultType; } @@ -4195,6 +4252,9 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } + if (FFlag::LuauImmutableTypes) + return *moduleType; + SeenTypes seenTypes; SeenTypePacks seenTypePacks; CloneState cloneState; @@ -4978,6 +5038,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.log = TxnLog::empty(); applyTypeFunction.typeArguments.clear(); applyTypeFunction.typePackArguments.clear(); applyTypeFunction.currentModule = currentModule; @@ -5445,7 +5506,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions); + LUAU_ASSERT(FFlag::LuauDiscriminableUnions2); const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. @@ -5658,7 +5719,7 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi return std::nullopt; }; - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); if (ty && fromOr) @@ -5771,7 +5832,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { refineLValue(isaP.lvalue, refis, scope, predicate); } @@ -5846,7 +5907,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); return; @@ -5868,7 +5929,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec } auto fail = [&](const TypeErrorData& err) { - if (!FFlag::LuauDiscriminableUnions) + if (!FFlag::LuauDiscriminableUnions2) errVec.push_back(TypeError{typeguardP.location, err}); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; @@ -5900,7 +5961,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return {ty}; }; - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { std::vector rhs = options(eqP.type); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 2321eafd..7e438e31 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -28,6 +28,7 @@ LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauUnionTagMatchFix) +LUAU_FASTFLAG(LuauDiscriminableUnions2) namespace Luau { @@ -393,7 +394,8 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) if (seen.contains(ty)) return true; - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty) || get(ty) || get(ty)) + bool isStr = FFlag::LuauDiscriminableUnions2 ? isString(ty) : isPrim(ty, PrimitiveTypeVar::String); + if (isStr || get(ty) || get(ty) || get(ty)) return true; if (auto uty = get(ty)) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 89e4ae23..a8ad5159 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -15,6 +15,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) +LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); @@ -24,6 +25,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) +LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) namespace Luau { @@ -32,11 +34,13 @@ struct PromoteTypeLevels { DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; + const TypeArena* typeArena = nullptr; TypeLevel minLevel; - explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel) + explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) : DEPRECATED_log(DEPRECATED_log) , log(log) + , typeArena(typeArena) , minLevel(minLevel) { } @@ -65,8 +69,12 @@ struct PromoteTypeLevels } template - bool operator()(TID, const T&) + bool operator()(TID ty, const T&) { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + return true; } @@ -83,12 +91,20 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FunctionTypeVar&) { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypeId ty, const TableTypeVar& ttv) { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + if (ttv.state != TableState::Free && ttv.state != TableState::Generic) return true; @@ -108,24 +124,33 @@ struct PromoteTypeLevels } }; -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypeId ty) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { - PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return; + + PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { - PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + return; + + PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } struct SkipCacheForType { - SkipCacheForType(const DenseHashMap& skipCacheForType) + SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) : skipCacheForType(skipCacheForType) + , typeArena(typeArena) { } @@ -152,6 +177,10 @@ struct SkipCacheForType bool operator()(TypeId ty, const TableTypeVar&) { + // Types from other modules don't contain mutable elements and are ok to cache + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + TableTypeVar& ttv = *getMutable(ty); if (ttv.boundTo) @@ -172,6 +201,10 @@ struct SkipCacheForType template bool operator()(TypeId ty, const T& t) { + // Types from other modules don't contain mutable elements and are ok to cache + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + const bool* prev = skipCacheForType.find(ty); if (prev && *prev) @@ -184,8 +217,12 @@ struct SkipCacheForType } template - bool operator()(TypePackId, const T&) + bool operator()(TypePackId tp, const T&) { + // Types from other modules don't contain mutable elements and are ok to cache + if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + return false; + return true; } @@ -208,6 +245,7 @@ struct SkipCacheForType } const DenseHashMap& skipCacheForType; + const TypeArena* typeArena = nullptr; bool result = false; }; @@ -422,13 +460,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (FFlag::LuauUseCommittingTxnLog) { - promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); log.replace(superTy, BoundTypeVar(subTy)); } else { if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); else if (auto subLevel = getMutableLevel(subTy)) { if (!subLevel->subsumes(superFree->level)) @@ -466,13 +504,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (FFlag::LuauUseCommittingTxnLog) { - promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); log.replace(subTy, BoundTypeVar(superTy)); } else { if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); else if (auto superLevel = getMutableLevel(superTy)) { if (!superLevel->subsumes(subFree->level)) @@ -849,7 +887,7 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy) return; auto skipCacheFor = [this](TypeId ty) { - SkipCacheForType visitor{sharedState.skipCacheForType}; + SkipCacheForType visitor{sharedState.skipCacheForType, types}; visitTypeVarOnce(ty, visitor, sharedState.seenAny); sharedState.skipCacheForType[ty] = visitor.result; @@ -1637,32 +1675,35 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(subFunction->retType, superFunction->retType); } - if (FFlag::LuauUseCommittingTxnLog) + if (!FFlag::LuauImmutableTypes) { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) + if (FFlag::LuauUseCommittingTxnLog) { - PendingType* newSubTy = log.queue(subTy); - FunctionTypeVar* newSubFtv = getMutable(newSubTy); - LUAU_ASSERT(newSubFtv); - newSubFtv->definition = superFunction->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; + } } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + else { - PendingType* newSuperTy = log.queue(superTy); - FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); - LUAU_ASSERT(newSuperFtv); - newSuperFtv->definition = subFunction->definition; - } - } - else - { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - subFunction->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - superFunction->definition = subFunction->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + subFunction->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + superFunction->definition = subFunction->definition; + } } } @@ -2631,7 +2672,7 @@ static void queueTypePack(std::vector& queue, DenseHashSet& { while (true) { - a = follow(a); + a = FFlag::LuauFollowWithCommittingTxnLogInAnyUnification ? state.log.follow(a) : follow(a); if (seenTypePacks.find(a)) break; @@ -2738,7 +2779,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, - TypeId anyType, TypePackId anyTypePack) + const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack) { while (!queue.empty()) { @@ -2746,8 +2787,14 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { TypeId ty = state.log.follow(queue.back()); queue.pop_back(); + + // Types from other modules don't have free types + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + continue; + if (seen.find(ty)) continue; + seen.insert(ty); if (state.log.getMutable(ty)) @@ -2853,7 +2900,7 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, getSingletonTypes().anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) @@ -2869,7 +2916,7 @@ void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) queueTypePack(queue, sharedState.tempSeenTp, *this, subTy, anyTp); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, anyTp); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, anyTy, anyTp); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f559e2e0..30b32f91 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1133,7 +1133,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector + +#include "Luau/Common.h" +#include "Luau/Ast.h" +#include "Luau/JsonEncoder.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" + +#include "FileUtils.h" + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [file]\n", argv0); +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + if (argc >= 2 && strcmp(argv[1], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (argc < 2) + { + displayHelp(argv[0]); + return 1; + } + + const char* name = argv[1]; + std::optional maybeSource = std::nullopt; + if (strcmp(name, "-") == 0) + { + maybeSource = readStdin(); + } + else + { + maybeSource = readFile(name); + } + + if (!maybeSource) + { + fprintf(stderr, "Couldn't read source %s\n", name); + return 1; + } + + std::string source = *maybeSource; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + Luau::ParseOptions options; + options.supportContinueStatement = true; + options.allowTypeAnnotations = true; + options.allowDeclarationSyntax = true; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + + if (parseResult.errors.size() > 0) + { + fprintf(stderr, "Parse errors were encountered:\n"); + for (const Luau::ParseError& error : parseResult.errors) + { + fprintf(stderr, " %s - %s\n", toString(error.getLocation()).c_str(), error.getMessage().c_str()); + } + fprintf(stderr, "\n"); + } + + printf("%s", Luau::toJson(parseResult.root).c_str()); + + return parseResult.errors.size() > 0 ? 1 : 0; +} + + diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5af6b508..9a6e25c2 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -1,4 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Repl.h" + #include "lua.h" #include "lualib.h" @@ -38,13 +40,14 @@ enum class CompileFormat struct GlobalOptions { int optimizationLevel = 1; + int debugLevel = 1; } globalOptions; static Luau::CompileOptions copts() { Luau::CompileOptions result = {}; result.optimizationLevel = globalOptions.optimizationLevel; - result.debugLevel = 1; + result.debugLevel = globalOptions.debugLevel; result.coverageLevel = coverageActive() ? 2 : 0; return result; @@ -240,9 +243,8 @@ std::string runCode(lua_State* L, const std::string& source) return std::string(); } -static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) +static void completeIndexer(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) { - auto* L = reinterpret_cast(ic_completion_arg(cenv)); std::string_view lookup = editBuffer; char lastSep = 0; @@ -276,7 +278,7 @@ static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) // Add an opening paren for function calls by default. completion += "("; } - ic_add_completion_ex(cenv, completion.data(), key.data(), nullptr); + addCompletionCallback(completion, std::string(key)); } } lua_pop(L, 1); @@ -295,10 +297,11 @@ static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) { // Replace the string object with the string class to perform further lookups of string functions // Note: We retrieve the string class from _G to prevent issues if the user assigns to `string`. + lua_pop(L, 1); // Pop the string instance lua_getglobal(L, "_G"); lua_pushlstring(L, "string", 6); lua_rawget(L, -2); - lua_remove(L, -2); + lua_remove(L, -2); // Remove the global table LUAU_ASSERT(lua_istable(L, -1)); } else if (!lua_istable(L, -1)) @@ -312,6 +315,26 @@ static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) lua_pop(L, 1); } +void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) +{ + // look the value up in current global table first + lua_pushvalue(L, LUA_GLOBALSINDEX); + completeIndexer(L, editBuffer, addCompletionCallback); + + // and in actual global table after that + lua_getglobal(L, "_G"); + completeIndexer(L, editBuffer, addCompletionCallback); +} + +static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer) +{ + auto* L = reinterpret_cast(ic_completion_arg(cenv)); + + getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) { + ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); + }); +} + static bool isMethodOrFunctionChar(const char* s, long len) { char c = *s; @@ -320,15 +343,7 @@ static bool isMethodOrFunctionChar(const char* s, long len) static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) { - auto* L = reinterpret_cast(ic_completion_arg(cenv)); - - // look the value up in current global table first - lua_pushvalue(L, LUA_GLOBALSINDEX); - ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); - - // and in actual global table after that - lua_getglobal(L, "_G"); - ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); + ic_complete_word(cenv, editBuffer, icGetCompletions, isMethodOrFunctionChar); } struct LinenoiseScopedHistory @@ -372,19 +387,20 @@ static void runReplImpl(lua_State* L) for (;;) { - const char* line = ic_readline(buffer.empty() ? "" : ">"); + const char* prompt = buffer.empty() ? "" : ">"; + std::unique_ptr line(ic_readline(prompt), free); if (!line) break; - if (buffer.empty() && runCode(L, std::string("return ") + line) == std::string()) + if (buffer.empty() && runCode(L, std::string("return ") + line.get()) == std::string()) { - ic_history_add(line); + ic_history_add(line.get()); continue; } if (!buffer.empty()) buffer += "\n"; - buffer += line; + buffer += line.get(); std::string error = runCode(L, buffer); @@ -400,7 +416,6 @@ static void runReplImpl(lua_State* L) ic_history_add(buffer.c_str()); buffer.clear(); - free((void*)line); } } @@ -504,7 +519,7 @@ static bool compileFile(const char* name, CompileFormat format) if (format == CompileFormat::Text) { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals); bcb.setDumpSource(*source); } @@ -549,7 +564,8 @@ static void displayHelp(const char* argv0) printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" -h, --help: Display this usage message.\n"); printf(" -i, --interactive: Run an interactive REPL after executing the last script specified.\n"); - printf(" -O: use compiler optimization level (n=0-2).\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -620,6 +636,16 @@ int replMain(int argc, char** argv) } globalOptions.optimizationLevel = level; } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.debugLevel = level; + } else if (strcmp(argv[i], "--profile") == 0) { profile = 10000; // default to 10 KHz diff --git a/CLI/Repl.h b/CLI/Repl.h index 11a077ae..cd54b7e0 100644 --- a/CLI/Repl.h +++ b/CLI/Repl.h @@ -3,10 +3,15 @@ #include "lua.h" +#include #include +using AddCompletionCallback = std::function; + // Note: These are internal functions which are being exposed in a header // so they can be included by unit tests. -int replMain(int argc, char** argv); void setupState(lua_State* L); std::string runCode(lua_State* L, const std::string& source); +void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback); + +int replMain(int argc, char** argv); diff --git a/CMakeLists.txt b/CMakeLists.txt index 881d3c3f..c19d2b40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,13 +5,20 @@ if(EXT_PLATFORM_STRING) endif() cmake_minimum_required(VERSION 3.0) -project(Luau LANGUAGES CXX C) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) +option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) +if(LUAU_STATIC_CRT) + cmake_minimum_required(VERSION 3.15) + cmake_policy(SET CMP0091 NEW) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") +endif() + +project(Luau LANGUAGES CXX C) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) @@ -21,10 +28,12 @@ add_library(isocline STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) add_executable(Luau.Analyze.CLI) + add_executable(Luau.Ast.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + set_target_properties(Luau.Ast.CLI PROPERTIES OUTPUT_NAME luau-ast) endif() if(LUAU_BUILD_TESTS) @@ -98,6 +107,7 @@ endif() if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) @@ -111,6 +121,8 @@ if(LUAU_BUILD_CLI) endif() target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + + target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis) endif() if(LUAU_BUILD_TESTS) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index e6d02454..09f06b68 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Write, false) - namespace Luau { @@ -510,7 +508,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(FFlag::LuauBytecodeV2Write ? LBC_VERSION_FUTURE : LBC_VERSION); + bytecode = char(LBC_VERSION_FUTURE); writeStringTable(bytecode); @@ -611,9 +609,7 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const writeVarInt(ss, child); // debug info - if (FFlag::LuauBytecodeV2Write) - writeVarInt(ss, func.debuglinedefined); - + writeVarInt(ss, func.debuglinedefined); writeVarInt(ss, func.debugname); bool hasLines = true; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e4253adc..656a9926 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,7 +15,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) LUAU_FASTFLAG(LuauCompileSelectBuiltin2) namespace Luau @@ -1182,18 +1181,9 @@ struct Compiler const AstExprTable::Item& item = expr->items.data[i]; LUAU_ASSERT(item.key); // no list portion => all items have keys - if (FFlag::LuauCompileTableIndexOpt) - { - const Constant* ckey = constants.find(item.key); + const Constant* ckey = constants.find(item.key); - indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); - } - else - { - AstExprConstantNumber* ckey = item.key->as(); - - indexSize += (ckey && ckey->value == double(indexSize + 1)); - } + indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); } // we only perform the optimization if we don't have any other []-keys @@ -1295,43 +1285,10 @@ struct Compiler { RegScope rsi(this); - if (FFlag::LuauCompileTableIndexOpt) - { - LValue lv = compileLValueIndex(reg, key, rsi); - uint8_t rv = compileExprAuto(value, rsi); + LValue lv = compileLValueIndex(reg, key, rsi); + uint8_t rv = compileExprAuto(value, rsi); - compileAssign(lv, rv); - } - else - { - // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax - if (AstExprConstantString* ckey = key->as()) - { - BytecodeBuilder::StringRef cname = sref(ckey->value); - int32_t cid = bytecode.addConstantString(cname); - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); - bytecode.emitAux(cid); - } - else if (AstExprConstantNumber* ckey = key->as(); - ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) - { - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); - } - else - { - uint8_t rk = compileExprAuto(key, rsi); - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); - } - } + compileAssign(lv, rv); } // items without a key are set using SETLIST so that we can initialize large arrays quickly else @@ -1439,8 +1396,7 @@ struct Compiler uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); - if (FFlag::LuauCompileTableIndexOpt) - setDebugLine(expr->index); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } @@ -1453,8 +1409,7 @@ struct Compiler if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - if (FFlag::LuauCompileTableIndexOpt) - setDebugLine(expr->index); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); @@ -1853,8 +1808,7 @@ struct Compiler void compileLValueUse(const LValue& lv, uint8_t reg, bool set) { - if (FFlag::LuauCompileTableIndexOpt) - setDebugLine(lv.location); + setDebugLine(lv.location); switch (lv.kind) { diff --git a/Sources.cmake b/Sources.cmake index b36b6db5..773f6f35 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -193,6 +193,14 @@ if(TARGET Luau.Analyze.CLI) CLI/Analyze.cpp) endif() +if(TARGET Luau.Ast.CLI) + target_sources(Luau.Ast.CLI PRIVATE + CLI/Ast.cpp + CLI/FileUtils.h + CLI/FileUtils.cpp + ) +endif() + if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index ecc14e87..718d387d 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1098,7 +1098,7 @@ static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, Stk int i = int(nvalue(arg0)); // i >= 1 && i <= n - if (unsigned(i - 1) <= unsigned(n)) + if (unsigned(i - 1) < unsigned(n)) { setobj2s(L, res, L->base - n + (i - 1)); return 1; diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 906fb0d0..ce196520 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -250,6 +250,8 @@ void luaC_validate(lua_State* L) if (FFlag::LuauGcPagedSweep) { + validategco(L, NULL, obj2gco(g->mainthread)); + luaM_visitgco(L, L, validategco); } else @@ -565,6 +567,8 @@ void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* if (FFlag::LuauGcPagedSweep) { + dumpgco(f, NULL, obj2gco(g->mainthread)); + luaM_visitgco(L, f, dumpgco); } else diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index de85cf59..19617b8c 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -8,6 +8,76 @@ #include +/* + * Luau heap uses a size-segregated page structure, with individual pages and large allocations + * allocated using system heap (via frealloc callback). + * + * frealloc callback serves as a general, if slow, allocation callback that can allocate, free or + * resize allocations: + * + * void* frealloc(void* ud, void* ptr, size_t oldsize, size_t newsize); + * + * frealloc(ud, NULL, 0, x) creates a new block of size x + * frealloc(ud, p, x, 0) frees the block p (must return NULL) + * frealloc(ud, NULL, 0, 0) does nothing, equivalent to free(NULL) + * + * frealloc returns NULL if it cannot create or reallocate the area + * (any reallocation to an equal or smaller size cannot fail!) + * + * On top of this, Luau implements heap storage which is split into two types of allocations: + * + * - GCO, short for "garbage collected objects" + * - other objects (for example, arrays stored inside table objects) + * + * The heap layout for these two allocation types is a bit different. + * + * All GCO are allocated in pages, which is a block of memory of ~16K in size that has a page header + * (lua_Page). Each page contains 1..N blocks of the same size, where N is selected to fill the page + * completely. This amortizes the allocation cost and increases locality. Each GCO block starts with + * the GC header (GCheader) which contains the object type, mark bits and other GC metadata. If the + * GCO block is free (not used), then it must have the type set to TNIL; in this case the block can + * be part of the per-page free list, the link for that list is stored after the header (freegcolink). + * + * Importantly, the GCO block doesn't have any back references to the page it's allocated in, so it's + * impossible to free it in isolation - GCO blocks are freed by sweeping the pages they belong to, + * using luaM_freegco which must specify the page; this is called by page sweeper that traverses the + * entire page's worth of objects. For this reason it's also important that freed GCO blocks keep the + * GC header intact and accessible (with type = NIL) so that the sweeper can access it. + * + * Some GCOs are too large to fit in a 16K page without excessive fragmentation (the size threshold is + * currently 512 bytes); in this case, we allocate a dedicated small page with just a single block's worth + * storage space, but that requires allocating an extra page header. In effect large GCOs are a little bit + * less memory efficient, but this allows us to uniformly sweep small and large GCOs using page lists. + * + * All GCO pages are linked in a large intrusive linked list (global_State::allgcopages). Additionally, + * for each block size there's a page free list that contains pages that have at least one free block + * (global_State::freegcopages). This free list is used to make sure object allocation is O(1). + * + * Compared to GCOs, regular allocations have two important differences: they can be freed in isolation, + * and they don't start with a GC header. Because of this, each allocation is prefixed with block metadata, + * which contains the pointer to the page for allocated blocks, and the pointer to the next free block + * inside the page for freed blocks. + * For regular allocations that are too large to fit in a page (using the same threshold of 512 bytes), + * we don't allocate a separate page, instead simply using frealloc to allocate a vanilla block of memory. + * + * Just like GCO pages, we store a page free list (global_State::freepages) that allows O(1) allocation; + * there is no global list for non-GCO pages since we never need to traverse them directly. + * + * In both cases, we pick the page by computing the size class from the block size which rounds the block + * size up to reduce the chance that we'll allocate pages that have very few allocated blocks. The size + * class strategy is determined by SizeClassConfig constructor. + * + * Note that when the last block in a page is freed, we immediately free the page with frealloc - the + * memory manager doesn't currently attempt to keep unused memory around. This can result in excessive + * allocation traffic and can be mitigated by adding a page cache in the future. + * + * For both GCO and non-GCO pages, the per-page block allocation combines bump pointer style allocation + * (lua_Page::freeNext) and per-page free list (lua_Page::freeList). We use the bump allocator to allocate + * the contents of the page, and the free list for further reuse; this allows shorter page setup times + * which results in less variance between allocation cost, as well as tighter sweep bounds for newly + * allocated pages. + */ + LUAU_FASTFLAG(LuauGcPagedSweep) #ifndef __has_feature @@ -56,6 +126,7 @@ static_assert(offsetof(GCObject, ts) == 0, "TString data must be located at the const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata + const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms const size_t kGCOLinkOffset = (sizeof(GCheader) + sizeof(void*) - 1) & ~(sizeof(void*) - 1); // GCO pages contain freelist links after the GC header @@ -107,24 +178,6 @@ const SizeClassConfig kSizeClassConfig; #define metadata(block) (*(void**)(block)) #define freegcolink(block) (*(void**)((char*)block + kGCOLinkOffset)) -/* -** About the realloc function: -** void * frealloc (void *ud, void *ptr, size_t osize, size_t nsize); -** (`osize' is the old size, `nsize' is the new size) -** -** Lua ensures that (ptr == NULL) iff (osize == 0). -** -** * frealloc(ud, NULL, 0, x) creates a new block of size `x' -** -** * frealloc(ud, p, x, 0) frees the block `p' -** (in this specific case, frealloc must return NULL). -** particularly, frealloc(ud, NULL, 0, 0) does nothing -** (which is equivalent to free(NULL) in ANSI C) -** -** frealloc returns NULL if it cannot create or reallocate the area -** (any reallocation to an equal or smaller size cannot fail!) -*/ - struct lua_Page { // list of pages with free blocks @@ -135,13 +188,12 @@ struct lua_Page lua_Page* gcolistprev; lua_Page* gcolistnext; - int busyBlocks; - int blockSize; + int pageSize; // page size in bytes, including page header + int blockSize; // block size in bytes, including block header (for non-GCO) - void* freeList; - int freeNext; - - int pageSize; + void* freeList; // next free block in this page; linked with metadata()/freegcolink() + int freeNext; // next free block offset in this page, in bytes; when negative, freeList is used instead + int busyBlocks; // number of blocks allocated out of this page union { @@ -177,7 +229,7 @@ static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) page->gcolistprev = NULL; page->gcolistnext = NULL; - page->busyBlocks = 0; + page->pageSize = kPageSize; page->blockSize = blockSize; // note: we start with the last block in the page and move downward @@ -185,6 +237,7 @@ static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order page->freeList = NULL; page->freeNext = (blockCount - 1) * blockSize; + page->busyBlocks = 0; // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) LUAU_ASSERT(!g->freepages[sizeClass]); @@ -214,7 +267,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int page->gcolistprev = NULL; page->gcolistnext = NULL; - page->busyBlocks = 0; + page->pageSize = pageSize; page->blockSize = blockSize; // note: we start with the last block in the page and move downward @@ -222,8 +275,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order page->freeList = NULL; page->freeNext = (blockCount - 1) * blockSize; - - page->pageSize = pageSize; + page->busyBlocks = 0; if (gcopageset) { @@ -406,8 +458,7 @@ static void* newgcoblock(lua_State* L, int sizeClass) page->next = NULL; } - // the user data is right after the metadata - return (char*)block; + return block; } static void freeblock(lua_State* L, int sizeClass, void* block) @@ -421,6 +472,7 @@ static void freeblock(lua_State* L, int sizeClass, void* block) lua_Page* page = (lua_Page*)metadata(block); LUAU_ASSERT(page && page->busyBlocks > 0); LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader); + LUAU_ASSERT(block >= page->data && block < (char*)page + page->pageSize); // if the page wasn't in the page free list, it should be now since it got a block! if (!page->freeList && page->freeNext < 0) @@ -455,6 +507,9 @@ static void freeblock(lua_State* L, int sizeClass, void* block) static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) { LUAU_ASSERT(FFlag::LuauGcPagedSweep); + LUAU_ASSERT(page && page->busyBlocks > 0); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(block >= page->data && block < (char*)page + page->pageSize); global_State* g = L->global; @@ -575,6 +630,8 @@ void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, else { LUAU_ASSERT(page->busyBlocks == 1); + LUAU_ASSERT(size_t(page->blockSize) == osize); + LUAU_ASSERT((void*)block == page->data); freepage(L, &g->allgcopages, page); } @@ -626,8 +683,12 @@ void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlo int blockCount = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; - *start = page->data + page->freeNext + page->blockSize; - *end = page->data + blockCount * page->blockSize; + LUAU_ASSERT(page->freeNext >= -page->blockSize && page->freeNext <= (blockCount - 1) * page->blockSize); + + char* data = page->data; // silences ubsan when indexing page->data + + *start = data + page->freeNext + page->blockSize; + *end = data + blockCount * page->blockSize; *busyBlocks = page->busyBlocks; *blockSize = page->blockSize; } @@ -675,7 +736,7 @@ void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, l for (lua_Page* curr = g->allgcopages; curr;) { - lua_Page* next = curr->gcolistnext; // page blockvisit might destroy the page + lua_Page* next = curr->gcolistnext; // block visit might destroy the page luaM_visitpage(curr, context, visitor); diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index 370c7b28..d5bd76a8 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -131,7 +131,7 @@ void luaO_chunkid(char* out, const char* source, size_t bufflen) { size_t l; source++; /* skip the `@' */ - bufflen -= sizeof(" '...' "); + bufflen -= sizeof("..."); l = strlen(source); strcpy(out, ""); if (l > bufflen) @@ -144,7 +144,7 @@ void luaO_chunkid(char* out, const char* source, size_t bufflen) else { /* out = [string "string"] */ size_t len = strcspn(source, "\n\r"); /* stop at first newline */ - bufflen -= sizeof(" [string \"...\"] "); + bufflen -= sizeof("[string \"...\"]"); if (len > bufflen) len = bufflen; strcpy(out, "[string \""); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index cba3670a..c3b662a2 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -609,7 +609,8 @@ static void luau_execute(lua_State* L) if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') { - setnvalue(ra, rb->value.v[ic]); + const float* v = rb->value.v; // silences ubsan when indexing v[] + setnvalue(ra, v[ic]); VM_NEXT(); } diff --git a/extern/isocline/src/isocline.c b/extern/isocline/src/isocline.c index 13278062..8b6055cf 100644 --- a/extern/isocline/src/isocline.c +++ b/extern/isocline/src/isocline.c @@ -13,7 +13,12 @@ // $ gcc -c src/isocline.c //------------------------------------------------------------- #if !defined(IC_SEPARATE_OBJS) -# define _CRT_SECURE_NO_WARNINGS // for msvc +# ifndef _CRT_NONSTDC_NO_WARNINGS +# define _CRT_NONSTDC_NO_WARNINGS // for msvc +# endif +# ifndef _CRT_SECURE_NO_WARNINGS +# define _CRT_SECURE_NO_WARNINGS // for msvc +# endif # define _XOPEN_SOURCE 700 // for wcwidth # define _DEFAULT_SOURCE // ensure usleep stays visible with _XOPEN_SOURCE >= 700 # include "attr.c" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index d8af94db..cd7a21d8 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -605,8 +605,6 @@ RETURN R0 1 TEST_CASE("TableLiteralsIndexConstant") { - ScopedFastFlag sff("LuauCompileTableIndexOpt", true); - // validate that we use SETTTABLEKS for constant variable keys CHECK_EQ("\n" + compileFunction0(R"( local a, b = "key", "value" @@ -2483,8 +2481,6 @@ return TEST_CASE("DebugLineInfoAssignment") { - ScopedFastFlag sff("LuauCompileTableIndexOpt", true); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e580949f..8b58d2ce 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -492,8 +492,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag sffw("LuauBytecodeV2Write", true); - runConformance("debug.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index d1cc49b2..577415fc 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1392,19 +1392,31 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, }; + + TypeId colorType = typeChecker.globalTypes.addType(TableTypeVar{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + + getMutable(colorType)->props = { + {"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"} } + }; + + addGlobalBinding(typeChecker, "Color3", Binding{colorType, {}}); + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( return function (i: Instance) i:Wait(1.0) print(i.Name) + print(Color3.toHSV()) + print(Color3.doesntexist, i.doesntexist) -- type error, but this verifies we correctly handle non-existent members return i.DataCost end )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE_EQ(result.warnings.size(), 3); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); - CHECK_EQ(result.warnings[1].text, "Member 'Instance.DataCost' is deprecated"); + CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); + CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); } TEST_CASE_FIXTURE(Fixture, "TableOperations") @@ -1475,9 +1487,11 @@ _ = (true and true) or true _ = (true and false) and (42 and false) _ = true and true or false -- no warning since this is is a common pattern used as a ternary replacement + +_ = if true then 1 elseif true then 2 else 3 )"); - REQUIRE_EQ(result.warnings.size(), 7); + REQUIRE_EQ(result.warnings.size(), 8); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); CHECK_EQ(result.warnings[1].text, "Condition has already been checked on column 5"); @@ -1487,6 +1501,7 @@ _ = true and true or false -- no warning since this is is a common pattern used CHECK_EQ(result.warnings[5].text, "Condition has already been checked on column 6"); CHECK_EQ(result.warnings[6].text, "Condition has already been checked on column 15"); CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); + CHECK_EQ(result.warnings[7].text, "Condition has already been checked on column 8"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") @@ -1528,4 +1543,19 @@ return foo, moo, a1, a2 CHECK_EQ(result.warnings[3].text, "Function parameter 'self' already defined implicitly"); } +TEST_CASE_FIXTURE(Fixture, "MisleadingAndOr") +{ + LintResult result = lint(R"( +_ = math.random() < 0.5 and true or 42 +_ = math.random() < 0.5 and false or 42 -- misleading +_ = math.random() < 0.5 and nil or 42 -- misleading +_ = math.random() < 0.5 and 0 or 42 +_ = (math.random() < 0.5 and false) or 42 -- currently ignored +)"); + + REQUIRE_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; consider using if-then-else expression instead"); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index f660bcd3..1f9c9739 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -8,9 +8,22 @@ #include #include +#include #include #include +struct Completion +{ + std::string completion; + std::string display; + + bool operator<(Completion const& other) const + { + return std::tie(completion, display) < std::tie(other.completion, other.display); + } +}; + +using CompletionSet = std::set; class ReplFixture { @@ -34,6 +47,27 @@ public: lua_pop(L, 1); return result; } + + CompletionSet getCompletionSet(const char* inputPrefix) + { + CompletionSet result; + int top = lua_gettop(L); + getCompletions(L, inputPrefix, [&result](const std::string& completion, const std::string& display) { + result.insert(Completion{completion, display}); + }); + // Ensure that generating completions doesn't change the position of luau's stack top. + CHECK(top == lua_gettop(L)); + + return result; + } + + bool checkCompletion(const CompletionSet& completions, const std::string& prefix, const std::string& expected) + { + std::string expectedDisplay(expected.substr(0, expected.find_first_of('('))); + Completion expectedCompletion{prefix + expected, expectedDisplay}; + return completions.count(expectedCompletion) == 1; + } + lua_State* L; private: @@ -115,3 +149,61 @@ TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ReplCodeCompletion"); + +TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") +{ + runCode(L, R"( + myvariable1 = 5 + myvariable2 = 5 +)"); + CompletionSet completions = getCompletionSet("myvar"); + + std::string prefix = ""; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "myvariable1")); + CHECK(checkCompletion(completions, prefix, "myvariable2")); +} + +TEST_CASE_FIXTURE(ReplFixture, "CompleteTableKeys") +{ + runCode(L, R"( + t = { color = "red", size = 1, shape = "circle" } +)"); + { + CompletionSet completions = getCompletionSet("t."); + + std::string prefix = "t."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "color")); + CHECK(checkCompletion(completions, prefix, "size")); + CHECK(checkCompletion(completions, prefix, "shape")); + } + + { + CompletionSet completions = getCompletionSet("t.s"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "size")); + CHECK(checkCompletion(completions, prefix, "shape")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "StringMethods") +{ + runCode(L, R"( + s = "" +)"); + { + CompletionSet completions = getCompletionSet("s:l"); + + std::string prefix = "s:"; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "len(")); + CHECK(checkCompletion(completions, prefix, "lower(")); + } +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 76ab23b3..a8729268 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -595,4 +595,65 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ LUAU_REQUIRE_NO_ERRORS(result); } +/* + * The two-pass alias definition system starts by ascribing a free TypeVar to each alias. It then + * circles back to fill in the actual type later on. + * + * If this free type is unified with something degenerate like `any`, we need to take extra care + * to ensure that the alias actually binds to the type that the user expected. + */ +TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") +{ + ScopedFastFlag sff[] = { + {"LuauTwoPassAliasDefinitionFix", true} + }; + + CheckResult result = check(R"( + local function x() + local y: FutureType = {}::any + return 1 + end + type FutureType = { foo: typeof(x()) } + local d: FutureType = { smth = true } -- missing error, 'd' is resolved to 'any' + )"); + + CHECK_EQ("{| foo: number |}", toString(requireType("d"), {true})); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") +{ + ScopedFastFlag sff[] = { + {"LuauTwoPassAliasDefinitionFix", true}, + + // We also force these two flags because this surfaced an unfortunate interaction. + {"LuauErrorRecoveryType", true}, + {"LuauQuantifyInPlace2", true}, + }; + + CheckResult result = check(R"( + local B = {} + B.bar = 4 + + function B:smth1() + local self: FutureIntersection = self + self.foo = 4 + return 4 + end + + function B:smth2() + local self: FutureIntersection = self + self.bar = 5 -- error, even though we should have B part with bar + end + + type A = { foo: typeof(B.smth1({foo=3})) } -- trick toposort into sorting functions before types + type B = typeof(B) + + type FutureIntersection = A & B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 6730bedb..df06884d 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauFixTonumberReturnType) - using namespace Luau; LUAU_FASTFLAG(LuauUseCommittingTxnLog) @@ -850,11 +848,8 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") local b: number = tonumber('asdf') )"); - if (FFlag::LuauFixTonumberReturnType) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") @@ -893,7 +888,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, }; CheckResult result = check(R"( @@ -910,7 +905,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_ { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, }; CheckResult result = check(R"( @@ -927,7 +922,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e5eb0dca..2bcd840c 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -262,7 +262,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") // Just needs to fully support equality refinement. Which is annoying without type states. TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") { - ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; CheckResult result = check(R"( type T = {x: string, y: number} | {x: nil, y: nil} @@ -616,4 +616,76 @@ local a: Self
CHECK_EQ(toString(requireType("a")), "Table
"); } +TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") +{ + ScopedFastFlag sff[]{ + {"LuauQuantifyInPlace2", true}, + {"LuauReturnAnyInsteadOfICE", true}, + }; + + // In-place quantification causes these types to have the wrong types but only because of nasty interaction with prototyping. + // The type of f is initially () -> free1... + // Then the prototype iterator advances, and checks the function expression assigned to g, which has the type () -> free2... + // In the body it calls f and returns what f() returns. This binds free2... with free1..., causing f and g to have same types. + // We then quantify g, leaving it with the final type () -> a... + // Because free1... and free2... were bound, in combination with in-place quantification, f's return type was also turned into a... + // Then the check iterator catches up, and checks the body of f, and attempts to quantify it too. + // Alas, one of the requirements for quantification is that a type must contain free types. () -> a... has no free types. + // Thus the quantification for f was no-op, which explains why f does not have any type parameters. + // Calling f() will attempt to instantiate the function type, which turns generics in type binders into to free types. + // However, instantiations only converts generics contained within the type binders of a function, so instantiation was also no-op. + // Which means that calling f() simply returned a... rather than an instantiation of it. And since the call site was not in tail position, + // picking first element in a... triggers an ICE because calls returning generic packs are unexpected. + CheckResult result = check(R"( + local function f() end + + local g = function() return f() end + + local x = (f()) -- should error: no return values to assign from the call to f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now +} + +TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") +{ + CheckResult result = check(R"( + local function id(x) return x end + local n2n: (number) -> number = id + local s2s: (string) -> string = id + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") +{ + ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( + --!strict + local function f(...) return ... end + local g = function(...) return f(...) end + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 3a610c3a..48e6be6a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -6,7 +6,7 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauDiscriminableUnions) +LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauQuantifyInPlace2) @@ -262,7 +262,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { LUAU_REQUIRE_NO_ERRORS(result); @@ -435,7 +435,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauSingletonTypes", true}, }; @@ -485,7 +485,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( @@ -589,7 +589,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { LUAU_REQUIRE_NO_ERRORS(result); } @@ -1002,7 +1002,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, }; @@ -1028,7 +1028,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, }; @@ -1069,7 +1069,7 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1094,7 +1094,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1118,7 +1118,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, }; @@ -1157,7 +1157,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) LUAU_REQUIRE_NO_ERRORS(result); else { diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index ead3d762..531a382f 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5164,4 +5164,151 @@ function x:Destroy(): () end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +export type Type = { x: { a: number } } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = { x = { a = 2 } } +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +local y = setmetatable({}, {}) +export type Type = { x: typeof(y) } +return { x = y } + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = types +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + {"LuauLengthOnCompositeType", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + {"LuauLengthOnCompositeType", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); +} + +/* + * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at + * the level of the table. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") +{ + CheckResult result = check(R"( + --!strict + local T = {} + + local function f(prop) + T[1] = { + prop = prop, + } + end + + local function g() + local l = T[1].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 0aeca096..8c7fb79a 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -273,4 +273,21 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") state.tryUnify(&metatable, &variant); } +TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") +{ + ScopedFastFlag sffs[] = { + {"LuauUseCommittingTxnLog", true}, + {"LuauFollowWithCommittingTxnLogInAnyUnification", true}, + }; + + TypePackVar free{FreeTypePack{TypeLevel{}}}; + TypePackVar target{TypePack{}}; + + TypeVar func{FunctionTypeVar{&free, &free}}; + + state.tryUnify(&free, &target); + // Shouldn't assert or error. + state.tryUnify(&func, typeChecker.anyType); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index de091632..78d90077 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,9 +118,7 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) -local ud = newproxy(true) -getmetatable(ud).__len = function() return 42 end -assert((function() return #ud end)() == 42) +assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) assert((function() local a = 1 a = -a return a end)() == -1) @@ -325,6 +323,10 @@ assert((function() local t = {6, 9, 7} t[4.5] = 10 return t[4.5] end)() == 10) assert((function() local t = {6, 9, 7} t['a'] = 11 return t['a'] end)() == 11) assert((function() local t = {6, 9, 7} setmetatable(t, { __newindex = function(t,i,v) rawset(t, i * 10, v) end }) t[1] = 17 t[5] = 1 return concat(t[1],t[5],t[50]) end)() == "17,nil,1") +-- userdata access +assert((function() local ud = newproxy(true) getmetatable(ud).__index = function(ud,i) return i * 10 end return ud[2] end)() == 20) +assert((function() local ud = newproxy(true) getmetatable(ud).__index = function() return function(self, i) return i * 10 end end return ud:meow(2) end)() == 20) + -- and/or -- rhs is a constant assert((function() local a = 1 a = a and 2 return a end)() == 2) @@ -462,7 +464,7 @@ assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r en -- metatable ops local function vec3t(x, y, z) - return setmetatable({ x=x, y=y, z=z}, { + return setmetatable({x=x, y=y, z=z}, { __add = function(l, r) return vec3t(l.x + r.x, l.y + r.y, l.z + r.z) end, __sub = function(l, r) return vec3t(l.x - r.x, l.y - r.y, l.z - r.z) end, __mul = function(l, r) return type(r) == "number" and vec3t(l.x * r, l.y * r, l.z * r) or vec3t(l.x * r.x, l.y * r.y, l.z * r.z) end, diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 8c96ab33..0e410000 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -37,6 +37,7 @@ coroutine.resume(co2, 0 / 0, 42) assert(debug.traceback(co2) == "debug.lua:31 function halp\n") assert(debug.info(co2, 0, "l") == 31) +assert(debug.info(co2, 0, "f") == halp) -- info errors function qux(...) diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index d5ff215b..751188be 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -260,8 +260,7 @@ local a,b = loadstring(s) assert(not a) --assert(string.find(b, "line 2")) --- Test for CLI-28786 --- The xpcall is intentially going to cause an exception +-- The xpcall is intentionally going to cause an exception -- followed by a forced exception in the error handler. -- If the secondary handler isn't trapped, it will cause -- the unit test to fail. If the xpcall captures the @@ -281,6 +280,19 @@ coroutine.wrap(function() assert(not pcall(debug.getinfo, coroutine.running(), 0, ">")) end)() +-- loadstring chunk truncation +local a,b = loadstring("nope", "@short") +assert(not a and b:match('[^ ]+') == "short:1:") + +local a,b = loadstring("nope", "@" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) +assert(not a and b:match('[^ ]+') == "...wontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities:1:") + +local a,b = loadstring("nope", "=short") +assert(not a and b:match('[^ ]+') == "short:1:") + +local a,b = loadstring("nope", "=" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) +assert(not a and b:match('[^ ]+') == "thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbuffe:1:") + -- arith errors function ecall(fn, ...) local ok, err = pcall(fn, ...) diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 409cd224..5804ea7f 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -180,6 +180,11 @@ x,y,z=nil collectgarbage() assert(next(a) == string.rep('$', 11)) +-- shrinking tables reduce their capacity; confirming the shrinking is difficult but we can at least test the surface level behavior +a = {}; setmetatable(a, {__mode = 'ks'}) +for i=1,lim do a[{}] = i end +collectgarbage() +assert(next(a) == nil) -- testing userdata collectgarbage("stop") -- stop collection @@ -315,8 +320,6 @@ do end collectgarbage() - end - return('OK') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index bfea0e1f..79ea0fb6 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -289,6 +289,7 @@ assert(math.sqrt("4") == 2) assert(math.tanh("0") == 0) assert(math.tan("0") == 0) assert(math.clamp("0", 2, 3) == 2) +assert(math.clamp("4", 2, 3) == 3) assert(math.sign("2") == 1) assert(math.sign("-2") == -1) assert(math.sign("0") == 0) diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.lua index d05f9577..178c56b8 100644 --- a/tests/conformance/vararg.lua +++ b/tests/conformance/vararg.lua @@ -139,6 +139,12 @@ assert(selectmany(1, 10, 20, 30) == "10,20,30") assert(selectone(2, 10, 20, 30) == 20) assert(selectmany(2, 10, 20, 30) == "20,30") +assert(selectone(3, 10, 20, 30) == 30) +assert(selectmany(3, 10, 20, 30) == "30") + +assert(selectone(4, 10, 20, 30) == nil) +assert(selectmany(4, 10, 20, 30) == "") + assert(selectone(-2, 10, 20, 30) == 20) assert(selectmany(-2, 10, 20, 30) == "20,30") diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 7d18bda3..22d6adfc 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -87,9 +87,18 @@ assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == fal -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) +-- validate component access (both cases) +assert(vector(1, 2, 3).x == 1) +assert(vector(1, 2, 3).X == 1) +assert(vector(1, 2, 3).y == 2) +assert(vector(1, 2, 3).Y == 2) +assert(vector(1, 2, 3).z == 3) +assert(vector(1, 2, 3).Z == 3) + -- additional checks for 4-component vectors if vector_size == 4 then assert(vector(1, 2, 3, 4).w == 4) + assert(vector(1, 2, 3, 4).W == 4) end return 'OK' diff --git a/tools/heapgraph.py b/tools/heapgraph.py index 106db549..d4d29af1 100644 --- a/tools/heapgraph.py +++ b/tools/heapgraph.py @@ -7,10 +7,20 @@ # The result of analysis is a .svg file which can be viewed in a browser # To generate these dumps, use luaC_dump, ideally preceded by luaC_fullgc +import argparse import json import sys import svg +argumentParser = argparse.ArgumentParser(description='Luau heap snapshot analyzer') + +argumentParser.add_argument('--split', dest = 'split', type = str, default = 'none', help = 'Perform additional root split using memory categories', choices = ['none', 'custom', 'all']) + +argumentParser.add_argument('snapshot') +argumentParser.add_argument('snapshotnew', nargs='?') + +arguments = argumentParser.parse_args() + class Node(svg.Node): def __init__(self): svg.Node.__init__(self) @@ -30,14 +40,14 @@ class Node(svg.Node): return "{} ({:,} bytes, {:.1%}); self: {:,} bytes in {:,} objects".format(self.name, self.width, self.width / root.width, self.size, self.count) # load files -if len(sys.argv) == 2: +if arguments.snapshotnew == None: dumpold = None - with open(sys.argv[1]) as f: + with open(arguments.snapshot) as f: dump = json.load(f) else: - with open(sys.argv[1]) as f: + with open(arguments.snapshot) as f: dumpold = json.load(f) - with open(sys.argv[2]) as f: + with open(arguments.snapshotnew) as f: dump = json.load(f) # reachability analysis: how much of the heap is reachable from roots? @@ -111,12 +121,15 @@ while offset < len(queue): if "object" in obj: queue.append((obj["object"], node)) -def annotateContainedCategories(node): +def annotateContainedCategories(node, start): for obj in node.objects: + if obj["cat"] < start: + obj["cat"] = 0 + node.categories.add(obj["cat"]) for child in node.children.values(): - annotateContainedCategories(child) + annotateContainedCategories(child, start) for cat in child.categories: node.categories.add(cat) @@ -172,9 +185,11 @@ def splitIntoCategories(root): return result -# temporarily disabled because it makes FG harder to read, maybe this should be a separate command line option? -if dump["stats"].get("categories") and False: - annotateContainedCategories(root) +if dump["stats"].get("categories") and arguments.split != 'none': + if arguments.split == 'custom': + annotateContainedCategories(root, 128) + else: + annotateContainedCategories(root, 0) root = splitIntoCategories(root) diff --git a/tools/svg.py b/tools/svg.py index 99853fb6..21200eeb 100644 --- a/tools/svg.py +++ b/tools/svg.py @@ -452,7 +452,7 @@ def display(root, title, colors, flip = False): .replace("$gradient-start", gradient_start) .replace("$gradient-end", gradient_end) .replace("$height", str(svgheight)) - .replace("$status", str(svgheight - 16 + 3)) + .replace("$status", str((svgheight - 16 + 3 if flip else 3 * 16 - 3))) .replace("$flip", str(int(flip))) ) From 49304095161c4532a236b8141faf1e0c7dda498f Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 17 Feb 2022 16:41:20 -0800 Subject: [PATCH 026/102] Sync to upstream/release/515 --- Analysis/include/Luau/Documentation.h | 3 + Analysis/include/Luau/Frontend.h | 3 +- Analysis/include/Luau/Linter.h | 7 +- Analysis/include/Luau/Module.h | 4 +- Analysis/include/Luau/Quantify.h | 5 +- Analysis/include/Luau/Substitution.h | 20 +- Analysis/include/Luau/TypeInfer.h | 45 +++- Analysis/src/Autocomplete.cpp | 45 ++-- Analysis/src/Config.cpp | 2 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 2 +- Analysis/src/Frontend.cpp | 35 ++- Analysis/src/JsonEncoder.cpp | 36 ++- Analysis/src/Linter.cpp | 126 ++++++++- Analysis/src/Quantify.cpp | 10 +- Analysis/src/Substitution.cpp | 23 -- Analysis/src/TopoSortStatements.cpp | 3 +- Analysis/src/Transpiler.cpp | 8 +- Analysis/src/TypeAttach.cpp | 4 +- Analysis/src/TypeInfer.cpp | 142 +++-------- Analysis/src/TypeUtils.cpp | 12 + Analysis/src/TypeVar.cpp | 6 +- Analysis/src/TypedAllocator.cpp | 3 + Analysis/src/Unifier.cpp | 16 +- Ast/include/Luau/Ast.h | 17 +- Ast/include/Luau/Lexer.h | 5 + Ast/include/Luau/ParseResult.h | 69 +++++ Ast/include/Luau/Parser.h | 56 +--- Ast/src/Ast.cpp | 29 +-- Ast/src/Lexer.cpp | 12 +- Ast/src/Parser.cpp | 167 +++++++----- Ast/src/TimeTrace.cpp | 6 + CLI/FileUtils.cpp | 7 +- CLI/Repl.cpp | 171 +++++++++---- CMakeLists.txt | 14 + Sources.cmake | 1 + VM/src/lapi.cpp | 25 +- VM/src/ldebug.cpp | 2 +- VM/src/ldo.cpp | 13 +- VM/src/ldo.h | 2 +- VM/src/lgc.cpp | 4 +- VM/src/lgc.h | 2 +- VM/src/lgcdebug.cpp | 10 +- VM/src/lperf.cpp | 6 + VM/src/lstate.cpp | 35 ++- VM/src/lstate.h | 8 +- VM/src/lvmexecute.cpp | 2 +- VM/src/lvmload.cpp | 6 +- VM/src/lvmutils.cpp | 4 +- fuzz/linter.cpp | 2 +- fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 22 +- tests/Compiler.test.cpp | 14 +- tests/Conformance.test.cpp | 8 + tests/Fixture.cpp | 4 +- tests/Fixture.h | 1 - tests/Frontend.test.cpp | 3 - tests/JsonEncoder.test.cpp | 2 +- tests/Linter.test.cpp | 48 +++- tests/NonstrictMode.test.cpp | 1 - tests/Parser.test.cpp | 51 ++-- tests/Repl.test.cpp | 209 ++++++++++++++- tests/RequireTracer.test.cpp | 2 +- tests/TypeInfer.aliases.test.cpp | 6 +- tests/TypeInfer.annotations.test.cpp | 1 - tests/TypeInfer.builtins.test.cpp | 1 - tests/TypeInfer.classes.test.cpp | 1 - tests/TypeInfer.definitions.test.cpp | 1 - tests/TypeInfer.generics.test.cpp | 1 - tests/TypeInfer.intersectionTypes.test.cpp | 1 - tests/TypeInfer.provisional.test.cpp | 1 - tests/TypeInfer.singletons.test.cpp | 2 - tests/TypeInfer.tables.test.cpp | 43 +++- tests/TypeInfer.test.cpp | 27 +- tests/TypeInfer.tryUnify.test.cpp | 1 - tests/TypeInfer.typePacks.cpp | 3 - tests/TypeInfer.unionTypes.test.cpp | 1 - tests/TypePack.test.cpp | 1 - tests/TypeVar.test.cpp | 1 - tests/conformance/coroutine.lua | 9 + tests/conformance/coverage.lua | 8 + tests/conformance/debug.lua | 3 + tests/main.cpp | 7 +- tools/natvis/Analysis.natvis | 78 ++++++ tools/natvis/Ast.natvis | 25 ++ tools/natvis/VM.natvis | 269 ++++++++++++++++++++ 85 files changed, 1512 insertions(+), 581 deletions(-) create mode 100644 Ast/include/Luau/ParseResult.h create mode 100644 tools/natvis/Analysis.natvis create mode 100644 tools/natvis/Ast.natvis create mode 100644 tools/natvis/VM.natvis diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 68ff3a7c..7a2b56ff 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -21,6 +21,7 @@ struct BasicDocumentation { std::string documentation; std::string learnMoreLink; + std::string codeSample; }; struct FunctionParameterDocumentation @@ -37,6 +38,7 @@ struct FunctionDocumentation std::vector parameters; std::vector returns; std::string learnMoreLink; + std::string codeSample; }; struct OverloadedFunctionDocumentation @@ -52,6 +54,7 @@ struct TableDocumentation std::string documentation; Luau::DenseHashMap keys; std::string learnMoreLink; + std::string codeSample; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 1f64db30..0bf8f362 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -24,6 +24,7 @@ struct TypeChecker; struct FileResolver; struct ModuleResolver; struct ParseResult; +struct HotComment; struct LoadDefinitionFileResult { @@ -35,7 +36,7 @@ struct LoadDefinitionFileResult LoadDefinitionFileResult loadDefinitionFile( TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); -std::optional parseMode(const std::vector& hotcomments); +std::optional parseMode(const std::vector& hotcomments); std::vector parsePathExpr(const AstExpr& pathExpr); diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index ec3c124d..6c7ce47f 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -14,6 +14,7 @@ class AstStat; class AstNameTable; struct TypeChecker; struct Module; +struct HotComment; using ScopePtr = std::shared_ptr; @@ -50,6 +51,7 @@ struct LintWarning Code_TableOperations = 23, Code_DuplicateCondition = 24, Code_MisleadingAndOr = 25, + Code_CommentDirective = 26, Code__Count }; @@ -60,7 +62,7 @@ struct LintWarning static const char* getName(Code code); static Code parseName(const char* name); - static uint64_t parseMask(const std::vector& hotcomments); + static uint64_t parseMask(const std::vector& hotcomments); }; struct LintResult @@ -90,7 +92,8 @@ struct LintOptions void setDefaults(); }; -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options); +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, + const std::vector& hotcomments, const LintOptions& options); std::vector getDeprecatedGlobals(const AstNameTable& names); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 1bf0473c..61200771 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -6,7 +6,7 @@ #include "Luau/TypedAllocator.h" #include "Luau/ParseOptions.h" #include "Luau/Error.h" -#include "Luau/Parser.h" +#include "Luau/ParseResult.h" #include #include @@ -37,8 +37,8 @@ struct SourceModule AstStatBlock* root = nullptr; std::optional mode; - uint64_t ignoreLints = 0; + std::vector hotcomments; std::vector commentLocations; SourceModule() diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index f46df146..e48cad40 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,9 +6,6 @@ namespace Luau { -struct Module; -using ModulePtr = std::shared_ptr; - -void quantify(ModulePtr module, TypeId ty, TypeLevel level); +void quantify(TypeId ty, TypeLevel level); } // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index f85b4269..9662d5b3 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -101,9 +101,6 @@ struct Tarjan // This is hot code, so we optimize recursion to a stack. TarjanResult loop(); - // Clear the state - void clear(); - // Find or create the index for a vertex. // Return a boolean which is `true` if it's a freshly created index. std::pair indexify(TypeId ty); @@ -166,7 +163,17 @@ struct FindDirty : Tarjan // and replaces them with clean ones. struct Substitution : FindDirty { - ModulePtr currentModule; +protected: + Substitution(const TxnLog* log_, TypeArena* arena) + : arena(arena) + { + log = log_; + LUAU_ASSERT(log); + LUAU_ASSERT(arena); + } + +public: + TypeArena* arena; DenseHashMap newTypes{nullptr}; DenseHashMap newPacks{nullptr}; @@ -192,12 +199,13 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - return currentModule->internalTypes.addType(tv); + return arena->addType(tv); } + template TypePackId addTypePack(const T& tp) { - return currentModule->internalTypes.addTypePack(TypePackVar{tp}); + return arena->addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 5592fa1f..3c5ded3c 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -5,7 +5,6 @@ #include "Luau/Error.h" #include "Luau/Module.h" #include "Luau/Symbol.h" -#include "Luau/Parser.h" #include "Luau/Substitution.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" @@ -37,6 +36,15 @@ struct Unifier; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { + ReplaceGenerics( + const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) + : Substitution(log, arena) + , level(level) + , generics(generics) + , genericPacks(genericPacks) + { + } + TypeLevel level; std::vector generics; std::vector genericPacks; @@ -50,8 +58,13 @@ struct ReplaceGenerics : Substitution // A substitution which replaces generic functions by monomorphic functions struct Instantiation : Substitution { + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) + : Substitution(log, arena) + , level(level) + { + } + TypeLevel level; - ReplaceGenerics replaceGenerics; bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -62,6 +75,12 @@ struct Instantiation : Substitution // A substitution which replaces free types by generic types. struct Quantification : Substitution { + Quantification(TypeArena* arena, TypeLevel level) + : Substitution(TxnLog::empty(), arena) + , level(level) + { + } + TypeLevel level; std::vector generics; std::vector genericPacks; @@ -74,6 +93,13 @@ struct Quantification : Substitution // A substitution which replaces free types by any struct Anyification : Substitution { + Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack) + : Substitution(TxnLog::empty(), arena) + , anyType(anyType) + , anyTypePack(anyTypePack) + { + } + TypeId anyType; TypePackId anyTypePack; bool isDirty(TypeId ty) override; @@ -85,6 +111,13 @@ struct Anyification : Substitution // A substitution which replaces the type parameters of a type function by arguments struct ApplyTypeFunction : Substitution { + ApplyTypeFunction(TypeArena* arena, TypeLevel level) + : Substitution(TxnLog::empty(), arena) + , level(level) + , encounteredForwardedType(false) + { + } + TypeLevel level; bool encounteredForwardedType; std::unordered_map typeArguments; @@ -351,8 +384,7 @@ private: // Note: `scope` must be a fresh scope. GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames, - bool useCache = false); + const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -392,11 +424,6 @@ public: ModulePtr currentModule; ModuleName currentModuleName; - Instantiation instantiation; - Quantification quantification; - Anyification anyification; - ApplyTypeFunction applyTypeFunction; - std::function prepareModuleScope; InternalErrorReporter* iceHandler; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 85099e12..5a1ae397 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -7,6 +7,7 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" +#include "Luau/Parser.h" // TODO: only needed for autocompleteSource which is deprecated #include #include @@ -14,9 +15,9 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); +LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -380,7 +381,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { // We are walking up the class hierarchy, so if we encounter a property that we have // already populated, it takes precedence over the property we found just now. - if (result.count(name) == 0 && name != Parser::errorName) + if (result.count(name) == 0 && name != kParseNameError) { Luau::TypeId type = Luau::follow(prop.type); TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct @@ -948,9 +949,12 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } } - for (size_t i = 0; i < node->returnAnnotation.types.size; i++) + if (!node->returnAnnotation) + return result; + + for (size_t i = 0; i < node->returnAnnotation->types.size; i++) { - AstType* ret = node->returnAnnotation.types.data[i]; + AstType* ret = node->returnAnnotation->types.data[i]; if (ret->location.containsClosed(position)) { @@ -965,7 +969,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } } - if (AstTypePack* retTp = node->returnAnnotation.tailType) + if (AstTypePack* retTp = node->returnAnnotation->tailType) { if (auto variadic = retTp->as()) { @@ -1136,7 +1140,7 @@ static AutocompleteEntryMap autocompleteStatement( AstNode* parent = ancestry.rbegin()[1]; if (AstStatIf* statIf = parent->as()) { - if (!statIf->elsebody || (statIf->hasElse && statIf->elseLocation.containsClosed(position))) + if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) { result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); @@ -1164,8 +1168,7 @@ static AutocompleteEntryMap autocompleteStatement( return result; } -// Returns true if completions were generated (completions will be inserted into 'outResult') -// Returns false if no completions were generated +// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) static bool autocompleteIfElseExpression( const AstNode* node, const std::vector& ancestry, const Position& position, AutocompleteEntryMap& outResult) { @@ -1173,6 +1176,13 @@ static bool autocompleteIfElseExpression( if (!parent) return false; + if (FFlag::LuauIfElseExprFixCompletionIssue && node->is()) + { + // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else + // expression. + return true; + } + AstExprIfElse* ifElseExpr = parent->as(); if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) { @@ -1310,7 +1320,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is() && (!FFlag::LuauCompleteBrokenStringParams || !nodes.back()->is())) + if (!nodes.back()->is() && !nodes.back()->is()) { return std::nullopt; } @@ -1408,8 +1418,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (auto typeReference = node->as()) { - if (typeReference->hasPrefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix.value), finder.ancestry}; + if (typeReference->prefix) + return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), finder.ancestry}; else return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; } @@ -1419,9 +1429,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (AstStatLocal* statLocal = node->as()) { - if (statLocal->vars.size == 1 && (!statLocal->hasEqualsSign || position < statLocal->equalsSignLocation.begin)) + if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (statLocal->hasEqualsSign && position >= statLocal->equalsSignLocation.end) + else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; else return {}; @@ -1449,7 +1459,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (!statForIn->hasIn || position <= statForIn->inLocation.begin) { AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; - if (lastName->name == Parser::errorName || lastName->location.containsClosed(position)) + if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) { // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer @@ -1499,7 +1509,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (AstStatIf* statIf = node->as(); statIf && !statIf->hasElse) + else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; @@ -1508,11 +1518,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (statIf->condition->is()) return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (!statIf->hasThen || statIf->thenLocation.containsClosed(position)) + else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; } else if (AstStatIf* statIf = extractStat(finder.ancestry); - statIf && (!statIf->hasThen || statIf->thenLocation.containsClosed(position))) + statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; @@ -1612,6 +1622,7 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback) { + // TODO: Remove #include "Luau/Parser.h" with this function auto sourceModule = std::make_unique(); ParseOptions parseOptions; parseOptions.captureComments = true; diff --git a/Analysis/src/Config.cpp b/Analysis/src/Config.cpp index d9fc44f8..35a2259d 100644 --- a/Analysis/src/Config.cpp +++ b/Analysis/src/Config.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Config.h" -#include "Luau/Parser.h" +#include "Luau/Lexer.h" #include "Luau/StringUtils.h" namespace diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index f3ef88fc..bf6e1193 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -167,7 +167,7 @@ declare function gcinfo(): number foreach: ({[K]: V}, (K, V) -> ()) -> (), foreachi: ({V}, (number, V) -> ()) -> (), - move: ({V}, number, number, number, {V}?) -> (), + move: ({V}, number, number, number, {V}?) -> {V}, clear: ({[K]: V}) -> (), freeze: ({[K]: V}) -> {[K]: V}, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 9001b19d..d8906f6e 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" @@ -16,23 +17,25 @@ #include LUAU_FASTFLAG(LuauInferInNoCheckMode) -LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) namespace Luau { -std::optional parseMode(const std::vector& hotcomments) +std::optional parseMode(const std::vector& hotcomments) { - for (const std::string& hc : hotcomments) + for (const HotComment& hc : hotcomments) { - if (hc == "nocheck") + if (!hc.header) + continue; + + if (hc.content == "nocheck") return Mode::NoCheck; - if (hc == "nonstrict") + if (hc.content == "nonstrict") return Mode::Nonstrict; - if (hc == "strict") + if (hc.content == "strict") return Mode::Strict; } @@ -607,13 +610,15 @@ std::pair Frontend::lintFragment(std::string_view sour SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); + uint64_t ignoreLints = LintWarning::parseMask(sourceModule.hotcomments); + Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint); - lintOptions.warningMask &= sourceModule.ignoreLints; + lintOptions.warningMask &= ~ignoreLints; double timestamp = getTimestamp(); - std::vector warnings = - Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, enabledLintWarnings.value_or(config.enabledLint)); + std::vector warnings = Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, + sourceModule.hotcomments, enabledLintWarnings.value_or(config.enabledLint)); stats.timeLint += getTimestamp() - timestamp; @@ -651,8 +656,10 @@ LintResult Frontend::lint(const SourceModule& module, std::optionalgetConfig(module.name); + uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); + LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~module.ignoreLints; + options.warningMask &= ~ignoreLints; Mode mode = module.mode.value_or(config.mode); if (mode != Mode::NoCheck) @@ -671,7 +678,7 @@ LintResult Frontend::lint(const SourceModule& module, std::optional warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), options); + std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); stats.timeLint += getTimestamp() - timestamp; @@ -839,7 +846,6 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const { sourceModule.root = parseResult.root; sourceModule.mode = parseMode(parseResult.hotcomments); - sourceModule.ignoreLints = LintWarning::parseMask(parseResult.hotcomments); } else { @@ -848,8 +854,13 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const } sourceModule.name = name; + if (parseOptions.captureComments) + { sourceModule.commentLocations = std::move(parseResult.commentLocations); + sourceModule.hotcomments = std::move(parseResult.hotcomments); + } + return sourceModule; } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 8dd597e1..ec399158 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -150,10 +150,21 @@ struct AstJsonEncoder : public AstVisitor { writeRaw(std::to_string(i)); } + void write(std::nullptr_t) + { + writeRaw("null"); + } void write(std::string_view str) { writeString(str); } + void write(std::optional name) + { + if (name) + write(*name); + else + writeRaw("null"); + } void write(AstName name) { writeString(name.value ? name.value : ""); @@ -177,7 +188,16 @@ struct AstJsonEncoder : public AstVisitor void write(AstLocal* local) { - write(local->name); + writeRaw("{"); + bool c = pushComma(); + if (local->annotation != nullptr) + write("type", local->annotation); + else + write("type", nullptr); + write("name", local->name); + write("location", local->location); + popComma(c); + writeRaw("}"); } void writeNode(AstNode* node) @@ -314,7 +334,7 @@ struct AstJsonEncoder : public AstVisitor if (node->self) PROP(self); PROP(args); - if (node->hasReturnAnnotation) + if (node->returnAnnotation) PROP(returnAnnotation); PROP(vararg); PROP(varargLocation); @@ -328,6 +348,14 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(const std::optional& typeList) + { + if (typeList) + write(*typeList); + else + writeRaw("null"); + } + void write(const AstTypeList& typeList) { writeRaw("{"); @@ -531,7 +559,7 @@ struct AstJsonEncoder : public AstVisitor PROP(thenbody); if (node->elsebody) PROP(elsebody); - PROP(hasThen); + write("hasThen", node->thenLocation.has_value()); PROP(hasEnd); }); } @@ -715,7 +743,7 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { - if (node->hasPrefix) + if (node->prefix) PROP(prefix); PROP(name); PROP(parameters); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 2ba6a0fc..8d7d2d97 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) + namespace Luau { @@ -44,6 +46,7 @@ static const char* kWarningNames[] = { "TableOperations", "DuplicateCondition", "MisleadingAndOr", + "CommentDirective", }; // clang-format on @@ -732,13 +735,13 @@ private: bool visit(AstTypeReference* node) override { - if (!node->hasPrefix) + if (!node->prefix) return true; - if (!imports.contains(node->prefix)) + if (!imports.contains(*node->prefix)) return true; - AstLocal* astLocal = imports[node->prefix]; + AstLocal* astLocal = imports[*node->prefix]; Local& local = locals[astLocal]; LUAU_ASSERT(local.import); local.used = true; @@ -2527,13 +2530,108 @@ static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, } } +static const char* fuzzyMatch(std::string_view str, const char** array, size_t size) +{ + if (FInt::LuauSuggestionDistance == 0) + return nullptr; + + size_t bestDistance = FInt::LuauSuggestionDistance; + size_t bestMatch = size; + + for (size_t i = 0; i < size; ++i) + { + size_t ed = editDistance(str, array[i]); + + if (ed <= bestDistance) + { + bestDistance = ed; + bestMatch = i; + } + } + + return bestMatch < size ? array[bestMatch] : nullptr; +} + +static void lintComments(LintContext& context, const std::vector& hotcomments) +{ + bool seenMode = false; + + for (const HotComment& hc : hotcomments) + { + // We reserve --! for various informational (non-directive) comments + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (!hc.header) + { + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive is ignored because it is placed after the first non-comment token"); + } + else + { + std::string::size_type space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "nolint") + { + std::string::size_type notspace = hc.content.find_first_not_of(" \t", space); + + if (space == std::string::npos || notspace == std::string::npos) + { + // disables all lints + } + else if (LintWarning::parseName(hc.content.c_str() + notspace) == LintWarning::Code_Unknown) + { + const char* rule = hc.content.c_str() + notspace; + + // skip Unknown + if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1)) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion); + else + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule); + } + } + else if (first == "nocheck" || first == "nonstrict" || first == "strict") + { + if (space != std::string::npos) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive with the type checking mode has extra symbols at the end of the line"); + else if (seenMode) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive with the type checking mode has already been used"); + else + seenMode = true; + } + else + { + static const char* kHotComments[] = { + "nolint", + "nocheck", + "nonstrict", + "strict", + }; + + if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?", + int(first.size()), first.data(), suggestion); + else + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), + first.data()); + } + } + } +} + void LintOptions::setDefaults() { // By default, we enable all warnings warningMask = ~0ull; } -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options) +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, + const std::vector& hotcomments, const LintOptions& options) { LintContext context; @@ -2609,6 +2707,9 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_MisleadingAndOr)) LintMisleadingAndOr::process(context); + if (context.warningEnabled(LintWarning::Code_CommentDirective)) + lintComments(context, hotcomments); + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; @@ -2630,23 +2731,30 @@ LintWarning::Code LintWarning::parseName(const char* name) return Code_Unknown; } -uint64_t LintWarning::parseMask(const std::vector& hotcomments) +uint64_t LintWarning::parseMask(const std::vector& hotcomments) { uint64_t result = 0; - for (const std::string& hc : hotcomments) + for (const HotComment& hc : hotcomments) { - if (hc.compare(0, 6, "nolint") != 0) + if (!hc.header) continue; - std::string::size_type name = hc.find_first_not_of(" \t", 6); + if (hc.content.compare(0, 6, "nolint") != 0) + continue; + + std::string::size_type name = hc.content.find_first_not_of(" \t", 6); // --!nolint disables everything if (name == std::string::npos) return ~0ull; + // --!nolint needs to be followed by a whitespace character + if (name == 6) + continue; + // --!nolint name disables the specific lint - LintWarning::Code code = LintWarning::parseName(hc.c_str() + name); + LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); if (code != LintWarning::Code_Unknown) result |= 1ull << int(code); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 04ebffc1..94e169f1 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -9,14 +9,12 @@ namespace Luau struct Quantifier { - ModulePtr module; TypeLevel level; std::vector generics; std::vector genericPacks; - Quantifier(ModulePtr module, TypeLevel level) - : module(module) - , level(level) + Quantifier(TypeLevel level) + : level(level) { } @@ -76,9 +74,9 @@ struct Quantifier } }; -void quantify(ModulePtr module, TypeId ty, TypeLevel level) +void quantify(TypeId ty, TypeLevel level) { - Quantifier q{std::move(module), level}; + Quantifier q{level}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, q, seen); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index bacbca76..770c7a47 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -226,27 +226,11 @@ TarjanResult Tarjan::loop() return TarjanResult::Ok; } -void Tarjan::clear() -{ - typeToIndex.clear(); - indexToType.clear(); - packToIndex.clear(); - indexToPack.clear(); - lowlink.clear(); - stack.clear(); - onStack.clear(); - - edgesTy.clear(); - edgesTp.clear(); - worklist.clear(); -} - TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; ty = log->follow(ty); - clear(); auto [index, fresh] = indexify(ty); worklist.push_back({index, -1, -1}); return loop(); @@ -257,7 +241,6 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) childCount = 0; tp = log->follow(tp); - clear(); auto [index, fresh] = indexify(tp); worklist.push_back({index, -1, -1}); return loop(); @@ -314,21 +297,17 @@ void FindDirty::visitSCC(int index) TarjanResult FindDirty::findDirty(TypeId ty) { - dirty.clear(); return visitRoot(ty); } TarjanResult FindDirty::findDirty(TypePackId tp) { - dirty.clear(); return visitRoot(tp); } std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); - newTypes.clear(); - newPacks.clear(); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -347,8 +326,6 @@ std::optional Substitution::substitute(TypeId ty) std::optional Substitution::substitute(TypePackId tp) { tp = log->follow(tp); - newTypes.clear(); - newPacks.clear(); auto result = findDirty(tp); if (result != TarjanResult::Ok) diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index dba694be..678001bf 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -26,9 +26,10 @@ * 3. Cyclic dependencies can be resolved by picking an arbitrary statement to check first. */ -#include "Luau/Parser.h" +#include "Luau/Ast.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include "Luau/StringUtils.h" #include #include diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index f5908683..54bd0d5e 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -933,12 +933,12 @@ struct Printer writer.symbol(")"); - if (writeTypes && func.hasReturnAnnotation) + if (writeTypes && func.returnAnnotation) { writer.symbol(":"); writer.space(); - visualizeTypeList(func.returnAnnotation, false); + visualizeTypeList(*func.returnAnnotation, false); } visualizeBlock(*func.body); @@ -989,9 +989,9 @@ struct Printer advance(typeAnnotation.location.begin); if (const auto& a = typeAnnotation.as()) { - if (a->hasPrefix) + if (a->prefix) { - writer.write(a->prefix.value); + writer.write(a->prefix->value); writer.symbol("."); } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 2208213f..d575e023 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -3,7 +3,6 @@ #include "Luau/Error.h" #include "Luau/Module.h" -#include "Luau/Parser.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/ToString.h" @@ -476,12 +475,11 @@ public: visitLocal(arg); } - if (!fn->hasReturnAnnotation) + if (!fn->returnAnnotation) { if (auto result = getScope(fn->body->location)) { TypePackId ret = result->returnType; - fn->hasReturnAnnotation = true; AstTypePack* variadicAnnotation = nullptr; const auto& [v, tail] = flatten(ret); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f1c314cd..c29699b7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,7 +3,6 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" -#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" @@ -24,16 +23,12 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) -LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) -LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) @@ -43,7 +38,6 @@ LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAG(LuauUnionTagMatchFix) @@ -293,13 +287,10 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); } - if (FFlag::LuauPerModuleUnificationCache) - { - // Clear unifier cache since it's keyed off internal types that get deallocated - // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. - unifierState.cachedUnify.clear(); - unifierState.skipCacheForType.clear(); - } + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.skipCacheForType.clear(); if (FFlag::LuauTwoPassAliasDefinitionFix) duplicateTypeAliases.clear(); @@ -509,7 +500,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == Parser::errorName) + if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == kParseNameError) continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -1193,7 +1184,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (FFlag::LuauTwoPassAliasDefinitionFix && name == Parser::errorName) + if (FFlag::LuauTwoPassAliasDefinitionFix && name == kParseNameError) return; std::optional binding; @@ -1222,7 +1213,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauProperTypeLevels) aliasScope->level.subLevel = subLevel; - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); + auto [generics, genericPacks] = + createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); @@ -1464,7 +1456,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& ExprResult result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt); + result = checkExpr(scope, *a->expr, expectedType); else if (expr.is()) result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) @@ -1508,7 +1500,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt); + result = checkExpr(scope, *a, expectedType); else ice("Unhandled AstExpr?"); @@ -2093,6 +2085,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn return {numberType}; } case AstExprUnary::Len: + { tablify(operandType); operandType = stripFromNilAndReport(operandType, expr.location); @@ -2100,30 +2093,13 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn if (get(operandType)) return {errorRecoveryType(scope)}; - if (FFlag::LuauLengthOnCompositeType) - { - DenseHashSet seen{nullptr}; + DenseHashSet seen{nullptr}; - if (!hasLength(operandType, seen, &recursionCount)) - reportError(TypeError{expr.location, NotATable{operandType}}); - } - else - { - if (get(operandType)) - return {numberType}; // Not strictly correct: metatables permit overriding this - - if (auto p = get(operandType)) - { - if (p->type == PrimitiveTypeVar::String) - return {numberType}; - } - - if (!getTableType(operandType)) - reportError(TypeError{expr.location, NotATable{operandType}}); - } + if (!hasLength(operandType, seen, &recursionCount)) + reportError(TypeError{expr.location, NotATable{operandType}}); return {numberType}; - + } default: ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); } @@ -2618,22 +2594,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - if (FFlag::LuauIfElseBranchTypeUnion) - { - if (falseType.type == trueType.type) - return {trueType.type}; - - std::vector types = reduceUnion({trueType.type, falseType.type}); - return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; - } - else - { - unify(falseType.type, trueType.type, expr.location); - - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. + if (falseType.type == trueType.type) return {trueType.type}; - } + + std::vector types = reduceUnion({trueType.type, falseType.type}); + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) @@ -2986,8 +2951,8 @@ std::pair TypeChecker::checkFunctionSignature( auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); TypePackId retPack; - if (expr.hasReturnAnnotation) - retPack = resolveTypePack(funScope, expr.returnAnnotation); + if (expr.returnAnnotation) + retPack = resolveTypePack(funScope, *expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) @@ -3181,7 +3146,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. - if (!isNonstrictMode() || function.hasReturnAnnotation) + if (!isNonstrictMode() || function.returnAnnotation) { reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); } @@ -4403,11 +4368,7 @@ TypeId Instantiation::clean(TypeId ty) // Annoyingly, we have to do this even if there are no generics, // to replace any generic tables. - replaceGenerics.log = log; - replaceGenerics.level = level; - replaceGenerics.currentModule = currentModule; - replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); - replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); + ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; // TODO: What to do if this returns nullopt? // We don't have access to the error-reporting machinery @@ -4573,16 +4534,11 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (FFlag::LuauQuantifyInPlace2) { - Luau::quantify(currentModule, ty, scope->level); + Luau::quantify(ty, scope->level); return ty; } - quantification.log = TxnLog::empty(); - quantification.level = scope->level; - quantification.generics.clear(); - quantification.genericPacks.clear(); - quantification.currentModule = currentModule; - + Quantification quantification{¤tModule->internalTypes, scope->level}; std::optional qty = quantification.substitute(ty); if (!qty.has_value()) @@ -4596,18 +4552,14 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location FunctionTypeVar* qftv = getMutable(*qty); LUAU_ASSERT(qftv); - qftv->generics = quantification.generics; - qftv->genericPacks = quantification.genericPacks; + qftv->generics = std::move(quantification.generics); + qftv->genericPacks = std::move(quantification.genericPacks); return *qty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - LUAU_ASSERT(log != nullptr); - - instantiation.log = FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(); - instantiation.level = scope->level; - instantiation.currentModule = currentModule; + Instantiation instantiation{FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(), ¤tModule->internalTypes, scope->level}; std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; @@ -4620,10 +4572,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - anyification.log = TxnLog::empty(); - anyification.anyType = anyType; - anyification.anyTypePack = anyTypePack; - anyification.currentModule = currentModule; + Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4636,10 +4585,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - anyification.log = TxnLog::empty(); - anyification.anyType = anyType; - anyification.anyTypePack = anyTypePack; - anyification.currentModule = currentModule; + Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4823,7 +4769,8 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return getSingletonTypes().errorRecoveryTypePack(guess); } -TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) { +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) +{ return [this, sense](TypeId ty) -> std::optional { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) @@ -4904,8 +4851,8 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& lit = annotation.as()) { std::optional tf; - if (lit->hasPrefix) - tf = scope->lookupImportedType(lit->prefix.value, lit->name.value); + if (lit->prefix) + tf = scope->lookupImportedType(lit->prefix->value, lit->name.value); else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_ice") ice("_luau_ice encountered", lit->location); @@ -4932,12 +4879,12 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { - if (lit->name == Parser::errorName) + if (lit->name == kParseNameError) return errorRecoveryType(scope); std::string typeName; - if (lit->hasPrefix) - typeName = std::string(lit->prefix.value) + "."; + if (lit->prefix) + typeName = std::string(lit->prefix->value) + "."; typeName += lit->name.value; if (scope->lookupPack(typeName)) @@ -5038,12 +4985,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types - applyTypeFunction.log = TxnLog::empty(); - applyTypeFunction.typeArguments.clear(); - applyTypeFunction.typePackArguments.clear(); - applyTypeFunction.currentModule = currentModule; - applyTypeFunction.level = scope->level; - applyTypeFunction.encounteredForwardedType = false; + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; for (size_t i = 0; i < typesProvided; ++i) applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; @@ -5362,18 +5304,14 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; - applyTypeFunction.typeArguments.clear(); + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; + for (size_t i = 0; i < tf.typeParams.size(); ++i) applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; - applyTypeFunction.typePackArguments.clear(); for (size_t i = 0; i < tf.typePackParams.size(); ++i) applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; - applyTypeFunction.log = TxnLog::empty(); - applyTypeFunction.currentModule = currentModule; - applyTypeFunction.level = scope->level; - applyTypeFunction.encounteredForwardedType = false; std::optional maybeInstantiated = applyTypeFunction.substitute(tf.type); if (!maybeInstantiated.has_value()) { diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 8c6d5e49..593b54c8 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,6 +5,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) + namespace Luau { @@ -48,9 +50,19 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc } std::optional mtIndex = findMetatableEntry(errors, globalScope, ty, "__index", location); + int count = 0; while (mtIndex) { TypeId index = follow(*mtIndex); + + if (FFlag::LuauTerminateCyclicMetatableIndexLookup) + { + if (count >= 100) + return std::nullopt; + + ++count; + } + if (const auto& itt = getTableType(index)) { const auto& fit = itt->props.find(name); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 7e438e31..b2358c27 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,8 +23,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauLengthOnCompositeType) -LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauUnionTagMatchFix) @@ -385,8 +383,6 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - LUAU_ASSERT(FFlag::LuauLengthOnCompositeType); - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); ty = follow(ty); @@ -555,7 +551,7 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { - if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs)) + if (areSeen(seen, &lhs, &rhs)) return true; return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index f037351e..c7f31822 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -7,6 +7,9 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include const size_t kPageSize = 4096; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index a8ad5159..322f6ebf 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,7 +14,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); @@ -23,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); -LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) @@ -116,7 +114,7 @@ struct PromoteTypeLevels { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauCommittingTxnLogFreeTpPromote && FFlag::LuauUseCommittingTxnLog && !log.is(tp)) + if (FFlag::LuauUseCommittingTxnLog && !log.is(tp)) return true; promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); @@ -1242,7 +1240,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + if (subTpv->tail && superTpv->tail) { tryUnify_(*subTpv->tail, *superTpv->tail); break; @@ -1250,9 +1248,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) - tryUnify_(*subTpv->tail, *superTpv->tail); - else if (lFreeTail) + if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); else if (rFreeTail) tryUnify_(emptyTp, *subTpv->tail); @@ -1448,7 +1444,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + if (subTpv->tail && superTpv->tail) { tryUnify_(*subTpv->tail, *superTpv->tail); break; @@ -1456,9 +1452,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) - tryUnify_(*subTpv->tail, *superTpv->tail); - else if (lFreeTail) + if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); else if (rFreeTail) tryUnify_(emptyTp, *subTpv->tail); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index ac5950c0..31cd01cc 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -594,8 +594,7 @@ public: AstArray genericPacks; AstLocal* self; AstArray args; - bool hasReturnAnnotation; - AstTypeList returnAnnotation; + std::optional returnAnnotation; bool vararg = false; Location varargLocation; AstTypePack* varargAnnotation; @@ -740,7 +739,7 @@ class AstStatIf : public AstStat public: LUAU_RTTI(AstStatIf) - AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, const Location& thenLocation, + AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd); void visit(AstVisitor* visitor) override; @@ -749,12 +748,10 @@ public: AstStatBlock* thenbody; AstStat* elsebody; - bool hasThen = false; - Location thenLocation; + std::optional thenLocation; // Active for 'elseif' as well - bool hasElse = false; - Location elseLocation; + std::optional elseLocation; bool hasEnd = false; }; @@ -849,8 +846,7 @@ public: AstArray vars; AstArray values; - bool hasEqualsSign = false; - Location equalsSignLocation; + std::optional equalsSignLocation; }; class AstStatFor : public AstStat @@ -1053,9 +1049,8 @@ public: void visit(AstVisitor* visitor) override; - bool hasPrefix; bool hasParameterList; - AstName prefix; + std::optional prefix; AstName name; AstArray parameters; }; diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 460ef056..d7d867f4 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -233,4 +233,9 @@ private: bool readNames; }; +inline bool isSpace(char ch) +{ + return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; +} + } // namespace Luau diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h new file mode 100644 index 00000000..17ce2e3b --- /dev/null +++ b/Ast/include/Luau/ParseResult.h @@ -0,0 +1,69 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Location.h" +#include "Luau/Lexer.h" +#include "Luau/StringUtils.h" + +namespace Luau +{ + +class AstStatBlock; + +class ParseError : public std::exception +{ +public: + ParseError(const Location& location, const std::string& message); + + virtual const char* what() const throw(); + + const Location& getLocation() const; + const std::string& getMessage() const; + + static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + +private: + Location location; + std::string message; +}; + +class ParseErrors : public std::exception +{ +public: + ParseErrors(std::vector errors); + + virtual const char* what() const throw(); + + const std::vector& getErrors() const; + +private: + std::vector errors; + std::string message; +}; + +struct HotComment +{ + bool header; + Location location; + std::string content; +}; + +struct Comment +{ + Lexeme::Type type; // Comment, BlockComment, or BrokenComment + Location location; +}; + +struct ParseResult +{ + AstStatBlock* root; + std::vector hotcomments; + std::vector errors; + + std::vector commentLocations; +}; + +static constexpr const char* kParseNameError = "%error-id%"; + +} // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 40ecdcdd..4b5ae315 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/Lexer.h" #include "Luau/ParseOptions.h" +#include "Luau/ParseResult.h" #include "Luau/StringUtils.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" @@ -14,37 +15,6 @@ namespace Luau { -class ParseError : public std::exception -{ -public: - ParseError(const Location& location, const std::string& message); - - virtual const char* what() const throw(); - - const Location& getLocation() const; - const std::string& getMessage() const; - - static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); - -private: - Location location; - std::string message; -}; - -class ParseErrors : public std::exception -{ -public: - ParseErrors(std::vector errors); - - virtual const char* what() const throw(); - - const std::vector& getErrors() const; - -private: - std::vector errors; - std::string message; -}; - template class TempVector { @@ -80,34 +50,17 @@ private: size_t size_; }; -struct Comment -{ - Lexeme::Type type; // Comment, BlockComment, or BrokenComment - Location location; -}; - -struct ParseResult -{ - AstStatBlock* root; - std::vector hotcomments; - std::vector errors; - - std::vector commentLocations; -}; - class Parser { public: static ParseResult parse( const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); - static constexpr const char* errorName = "%error-id%"; - private: struct Name; struct Binding; - Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator); + Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options); bool blockFollow(const Lexeme& l); @@ -330,7 +283,7 @@ private: AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); - const Lexeme& nextLexeme(); + void nextLexeme(); struct Function { @@ -386,6 +339,9 @@ private: Allocator& allocator; std::vector commentLocations; + std::vector hotcomments; + + bool hotcommentHeader = true; unsigned int recursionCounter; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 9b5bc0c7..24a280da 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -167,8 +167,7 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArrayreturnAnnotation = *returnAnnotation; } void AstExprFunction::visit(AstVisitor* visitor) @@ -195,8 +192,8 @@ void AstExprFunction::visit(AstVisitor* visitor) if (varargAnnotation) varargAnnotation->visit(visitor); - if (hasReturnAnnotation) - visitTypeList(visitor, returnAnnotation); + if (returnAnnotation) + visitTypeList(visitor, *returnAnnotation); body->visit(visitor); } @@ -375,21 +372,16 @@ void AstStatBlock::visit(AstVisitor* visitor) } } -AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, - const Location& thenLocation, const std::optional& elseLocation, bool hasEnd) +AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, + const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) - , hasThen(hasThen) , thenLocation(thenLocation) + , elseLocation(elseLocation) , hasEnd(hasEnd) { - if (bool(elseLocation)) - { - hasElse = true; - this->elseLocation = *elseLocation; - } } void AstStatIf::visit(AstVisitor* visitor) @@ -492,12 +484,8 @@ AstStatLocal::AstStatLocal( : AstStat(ClassIndex(), location) , vars(vars) , values(values) + , equalsSignLocation(equalsSignLocation) { - if (bool(equalsSignLocation)) - { - hasEqualsSign = true; - this->equalsSignLocation = *equalsSignLocation; - } } void AstStatLocal::visit(AstVisitor* visitor) @@ -750,9 +738,8 @@ void AstStatError::visit(AstVisitor* visitor) AstTypeReference::AstTypeReference( const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) - , hasPrefix(bool(prefix)) , hasParameterList(hasParameterList) - , prefix(prefix ? *prefix : AstName()) + , prefix(prefix) , name(name) , parameters(parameters) { diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a7aa24ca..d56c8860 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -101,11 +101,6 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); } -static bool isComment(const Lexeme& lexeme) -{ - return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; -} - static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; @@ -282,11 +277,6 @@ AstName AstNameTable::get(const char* name) const return getWithType(name, strlen(name)).first; } -inline bool isSpace(char ch) -{ - return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; -} - inline bool isAlpha(char ch) { // use or trick to convert to lower case and unsigned comparison to do range check @@ -372,7 +362,7 @@ const Lexeme& Lexer::next(bool skipComments) prevLocation = lexeme.location; lexeme = readNext(); - } while (skipComments && isComment(lexeme)); + } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 30b32f91..235d6349 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,18 +13,15 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) -LUAU_FASTFLAGVARIABLE(LuauStartingBrokenComment, false) +LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) +LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) namespace Luau { -inline bool isSpace(char ch) -{ - return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; -} - static bool isComment(const Lexeme& lexeme) { + LUAU_ASSERT(!FFlag::LuauParseAllHotComments); return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; } @@ -151,31 +148,37 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); - Parser p(buffer, bufferSize, names, allocator); + Parser p(buffer, bufferSize, names, allocator, FFlag::LuauParseAllHotComments ? options : ParseOptions()); try { - std::vector hotcomments; - - while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) + if (FFlag::LuauParseAllHotComments) { - const char* text = p.lexer.current().data; - unsigned int length = p.lexer.current().length; + AstStatBlock* root = p.parseChunk(); - if (length && text[0] == '!') + return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + } + else + { + std::vector hotcomments; + + while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) { - unsigned int end = length; - while (end > 0 && isSpace(text[end - 1])) - --end; + const char* text = p.lexer.current().data; + unsigned int length = p.lexer.current().length; - hotcomments.push_back(std::string(text + 1, text + end)); - } + if (length && text[0] == '!') + { + unsigned int end = length; + while (end > 0 && isSpace(text[end - 1])) + --end; - const Lexeme::Type type = p.lexer.current().type; - const Location loc = p.lexer.current().location; + hotcomments.push_back({true, p.lexer.current().location, std::string(text + 1, text + end)}); + } + + const Lexeme::Type type = p.lexer.current().type; + const Location loc = p.lexer.current().location; - if (FFlag::LuauStartingBrokenComment) - { if (options.captureComments) p.commentLocations.push_back(Comment{type, loc}); @@ -184,22 +187,15 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n p.lexer.next(); } - else - { - p.lexer.next(); - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); - } + p.lexer.setSkipComments(true); + + p.options = options; + + AstStatBlock* root = p.parseChunk(); + + return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; } - - p.lexer.setSkipComments(true); - - p.options = options; - - AstStatBlock* root = p.parseChunk(); - - return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; } catch (ParseError& err) { @@ -210,8 +206,9 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n } } -Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator) - : lexer(buffer, bufferSize, names) +Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options) + : options(options) + , lexer(buffer, bufferSize, names) , allocator(allocator) , recursionCounter(0) , endMismatchSuspect(Location(), Lexeme::Eof) @@ -224,14 +221,20 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc nameSelf = names.addStatic("self"); nameNumber = names.addStatic("number"); - nameError = names.addStatic(errorName); + nameError = names.addStatic(kParseNameError); nameNil = names.getOrAdd("nil"); // nil is a reserved keyword matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); matchRecoveryStopOnToken[Lexeme::Type::Eof] = 1; + if (FFlag::LuauParseAllHotComments) + lexer.setSkipComments(true); + // read first lexeme nextLexeme(); + + // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode + hotcommentHeader = false; } bool Parser::blockFollow(const Lexeme& l) @@ -396,7 +399,9 @@ AstStat* Parser::parseIf() AstExpr* cond = parseExpr(); Lexeme matchThen = lexer.current(); - bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if statement"); + std::optional thenLocation; + if (expectAndConsume(Lexeme::ReservedThen, "if statement")) + thenLocation = matchThen.location; AstStatBlock* thenbody = parseBlock(); @@ -434,7 +439,7 @@ AstStat* Parser::parseIf() hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, hasThen, matchThen.location, elseLocation, hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, hasEnd); } // while exp do block end @@ -769,7 +774,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // note: `type` token is already parsed for us, so we just need to parse the rest - auto name = parseNameOpt("type name"); + std::optional name = parseNameOpt("type name"); // Use error name if the name is missing if (!name) @@ -925,7 +930,7 @@ AstStat* Parser::parseDeclaration(const Location& start) return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); } - else if (auto globalName = parseNameOpt("global variable name")) + else if (std::optional globalName = parseNameOpt("global variable name")) { expectAndConsume(':', "global variable declaration"); @@ -1066,7 +1071,7 @@ void Parser::parseExprList(TempVector& result) Parser::Binding Parser::parseBinding() { - auto name = parseNameOpt("variable name"); + std::optional name = parseNameOpt("variable name"); // Use placeholder if the name is missing if (!name) @@ -1325,7 +1330,7 @@ AstType* Parser::parseTableTypeAnnotation() } else { - auto name = parseNameOpt("table field"); + std::optional name = parseNameOpt("table field"); if (!name) break; @@ -1422,7 +1427,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); @@ -1869,7 +1874,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) // NAME AstExpr* Parser::parseNameExpr(const char* context) { - auto name = parseNameOpt(context); + std::optional name = parseNameOpt(context); if (!name) return allocator.alloc(lexer.current().location, copy({}), unsigned(parseErrors.size() - 1)); @@ -2233,6 +2238,12 @@ AstExpr* Parser::parseTableConstructor() AstExpr* key = allocator.alloc(name.location, nameString); AstExpr* value = parseExpr(); + if (FFlag::LuauTableFieldFunctionDebugname) + { + if (AstExprFunction* func = value->as()) + func->debugname = name.name; + } + items.push_back({AstExprTable::Item::Record, key, value}); } else @@ -2313,7 +2324,7 @@ std::optional Parser::parseNameOpt(const char* context) Parser::Name Parser::parseName(const char* context) { - if (auto name = parseNameOpt(context)) + if (std::optional name = parseNameOpt(context)) return *name; Location location = lexer.current().location; @@ -2324,7 +2335,7 @@ Parser::Name Parser::parseName(const char* context) Parser::Name Parser::parseIndexName(const char* context, const Position& previous) { - if (auto name = parseNameOpt(context)) + if (std::optional name = parseNameOpt(context)) return *name; // If we have a reserved keyword next at the same line, assume it's an incomplete name @@ -2379,7 +2390,7 @@ std::pair, AstArray> Parser::parseG if (shouldParseTypePackAnnotation(lexer)) { - auto typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePackAnnotation(); namePacks.push_back({name, nameLocation, typePack}); } @@ -2451,7 +2462,7 @@ AstArray Parser::parseTypeParams() { if (shouldParseTypePackAnnotation(lexer)) { - auto typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePackAnnotation(); parameters.push_back({{}, typePack}); } @@ -2821,25 +2832,57 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); } -const Lexeme& Parser::nextLexeme() +void Parser::nextLexeme() { if (options.captureComments) { - while (true) + if (FFlag::LuauParseAllHotComments) { - const Lexeme& lexeme = lexer.next(/*skipComments*/ false); - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) + Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + + while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) + { + const Lexeme& lexeme = lexer.current(); commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - if (isComment(lexeme)) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - else - return lexeme; + + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + return; + + // Comments starting with ! are called "hot comments" and contain directives for type checking / linting + if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + { + const char* text = lexeme.data; + + unsigned int end = lexeme.length; + while (end > 0 && isSpace(text[end - 1])) + --end; + + hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); + } + + type = lexer.next(/* skipComments= */ false).type; + } + } + else + { + while (true) + { + const Lexeme& lexeme = lexer.next(/*skipComments*/ false); + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + if (isComment(lexeme)) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + else + return; + } } } else - return lexer.next(); + lexer.next(); } } // namespace Luau diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index ded50e53..8079830b 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -9,6 +9,12 @@ #include #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include #endif diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fe005aec..fb6ac373 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -4,8 +4,13 @@ #include "Luau/Common.h" #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN -#include +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include #else #include #include diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 9a6e25c2..13304d57 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -37,6 +37,8 @@ enum class CompileFormat Binary }; +constexpr int MaxTraversalLimit = 50; + struct GlobalOptions { int optimizationLevel = 1; @@ -243,10 +245,115 @@ std::string runCode(lua_State* L, const std::string& source) return std::string(); } +// Replaces the top of the lua stack with the metatable __index for the value +// if it exists. Returns true iff __index exists. +static bool tryReplaceTopWithIndex(lua_State* L) +{ + if (luaL_getmetafield(L, -1, "__index")) + { + // Remove the table leaving __index on the top of stack + lua_remove(L, -2); + return true; + } + return false; +} + + +// This function is similar to lua_gettable, but it avoids calling any +// lua callback functions (e.g. __index) which might modify the Lua VM state. +static void safeGetTable(lua_State* L, int tableIndex) +{ + lua_pushvalue(L, tableIndex); // Duplicate the table + + // The loop invariant is that the table to search is at -1 + // and the key is at -2. + for (int loopCount = 0;; loopCount++) + { + lua_pushvalue(L, -2); // Duplicate the key + lua_rawget(L, -2); // Try to find the key + if (!lua_isnil(L, -1) || loopCount >= MaxTraversalLimit) + { + // Either the key has been found, and/or we have reached the max traversal limit + break; + } + else + { + lua_pop(L, 1); // Pop the nil result + if (!luaL_getmetafield(L, -1, "__index")) + { + lua_pushnil(L); + break; + } + else if (lua_istable(L, -1)) + { + // Replace the current table being searched with __index table + lua_replace(L, -2); + } + else + { + lua_pop(L, 1); // Pop the value + lua_pushnil(L); + break; + } + } + } + + lua_remove(L, -2); // Remove the table + lua_remove(L, -2); // Remove the original key +} + +// completePartialMatches finds keys that match the specified 'prefix' +// Note: the table/object to be searched must be on the top of the Lua stack +static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix, + const AddCompletionCallback& addCompletionCallback) +{ + for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++) + { + // table, key + lua_pushnil(L); + + // Loop over all the keys in the current table + while (lua_next(L, -2) != 0) + { + if (lua_type(L, -2) == LUA_TSTRING) + { + // table, key, value + std::string_view key = lua_tostring(L, -2); + int valueType = lua_type(L, -1); + + // If the last separator was a ':' (i.e. a method call) then only functions should be completed. + bool requiredValueType = (!completeOnlyFunctions || valueType == LUA_TFUNCTION); + + if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) + { + std::string completedComponent(key.substr(prefix.size())); + std::string completion(editBuffer + completedComponent); + if (valueType == LUA_TFUNCTION) + { + // Add an opening paren for function calls by default. + completion += "("; + } + addCompletionCallback(completion, std::string(key)); + } + } + lua_pop(L, 1); + } + + // Replace the current table being searched with an __index table if one exists + if (!tryReplaceTopWithIndex(L)) + { + break; + } + } +} + static void completeIndexer(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) { std::string_view lookup = editBuffer; - char lastSep = 0; + bool completeOnlyFunctions = false; + + // Push the global variable table to begin the search + lua_pushvalue(L, LUA_GLOBALSINDEX); for (;;) { @@ -255,60 +362,26 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A if (sep == std::string_view::npos) { - // table, key - lua_pushnil(L); - - while (lua_next(L, -2) != 0) - { - if (lua_type(L, -2) == LUA_TSTRING) - { - // table, key, value - std::string_view key = lua_tostring(L, -2); - int valueType = lua_type(L, -1); - - // If the last separator was a ':' (i.e. a method call) then only functions should be completed. - bool requiredValueType = (lastSep != ':' || valueType == LUA_TFUNCTION); - - if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) - { - std::string completedComponent(key.substr(prefix.size())); - std::string completion(editBuffer + completedComponent); - if (valueType == LUA_TFUNCTION) - { - // Add an opening paren for function calls by default. - completion += "("; - } - addCompletionCallback(completion, std::string(key)); - } - } - lua_pop(L, 1); - } - + completePartialMatches(L, completeOnlyFunctions, editBuffer, prefix, addCompletionCallback); break; } else { // find the key in the table lua_pushlstring(L, prefix.data(), prefix.size()); - lua_rawget(L, -2); + safeGetTable(L, -2); lua_remove(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) + if (lua_istable(L, -1) || tryReplaceTopWithIndex(L)) { - // Replace the string object with the string class to perform further lookups of string functions - // Note: We retrieve the string class from _G to prevent issues if the user assigns to `string`. - lua_pop(L, 1); // Pop the string instance - lua_getglobal(L, "_G"); - lua_pushlstring(L, "string", 6); - lua_rawget(L, -2); - lua_remove(L, -2); // Remove the global table - LUAU_ASSERT(lua_istable(L, -1)); + completeOnlyFunctions = lookup[sep] == ':'; + lookup.remove_prefix(sep + 1); } - else if (!lua_istable(L, -1)) + else + { + // Unable to search for keys, so stop searching break; - - lastSep = lookup[sep]; - lookup.remove_prefix(sep + 1); + } } } @@ -317,12 +390,6 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) { - // look the value up in current global table first - lua_pushvalue(L, LUA_GLOBALSINDEX); - completeIndexer(L, editBuffer, addCompletionCallback); - - // and in actual global table after that - lua_getglobal(L, "_G"); completeIndexer(L, editBuffer, addCompletionCallback); } @@ -365,9 +432,7 @@ struct LinenoiseScopedHistory ic_set_history(historyFilepath.c_str(), -1 /* default entries (= 200) */); } - ~LinenoiseScopedHistory() - { - } + ~LinenoiseScopedHistory() {} std::string historyFilepath; }; diff --git a/CMakeLists.txt b/CMakeLists.txt index c19d2b40..c6ccebc5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,20 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +# embed .natvis inside the library debug information +if(MSVC) + target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) + target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis) + target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis) +endif() + +# make .natvis visible inside the solution +if(MSVC_IDE) + target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis) + target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis) + target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) +endif() + if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) diff --git a/Sources.cmake b/Sources.cmake index 773f6f35..615641eb 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -8,6 +8,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/Location.h Ast/include/Luau/ParseOptions.h Ast/include/Luau/Parser.h + Ast/include/Luau/ParseResult.h Ast/include/Luau/StringUtils.h Ast/include/Luau/TimeTrace.h diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 5cffba63..29d5f397 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -36,12 +36,9 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" static Table* getcurrenv(lua_State* L) { if (L->ci == L->base_ci) /* no enclosing function? */ - return hvalue(gt(L)); /* use global table as environment */ + return L->gt; /* use global table as environment */ else - { - Closure* func = curr_func(L); - return func->env; - } + return curr_func(L)->env; } static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) @@ -53,11 +50,14 @@ static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) return registry(L); case LUA_ENVIRONINDEX: { - sethvalue(L, &L->env, getcurrenv(L)); - return &L->env; + sethvalue(L, &L->global->pseudotemp, getcurrenv(L)); + return &L->global->pseudotemp; } case LUA_GLOBALSINDEX: - return gt(L); + { + sethvalue(L, &L->global->pseudotemp, L->gt); + return &L->global->pseudotemp; + } default: { Closure* func = curr_func(L); @@ -237,6 +237,11 @@ void lua_replace(lua_State* L, int idx) func->env = hvalue(L->top - 1); luaC_barrier(L, func, L->top - 1); } + else if (idx == LUA_GLOBALSINDEX) + { + api_check(L, ttistable(L->top - 1)); + L->gt = hvalue(L->top - 1); + } else { setobj(L, o, L->top - 1); @@ -783,7 +788,7 @@ void lua_getfenv(lua_State* L, int idx) sethvalue(L, L->top, clvalue(o)->env); break; case LUA_TTHREAD: - setobj2s(L, L->top, gt(thvalue(o))); + sethvalue(L, L->top, thvalue(o)->gt); break; default: setnilvalue(L->top); @@ -914,7 +919,7 @@ int lua_setfenv(lua_State* L, int idx) clvalue(o)->env = hvalue(L->top - 1); break; case LUA_TTHREAD: - sethvalue(L, gt(thvalue(o)), hvalue(L->top - 1)); + thvalue(o)->gt = hvalue(L->top - 1); break; default: res = 0; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index e9930f7a..a4f93c62 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -419,7 +419,7 @@ static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* con } const char* debugname = p->debugname ? getstr(p->debugname) : NULL; - int linedefined = luaG_getline(p, 0); + int linedefined = getlinedefined(p); callback(context, debugname, linedefined, depth, buffer, size); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index a3982bc6..d87f0661 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAG(LuauReduceStackReallocs) + /* ** {====================================================== ** Error-recovery functions @@ -164,13 +166,14 @@ static void correctstack(lua_State* L, TValue* oldstack) void luaD_reallocstack(lua_State* L, int newsize) { TValue* oldstack = L->stack; - int realsize = newsize + 1 + EXTRA_STACK; - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK - 1); + int realsize = newsize + (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK); + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat); + TValue* newstack = L->stack; for (int i = L->stacksize; i < realsize; i++) - setnilvalue(L->stack + i); /* erase new segment */ + setnilvalue(newstack + i); /* erase new segment */ L->stacksize = realsize; - L->stack_last = L->stack + newsize; + L->stack_last = newstack + newsize; correctstack(L, oldstack); } @@ -512,7 +515,7 @@ static void callerrfunc(lua_State* L, void* ud) static void restore_stack_limit(lua_State* L) { - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK - 1); + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); if (L->size_ci > LUAI_MAXCALLS) { /* there was an overflow? */ int inuse = cast_int(L->ci - L->base_ci); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 72807f0f..1c1480d6 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -11,7 +11,7 @@ if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ luaD_growstack(L, n); \ else \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK - 1)); + condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); #define incr_top(L) \ { \ diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 835572fa..724b24b2 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -268,7 +268,7 @@ static void traverseclosure(global_State* g, Closure* cl) static void traversestack(global_State* g, lua_State* l, bool clearstack) { - markvalue(g, gt(l)); + markobject(g, l->gt); if (l->namecall) stringmark(l->namecall); for (StkId o = l->stack; o < l->top; o++) @@ -643,7 +643,7 @@ static void markroot(lua_State* L) g->weak = NULL; markobject(g, g->mainthread); /* make global table be traversed before main stack */ - markvalue(g, gt(g->mainthread)); + markobject(g, g->mainthread->gt); markvalue(g, registry(L)); markmt(g); g->gcstate = GCSpropagate; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 528d0944..2acb5d8a 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -77,7 +77,7 @@ #define luaC_checkGC(L) \ { \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK - 1)); \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); \ if (L->global->totalbytes >= L->global->GCthreshold) \ { \ condhardmemtests(luaC_validate(L), 1); \ diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index ce196520..30242e52 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -88,7 +88,7 @@ static void validateclosure(global_State* g, Closure* cl) static void validatestack(global_State* g, lua_State* l) { - validateref(g, obj2gco(l), gt(l)); + validateobjref(g, obj2gco(l), obj2gco(l->gt)); for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) { @@ -370,6 +370,7 @@ static void dumpclosure(FILE* f, Closure* cl) fprintf(f, ",\"env\":"); dumpref(f, obj2gco(cl->env)); + if (cl->isC) { if (cl->nupvalues) @@ -411,11 +412,8 @@ static void dumpthread(FILE* f, lua_State* th) fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); - if (iscollectable(&th->l_gt)) - { - fprintf(f, ",\"env\":"); - dumpref(f, gcvalue(&th->l_gt)); - } + fprintf(f, ",\"env\":"); + dumpref(f, obj2gco(th->gt)); Closure* tcl = 0; for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) diff --git a/VM/src/lperf.cpp b/VM/src/lperf.cpp index 2f6c7297..da68e376 100644 --- a/VM/src/lperf.cpp +++ b/VM/src/lperf.cpp @@ -3,6 +3,12 @@ #include "lua.h" #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include #endif diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 6762c638..d6d127c0 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -11,6 +11,7 @@ #include "ldebug.h" LUAU_FASTFLAG(LuauGcPagedSweep) +LUAU_FASTFLAGVARIABLE(LuauReduceStackReallocs, false) /* ** Main thread combines a thread state and the global state @@ -31,10 +32,11 @@ static void stack_init(lua_State* L1, lua_State* L) /* initialize stack array */ L1->stack = luaM_newarray(L, BASIC_STACK_SIZE + EXTRA_STACK, TValue, L1->memcat); L1->stacksize = BASIC_STACK_SIZE + EXTRA_STACK; + TValue* stack = L1->stack; for (int i = 0; i < BASIC_STACK_SIZE + EXTRA_STACK; i++) - setnilvalue(L1->stack + i); /* erase new stack */ - L1->top = L1->stack; - L1->stack_last = L1->stack + (L1->stacksize - EXTRA_STACK) - 1; + setnilvalue(stack + i); /* erase new stack */ + L1->top = stack; + L1->stack_last = stack + (L1->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); /* initialize first ci */ L1->ci->func = L1->top; setnilvalue(L1->top++); /* `function' entry for this `ci' */ @@ -55,7 +57,7 @@ static void f_luaopen(lua_State* L, void* ud) { global_State* g = L->global; stack_init(L, L); /* init stack */ - sethvalue(L, gt(L), luaH_new(L, 0, 2)); /* table of globals */ + L->gt = luaH_new(L, 0, 2); /* table of globals */ sethvalue(L, registry(L), luaH_new(L, 0, 2)); /* registry */ luaS_resize(L, LUA_MINSTRTABSIZE); /* initial size of string table */ luaT_init(L); @@ -69,6 +71,7 @@ static void preinit_state(lua_State* L, global_State* g) L->global = g; L->stack = NULL; L->stacksize = 0; + L->gt = NULL; L->openupval = NULL; L->size_ci = 0; L->nCcalls = L->baseCcalls = 0; @@ -80,7 +83,6 @@ static void preinit_state(lua_State* L, global_State* g) L->stackstate = 0; L->activememcat = 0; L->userdata = NULL; - setnilvalue(gt(L)); } static void close_state(lua_State* L) @@ -116,7 +118,7 @@ lua_State* luaE_newthread(lua_State* L) preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category stack_init(L1, L); /* init stack */ - setobj2n(L, gt(L1), gt(L)); /* share table of globals */ + L1->gt = L->gt; /* share table of globals */ L1->singlestep = L->singlestep; LUAU_ASSERT(iswhite(obj2gco(L1))); return L1; @@ -144,14 +146,30 @@ void lua_resetthread(lua_State* L) ci->top = ci->base + LUA_MINSTACK; setnilvalue(ci->func); L->ci = ci; - luaD_reallocCI(L, BASIC_CI_SIZE); + if (FFlag::LuauReduceStackReallocs) + { + if (L->size_ci != BASIC_CI_SIZE) + luaD_reallocCI(L, BASIC_CI_SIZE); + } + else + { + luaD_reallocCI(L, BASIC_CI_SIZE); + } /* clear thread state */ L->status = LUA_OK; L->base = L->ci->base; L->top = L->ci->base; L->nCcalls = L->baseCcalls = 0; /* clear thread stack */ - luaD_reallocstack(L, BASIC_STACK_SIZE); + if (FFlag::LuauReduceStackReallocs) + { + if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) + luaD_reallocstack(L, BASIC_STACK_SIZE); + } + else + { + luaD_reallocstack(L, BASIC_STACK_SIZE); + } for (int i = 0; i < L->stacksize; i++) setnilvalue(L->stack + i); } @@ -193,6 +211,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->strt.size = 0; g->strt.nuse = 0; g->strt.hash = NULL; + setnilvalue(&g->pseudotemp); setnilvalue(registry(L)); g->gcstate = GCSpause; if (!FFlag::LuauGcPagedSweep) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 0708b71f..6dd89138 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -5,9 +5,6 @@ #include "lobject.h" #include "ltm.h" -/* table of globals */ -#define gt(L) (&L->l_gt) - /* registry */ #define registry(L) (&L->global->registry) @@ -177,6 +174,8 @@ typedef struct global_State TString* ttname[LUA_T_COUNT]; /* names for basic types */ TString* tmname[TM_N]; /* array with tag-method names */ + TValue pseudotemp; /* storage for temporary values used in pseudo2addr */ + TValue registry; /* registry table, used by lua_ref and LUA_REGISTRYINDEX */ int registryfree; /* next free slot in registry */ @@ -231,8 +230,7 @@ struct lua_State int cachedslot; /* when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? */ - TValue l_gt; /* table of globals */ - TValue env; /* temporary place for environments */ + Table* gt; /* table of globals */ UpVal* openupval; /* list of open upvalues in this stack */ GCObject* gclist; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c3b662a2..6c31d36f 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -39,7 +39,7 @@ // When calling luau_callTM, we usually push the arguments to the top of the stack. // This is safe to do for complicated reasons: -// - stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) +// - stack guarantees EXTRA_STACK room beyond stack_last (see luaD_reallocstack) // - stack reallocation copies values past stack_last // All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 2472cd90..4e5435b7 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -116,12 +116,12 @@ static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) // note: we call getimport with nil propagation which means that accesses to table chains like A.B.C will resolve in nil // this is technically not necessary but it reduces the number of exceptions when loading scripts that rely on getfenv/setfenv for global // injection - luaV_getimport(L, hvalue(gt(L)), self->k, self->id, /* propagatenil= */ true); + luaV_getimport(L, L->gt, self->k, self->id, /* propagatenil= */ true); } }; ResolveImport ri = {k, id}; - if (hvalue(gt(L))->safeenv) + if (L->gt->safeenv) { // luaD_pcall will make sure that if any C/Lua calls during import resolution fail, the thread state is restored back int oldTop = lua_gettop(L); @@ -171,7 +171,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size L->global->GCthreshold = SIZE_MAX; // env is 0 for current environment and a stack index otherwise - Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(luaA_toobject(L, env)); + Table* envt = (env == 0) ? L->gt : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 31dd59c8..8a18a4d4 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -55,7 +55,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 { ptrdiff_t result = savestack(L, res); // using stack room beyond top is technically safe here, but for very complicated reasons: - // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated + // * The stack guarantees EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers // * we cannot use savestack/restorestack because the arguments are sometimes on the C++ stack @@ -76,7 +76,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) { // using stack room beyond top is technically safe here, but for very complicated reasons: - // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated + // * The stack guarantees EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers // * we cannot use savestack/restorestack because the arguments are sometimes on the C++ stack diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 55e0888b..04638d23 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -32,7 +32,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::LintOptions lintOptions; lintOptions.warningMask = ~0ull; - Luau::lint(parseResult.root, names, typeck.globalScope, nullptr, lintOptions); + Luau::lint(parseResult.root, names, typeck.globalScope, nullptr, {}, lintOptions); } return 0; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 27e53492..f407248a 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -227,7 +227,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) if (kFuzzLinter) { Luau::LintOptions lintOptions = {~0u}; - Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), lintOptions); + Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), {}, lintOptions); } } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 59e12574..1978a0d3 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Autocomplete.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" @@ -2610,6 +2609,27 @@ a = if temp then even elseif true then temp else e@9 CHECK(ac.entryMap.count("elseif") == 0); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") +{ + ScopedFastFlag FFlagLuauIfElseExprFixCompletionIssue("LuauIfElseExprFixCompletionIssue", true); + check(R"( +local abcdef = 0; +local temp = false +local even = true; +local a +a = if temp then even else@1 +a = if temp then even else @2 +a = if temp then even else abc@3 + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("else") == 0); + ac = autocomplete('2'); + CHECK(ac.entryMap.count("else") == 0); + ac = autocomplete('3'); + CHECK(ac.entryMap.count("abcdef")); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") { check(R"( diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index cd7a21d8..f982c86f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,6 +10,11 @@ #include #include +namespace Luau +{ +std::string rep(const std::string& s, size_t n); +} + using namespace Luau; static std::string compileFunction(const char* source, uint32_t id) @@ -1960,15 +1965,6 @@ RETURN R8 -1 )"); } -static std::string rep(const std::string& s, size_t n) -{ - std::string r; - r.reserve(s.length() * n); - for (size_t i = 0; i < n; ++i) - r += s; - return r; -} - TEST_CASE("RecursionParse") { // The test forcibly pushes the stack limit during compilation; in NoOpt, the stack consumption is much larger so we need to reduce the limit to diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 8b58d2ce..b09c1efb 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -492,6 +492,8 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { + ScopedFastFlag luauTableFieldFunctionDebugname{"LuauTableFieldFunctionDebugname", true}; + runConformance("debug.lua"); } @@ -890,6 +892,12 @@ TEST_CASE("Coverage") lua_pushstring(L, function); lua_setfield(L, -2, "name"); + lua_pushinteger(L, linedefined); + lua_setfield(L, -2, "linedefined"); + + lua_pushinteger(L, depth); + lua_setfield(L, -2, "depth"); + for (size_t i = 0; i < size; ++i) if (hits[i] != -1) { diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index c74bfa27..dbdd06a4 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" #include "Luau/AstQuery.h" +#include "Luau/Parser.h" #include "Luau/TypeVar.h" #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" @@ -112,7 +113,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars sourceModule->name = fromString(mainModuleName); sourceModule->root = result.root; sourceModule->mode = parseMode(result.hotcomments); - sourceModule->ignoreLints = LintWarning::parseMask(result.hotcomments); + sourceModule->hotcomments = std::move(result.hotcomments); if (!result.errors.empty()) { @@ -157,6 +158,7 @@ CheckResult Fixture::check(const std::string& source) LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { ParseOptions parseOptions; + parseOptions.captureComments = true; configResolver.defaultConfig.mode = Mode::Nonstrict; parse(source, parseOptions); diff --git a/tests/Fixture.h b/tests/Fixture.h index ab852ef6..4e45a952 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -8,7 +8,6 @@ #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" -#include "Luau/Parser.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index ea1a08fe..8a59acd1 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -2,7 +2,6 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Frontend.h" -#include "Luau/Parser.h" #include "Luau/RequireTracer.h" #include "Fixture.h" @@ -897,8 +896,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "clearStats") TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") { - ScopedFastFlag sffs("LuauTypeCheckTwice", true); - fileResolver.source["Module/A"] = R"( local a = 1 )"; diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index 4a717275..cb508072 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -46,7 +46,7 @@ TEST_CASE("encode_AstStatBlock") AstStatBlock block{Location(), bodyArray}; CHECK_EQ( - (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":["a_local"],"values":[]}]})"), + (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"type":null,"name":"a_local","location":"0,0 - 0,0"}],"values":[]}]})"), toJson(&block)); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 577415fc..d4b97360 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1395,12 +1395,10 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") TypeId colorType = typeChecker.globalTypes.addType(TableTypeVar{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); - getMutable(colorType)->props = { - {"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"} } - }; + getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; addGlobalBinding(typeChecker, "Color3", Binding{colorType, {}}); - + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( @@ -1554,8 +1552,46 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored )"); REQUIRE_EQ(result.warnings.size(), 2); - CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; consider using if-then-else expression instead"); - CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; " + "consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; " + "consider using if-then-else expression instead"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongComment") +{ + ScopedFastFlag sff("LuauParseAllHotComments", true); + + LintResult result = lint(R"( +--!strict +--!struct +--!nolintGlobal +--!nolint Global +--!nolint KnownGlobal +--!nolint UnknownGlobal +--! no more lint +--!strict here +do end +--!nolint +)"); + + REQUIRE_EQ(result.warnings.size(), 6); + CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); + CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); + CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); + CHECK_EQ(result.warnings[3].text, "nolint directive refers to unknown lint rule 'KnownGlobal'; did you mean 'UnknownGlobal'?"); + CHECK_EQ(result.warnings[4].text, "Comment directive with the type checking mode has extra symbols at the end of the line"); + CHECK_EQ(result.warnings[5].text, "Comment directive is ignored because it is placed after the first non-comment token"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") +{ + LintResult result = lint(R"( +--!nolint +--!struct +)"); + + REQUIRE_EQ(result.warnings.size(), 0); // --!nolint disables WrongComment lint :) } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 931a8403..5bad9901 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c1a8887b..0d4c088d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -300,8 +299,9 @@ TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") AstStatFunction* statFunction = block->body.data[0]->as(); REQUIRE(statFunction != nullptr); - CHECK_EQ(statFunction->func->returnAnnotation.types.size, 1); - CHECK(statFunction->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunction->func->returnAnnotation.has_value()); + CHECK_EQ(statFunction->func->returnAnnotation->types.size, 1); + CHECK(statFunction->func->returnAnnotation->tailType == nullptr); } TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") @@ -316,9 +316,9 @@ TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 1); AstTypeFunction* funTy = retTypes.data[0]->as(); @@ -337,9 +337,9 @@ TEST_CASE_FIXTURE(Fixture, "function_return_type_should_disambiguate_from_functi AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 2); AstTypeReference* ty0 = retTypes.data[0]->as(); @@ -363,9 +363,9 @@ TEST_CASE_FIXTURE(Fixture, "function_return_type_should_parse_as_function_type_a AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 1); AstTypeFunction* funTy = retTypes.data[0]->as(); @@ -707,12 +707,25 @@ TEST_CASE_FIXTURE(Fixture, "mode_is_unset_if_no_hot_comment") TEST_CASE_FIXTURE(Fixture, "sense_hot_comment_on_first_line") { - ParseResult result = parseEx(" --!strict "); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx(" --!strict ", options); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); CHECK_EQ(int(*mode), int(Mode::Strict)); } +TEST_CASE_FIXTURE(Fixture, "non_header_hot_comments") +{ + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("do end --!strict", options); + std::optional mode = parseMode(result.hotcomments); + REQUIRE(!mode); +} + TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") { CHECK_THROWS_AS(parse(" -"), std::exception); @@ -720,7 +733,10 @@ TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") { - ParseResult result = parseEx("--!nonstrict"); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("--!nonstrict", options); CHECK(result.errors.empty()); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); @@ -729,7 +745,10 @@ TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") TEST_CASE_FIXTURE(Fixture, "nocheck_mode") { - ParseResult result = parseEx("--!nocheck"); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("--!nocheck", options); CHECK(result.errors.empty()); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); @@ -1498,8 +1517,6 @@ return TEST_CASE_FIXTURE(Fixture, "parse_error_broken_comment") { - ScopedFastFlag luauStartingBrokenComment{"LuauStartingBrokenComment", true}; - const char* expected = "Expected identifier when parsing expression, got unfinished comment"; matchParseError("--[[unfinished work", expected); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index 1f9c9739..87a1e1e2 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -73,7 +73,7 @@ public: private: std::unique_ptr luaState; - // This is a simplicitic and incomplete pretty printer. + // This is a simplistic and incomplete pretty printer. // It is included here to test that the pretty printer hook is being called. // More elaborate tests to ensure correct output can be added if we introduce // a more feature rich pretty printer. @@ -158,12 +158,25 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") myvariable1 = 5 myvariable2 = 5 )"); - CompletionSet completions = getCompletionSet("myvar"); + { + // Try to complete globals that are added by the user's script + CompletionSet completions = getCompletionSet("myvar"); - std::string prefix = ""; - CHECK(completions.size() == 2); - CHECK(checkCompletion(completions, prefix, "myvariable1")); - CHECK(checkCompletion(completions, prefix, "myvariable2")); + std::string prefix = ""; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "myvariable1")); + CHECK(checkCompletion(completions, prefix, "myvariable2")); + } + { + // Try completing some builtin functions + CompletionSet completions = getCompletionSet("math.m"); + + std::string prefix = "math."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "max(")); + CHECK(checkCompletion(completions, prefix, "min(")); + CHECK(checkCompletion(completions, prefix, "modf(")); + } } TEST_CASE_FIXTURE(ReplFixture, "CompleteTableKeys") @@ -206,4 +219,188 @@ TEST_CASE_FIXTURE(ReplFixture, "StringMethods") } } +TEST_CASE_FIXTURE(ReplFixture, "TableWithMetatableIndexTable") +{ + runCode(L, R"( + -- Create 't' which is a table with a metatable with an __index table + mt = {} + mt.__index = mt + + t = {} + setmetatable(t, mt) + + mt.mtkey1 = {x="x value", y="y value", 1, 2} + mt.mtkey2 = 2 + + t.tkey1 = {data1 = 2, data2 = "str", 3, 4} + t.tkey2 = 4 +)"); + { + CompletionSet completions = getCompletionSet("t.t"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "tkey1")); + CHECK(checkCompletion(completions, prefix, "tkey2")); + } + { + CompletionSet completions = getCompletionSet("t.tkey1.data2:re"); + + std::string prefix = "t.tkey1.data2:"; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "rep(")); + CHECK(checkCompletion(completions, prefix, "reverse(")); + } + { + CompletionSet completions = getCompletionSet("t.mtk"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "mtkey1")); + CHECK(checkCompletion(completions, prefix, "mtkey2")); + } + { + CompletionSet completions = getCompletionSet("t.mtkey1."); + + std::string prefix = "t.mtkey1."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "x")); + CHECK(checkCompletion(completions, prefix, "y")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMetatableIndexFunction") +{ + runCode(L, R"( + -- Create 't' which is a table with a metatable with an __index function + mt = {} + mt.__index = function(table, key) + print("mt.__index called") + if key == "foo" then + return "FOO" + elseif key == "bar" then + return "BAR" + else + return nil + end + end + + t = {} + setmetatable(t, mt) + t.tkey = 0 +)"); + { + CompletionSet completions = getCompletionSet("t.t"); + + std::string prefix = "t."; + CHECK(completions.size() == 1); + CHECK(checkCompletion(completions, prefix, "tkey")); + } + { + // t.foo is a valid key, but should not be completed because it requires calling an __index function + CompletionSet completions = getCompletionSet("t.foo"); + + CHECK(completions.size() == 0); + } + { + // t.foo is a valid key, but should not be found because it requires calling an __index function + CompletionSet completions = getCompletionSet("t.foo:"); + + CHECK(completions.size() == 0); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMultipleMetatableIndexTables") +{ + runCode(L, R"( + -- Create a table with a chain of metatables + mt2 = {} + mt2.__index = mt2 + + mt = {} + mt.__index = mt + setmetatable(mt, mt2) + + t = {} + setmetatable(t, mt) + + mt2.mt2key = {x=1, y=2} + mt.mtkey = 2 + t.tkey = 3 +)"); + { + CompletionSet completions = getCompletionSet("t."); + + std::string prefix = "t."; + CHECK(completions.size() == 4); + CHECK(checkCompletion(completions, prefix, "__index")); + CHECK(checkCompletion(completions, prefix, "tkey")); + CHECK(checkCompletion(completions, prefix, "mtkey")); + CHECK(checkCompletion(completions, prefix, "mt2key")); + } + { + CompletionSet completions = getCompletionSet("t.__index."); + + std::string prefix = "t.__index."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "__index")); + CHECK(checkCompletion(completions, prefix, "mtkey")); + CHECK(checkCompletion(completions, prefix, "mt2key")); + } + { + CompletionSet completions = getCompletionSet("t.mt2key."); + + std::string prefix = "t.mt2key."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "x")); + CHECK(checkCompletion(completions, prefix, "y")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithDeepMetatableIndexTables") +{ + runCode(L, R"( +-- Creates a table with a chain of metatables of length `count` +function makeChainedTable(count) + local result = {} + result.__index = result + result[string.format("entry%d", count)] = { count = count } + if count == 0 then + return result + else + return setmetatable(result, makeChainedTable(count - 1)) + end +end + +t30 = makeChainedTable(30) +t60 = makeChainedTable(60) +)"); + { + // Check if entry0 exists + CompletionSet completions = getCompletionSet("t30.entry0"); + + std::string prefix = "t30."; + CHECK(checkCompletion(completions, prefix, "entry0")); + } + { + // Check if entry0.count exists + CompletionSet completions = getCompletionSet("t30.entry0.co"); + + std::string prefix = "t30.entry0."; + CHECK(checkCompletion(completions, prefix, "count")); + } + { + // Check if entry0 exists. With the max traversal limit of 50 in the repl, this should fail. + CompletionSet completions = getCompletionSet("t60.entry0"); + + CHECK(completions.size() == 0); + } + { + // Check if entry0.count exists. With the max traversal limit of 50 in the repl, this should fail. + CompletionSet completions = getCompletionSet("t60.entry0.co"); + + CHECK(completions.size() == 0); + } +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index b9fd04d6..ba03f363 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/RequireTracer.h" +#include "Luau/Parser.h" #include "Fixture.h" diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index a8729268..31d7ef10 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -598,15 +598,13 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ /* * The two-pass alias definition system starts by ascribing a free TypeVar to each alias. It then * circles back to fill in the actual type later on. - * + * * If this free type is unified with something degenerate like `any`, we need to take extra care * to ensure that the alias actually binds to the type that the user expected. */ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") { - ScopedFastFlag sff[] = { - {"LuauTwoPassAliasDefinitionFix", true} - }; + ScopedFastFlag sff[] = {{"LuauTwoPassAliasDefinitionFix", true}}; CheckResult result = check(R"( local function x() diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 572b882d..2ad11d01 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index df06884d..f3dfb214 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 0283ae19..98fa66eb 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2652486b..c6d55793 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 8a2c6f27..f8fccf6b 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 93c0baf6..d677e28d 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2bcd840c..eee0e0f1 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Fixture.h" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index df365fda..9021700d 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -426,8 +426,6 @@ TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauIfElseExpectedType2", true}, - {"LuauIfElseBranchTypeUnion", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index f19cb618..6bcd4b99 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -2168,8 +2167,6 @@ b() TEST_CASE_FIXTURE(Fixture, "length_operator_union") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} | {string} local y = #x @@ -2180,8 +2177,6 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_intersection") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} & {z:string} -- mixed tables are evil local y = #x @@ -2192,8 +2187,6 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_non_table_union") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} | any | string local y = #x @@ -2204,8 +2197,6 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} | number | string local y = #x @@ -2214,4 +2205,38 @@ local y = #x LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") +{ + ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; + + // t :: t1 where t1 = {metatable {__index: t1, __tostring: (t1) -> string}} + CheckResult result = check(R"( + local mt = {} + local t = setmetatable({}, mt) + mt.__index = t + + function mt:__tostring() + return t.p + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 't' does not have key 'p'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") +{ + CheckResult result = check(R"( + local data = { x = 5 } + local t1 = setmetatable({}, { __index = data }) + local t2 = setmetatable({}, t1) -- note: must be t1, not a new table + + local x1 = t1.x -- ok + local x2 = t2.x -- nope + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 't2' does not have key 'x'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 531a382f..32358571 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -2,7 +2,6 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -4654,8 +4653,6 @@ a = setmetatable(a, { __call = function(x) end }) TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") { - ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true}; - CheckResult result = check(R"( local function f(a: (number, number) -> number) return a(1, 3) end f(((function(a, b) return a + b end))) @@ -4735,21 +4732,14 @@ local a = if false then "a" elseif false then "b" else "c" TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") { - ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true}; + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); - { - CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a"), {true}), "number?"); - } + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") { - ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - CheckResult result = check(R"( type X = {number | string} local a: X = if true then {"1", 2, 3} else {4, 5, 6} @@ -4761,9 +4751,6 @@ local a: X = if true then {"1", 2, 3} else {4, 5, 6} TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") { - ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - CheckResult result = check(R"( local a: number? = if true then 1 else nil )"); @@ -4773,8 +4760,6 @@ local a: number? = if true then 1 else nil TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") { - ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - CheckResult result = check(R"( local function times(n: any, f: () -> T) local result: {T} = {} @@ -5058,8 +5043,6 @@ end TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") { - ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true}; - CheckResult result = check(R"( local function getIt() local y @@ -5076,8 +5059,6 @@ local c = a or b TEST_CASE_FIXTURE(Fixture, "bound_typepack_promote") { - ScopedFastFlag luauCommittingTxnLogFreeTpPromote{"LuauCommittingTxnLogFreeTpPromote", true}; - // No assertions should trigger check(R"( local function p() @@ -5251,7 +5232,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") {"LuauDiscriminableUnions2", true}, {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, - {"LuauLengthOnCompositeType", true}, }; CheckResult result = check(R"( @@ -5272,7 +5252,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") {"LuauDiscriminableUnions2", true}, {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, - {"LuauLengthOnCompositeType", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 8c7fb79a..4669ea8e 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 079870f5..cbe2e48f 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -930,8 +929,6 @@ type R = { m: F } TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") { - ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; - CheckResult result = check(R"( local a: () -> (number, ...string) local b: () -> (number, ...boolean) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 759794e6..3b53ddfe 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 8b056544..c4931578 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 329e7b1f..e43161fa 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index f2ecc96b..b4f81bba 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -371,6 +371,15 @@ do st, msg = coroutine.close(co) assert(st and msg == nil) assert(f() == 42) + + -- closing a coroutine with a large stack + co = coroutine.create(function() + local function f(depth) return if depth > 0 then f(depth - 1) + depth else 0 end + coroutine.yield(f(100)) + end) + assert(coroutine.resume(co)) + st, msg = coroutine.close(co) + assert(st and msg == nil) end return 'OK' diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua index f899603f..14d843a4 100644 --- a/tests/conformance/coverage.lua +++ b/tests/conformance/coverage.lua @@ -49,16 +49,24 @@ foo() c = getcoverage(foo) assert(#c == 1) assert(c[1].name == "foo") +assert(c[1].linedefined == 4) +assert(c[1].depth == 0) assert(validate(c[1], {5, 6, 7}, {})) bar() c = getcoverage(bar) assert(#c == 3) assert(c[1].name == "bar") +assert(c[1].linedefined == 10) +assert(c[1].depth == 0) assert(validate(c[1], {11, 15, 19}, {})) assert(c[2].name == "one") +assert(c[2].linedefined == 11) +assert(c[2].depth == 1) assert(validate(c[2], {12}, {})) assert(c[3].name == nil) +assert(c[3].linedefined == 15) +assert(c[3].depth == 1) assert(validate(c[3], {}, {16})) return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 0e410000..0c8cc2d8 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -76,6 +76,9 @@ assert(baz(co, 2, "n") == nil) assert(baz(math.sqrt, "n") == "sqrt") assert(baz(math.sqrt, "f") == math.sqrt) -- yes this is pointless +local t = { foo = function() return 1 end } +assert(baz(t.foo, "n") == "foo") + -- info multi-arg returns function quux(...) return {debug.info(...)} diff --git a/tests/main.cpp b/tests/main.cpp index cd24e100..2af9f702 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -10,6 +10,9 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include // IsDebuggerPresent #endif @@ -52,7 +55,7 @@ static bool debuggerPresent() #endif } -static int assertionHandler(const char* expr, const char* file, int line, const char* function) +static int testAssertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) LUAU_DEBUGBREAK(); @@ -218,7 +221,7 @@ static void setFastFlags(const std::vector& flags) int main(int argc, char** argv) { - Luau::assertHandler() = assertionHandler; + Luau::assertHandler() = testAssertionHandler; doctest::registerReporter("boost", 0, true); diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis new file mode 100644 index 00000000..5de0140e --- /dev/null +++ b/tools/natvis/Analysis.natvis @@ -0,0 +1,78 @@ + + + + + AnyTypeVar + + + + {{ index=0, value={*($T1*)storage} }} + {{ index=1, value={*($T2*)storage} }} + {{ index=2, value={*($T3*)storage} }} + {{ index=3, value={*($T4*)storage} }} + {{ index=4, value={*($T5*)storage} }} + {{ index=5, value={*($T6*)storage} }} + {{ index=6, value={*($T7*)storage} }} + {{ index=7, value={*($T8*)storage} }} + {{ index=8, value={*($T9*)storage} }} + {{ index=9, value={*($T10*)storage} }} + {{ index=10, value={*($T11*)storage} }} + {{ index=11, value={*($T12*)storage} }} + {{ index=12, value={*($T13*)storage} }} + {{ index=13, value={*($T14*)storage} }} + {{ index=14, value={*($T15*)storage} }} + {{ index=15, value={*($T16*)storage} }} + {{ index=16, value={*($T17*)storage} }} + {{ index=17, value={*($T18*)storage} }} + {{ index=18, value={*($T19*)storage} }} + {{ index=19, value={*($T20*)storage} }} + {{ index=20, value={*($T21*)storage} }} + {{ index=21, value={*($T22*)storage} }} + {{ index=22, value={*($T23*)storage} }} + {{ index=23, value={*($T24*)storage} }} + {{ index=24, value={*($T25*)storage} }} + {{ index=25, value={*($T26*)storage} }} + {{ index=26, value={*($T27*)storage} }} + {{ index=27, value={*($T28*)storage} }} + {{ index=28, value={*($T29*)storage} }} + {{ index=29, value={*($T30*)storage} }} + {{ index=30, value={*($T31*)storage} }} + {{ index=31, value={*($T32*)storage} }} + + typeId + *($T1*)storage + *($T2*)storage + *($T3*)storage + *($T4*)storage + *($T5*)storage + *($T6*)storage + *($T7*)storage + *($T8*)storage + *($T9*)storage + *($T10*)storage + *($T11*)storage + *($T12*)storage + *($T13*)storage + *($T14*)storage + *($T15*)storage + *($T16*)storage + *($T17*)storage + *($T18*)storage + *($T19*)storage + *($T20*)storage + *($T21*)storage + *($T22*)storage + *($T23*)storage + *($T24*)storage + *($T25*)storage + *($T26*)storage + *($T27*)storage + *($T28*)storage + *($T29*)storage + *($T30*)storage + *($T31*)storage + *($T32*)storage + + + + diff --git a/tools/natvis/Ast.natvis b/tools/natvis/Ast.natvis new file mode 100644 index 00000000..322eb8f6 --- /dev/null +++ b/tools/natvis/Ast.natvis @@ -0,0 +1,25 @@ + + + + + AstArray size={size} + + size + + size + data + + + + + + + size_ + + size_ + storage._Mypair._Myval2._Myfirst + offset + + + + + diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis new file mode 100644 index 00000000..9924e194 --- /dev/null +++ b/tools/natvis/VM.natvis @@ -0,0 +1,269 @@ + + + + + nil + {(bool)value.b} + lightuserdata {value.p} + number = {value.n} + vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} + {value.gc->ts} + {value.gc->h} + function {value.gc->cl,view(short)} + userdata {value.gc->u} + thread {value.gc->th} + proto {value.gc->p} + upvalue {value.gc->uv} + deadkey + empty + + value.p + value.gc->ts + value.gc->h + value.gc->cl + value.gc->cl + value.gc->u + value.gc->th + value.gc->p + value.gc->uv + + fixed ({(int)value.gc->gch.marked}) + black ({(int)value.gc->gch.marked}) + white ({(int)value.gc->gch.marked}) + white ({(int)value.gc->gch.marked}) + gray ({(int)value.gc->gch.marked}) + + + + + + nil + {(bool)value.b} + lightuserdata {value.p} + number = {value.n} + vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} + {value.gc->ts} + {value.gc->h} + function {value.gc->cl,view(short)} + userdata {value.gc->u} + thread {value.gc->th} + proto {value.gc->p} + upvalue {value.gc->uv} + deadkey + empty + + (void**)value.p + value.gc->ts + value.gc->h + value.gc->cl + value.gc->cl + value.gc->u + value.gc->th + value.gc->p + value.gc->uv + + next + + + + + {key,na} = {val} + --- + + + + table + + metatable + + + [size] {1<<lsizenode} + + + 1<<lsizenode + node[$i] + + + + + [size] {sizearray} + + + sizearray + array[$i] + + + + + + + + + + + + + 1 + + + + + metatable->node[i].val + + + + i = i + 1 + + + "unknown",sb + + + tag + len + metatable + data + + + + + {c.f,na} + {l.p,na} + {c} + {l} + invalid + + + + {data,s} + + + + + {ci->func->value.gc->cl.c.f,na} + + + {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} {ci->func->value.gc->cl.l.p->debugname->data,sb} + + + {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} + + thread + + + {ci-base_ci} frames + + + ci-base_ci + + + base_ci[ci-base_ci - $i].func->value.gc->cl,view(short) + + + + + + {top-base} values + + + top-base + base + + + + + {top-stack} values + + + top-stack + stack + + + + + + + openupval + u.l.next + this + + + + l_gt + env + userdata + + + + + {source->data,sb}:{linedefined} function {debugname->data,sb} [{(int)numparams} arg, {(int)nups} upval] + {source->data,sb}:{linedefined} [{(int)numparams} arg, {(int)nups} upval] + + debugname + + constants + + + sizek + k[$i] + + + + + locals + + + sizelocvars + locvars[$i] + + + + + bytecode + + + sizecode + code[$i] + + + + + functions + + + sizep + p[$i] + + + + + upvals + + + sizeupvalues + upvalues[$i] + + + + + source + + + + + + + {(lua_Type)tt} + + + fixed ({(int)marked}) + black ({(int)marked}) + white ({(int)marked}) + white ({(int)marked}) + gray ({(int)marked}) + unknown + + memcat + + + + From a8eabedd570e9b3aba7e02ff2b0f4d8bdbf9efbb Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 24 Feb 2022 15:15:41 -0800 Subject: [PATCH 027/102] Sync to upstream/release/516 --- Analysis/include/Luau/Module.h | 9 +- Analysis/include/Luau/TypeInfer.h | 2 + Analysis/include/Luau/TypeUtils.h | 4 +- Analysis/include/Luau/Unifier.h | 29 ++- Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/Linter.cpp | 1 + Analysis/src/Module.cpp | 27 +-- Analysis/src/Transpiler.cpp | 8 + Analysis/src/TypeInfer.cpp | 67 +++--- Analysis/src/TypeUtils.cpp | 8 +- Analysis/src/TypeVar.cpp | 110 ++-------- Analysis/src/Unifier.cpp | 126 ++++++----- Ast/src/Parser.cpp | 8 +- VM/include/lua.h | 12 +- VM/src/lapi.cpp | 20 +- VM/src/ldo.cpp | 3 +- VM/src/lfunc.cpp | 62 ++---- VM/src/lgc.cpp | 218 +++---------------- VM/src/lgc.h | 10 +- VM/src/lgcdebug.cpp | 72 +------ VM/src/linit.cpp | 2 +- VM/src/lmem.cpp | 127 ++--------- VM/src/lmem.h | 4 - VM/src/lobject.h | 5 +- VM/src/lstate.cpp | 32 +-- VM/src/lstate.h | 5 - VM/src/lstring.cpp | 70 ++---- VM/src/ltable.cpp | 4 +- VM/src/ludata.cpp | 2 +- fuzz/linter.cpp | 8 +- fuzz/luau.proto | 55 +++-- fuzz/proto.cpp | 237 ++++++++++++++------- fuzz/protoprint.cpp | 129 +++++++++-- fuzz/prototest.cpp | 12 +- fuzz/typeck.cpp | 6 +- tests/Autocomplete.test.cpp | 1 - tests/Conformance.test.cpp | 92 +++++++- tests/Linter.test.cpp | 13 ++ tests/Parser.test.cpp | 1 - tests/Transpiler.test.cpp | 15 ++ tests/TypeInfer.intersectionTypes.test.cpp | 6 +- tests/TypeInfer.refinements.test.cpp | 16 ++ tests/TypeInfer.singletons.test.cpp | 124 +++++++++++ tests/TypeInfer.tables.test.cpp | 18 ++ tests/TypeInfer.test.cpp | 6 - tests/TypeInfer.tryUnify.test.cpp | 4 +- tests/TypeVar.test.cpp | 12 -- tools/natvis/VM.natvis | 5 +- 48 files changed, 919 insertions(+), 892 deletions(-) diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 61200771..6c689b7c 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation) - namespace Luau { @@ -60,11 +58,8 @@ struct TypeArena template TypeId addType(T tv) { - if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) - { - if constexpr (std::is_same_v) - LUAU_ASSERT(tv.options.size() >= 2); - } + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); return addTV(TypeVar(std::move(tv))); } diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 3c5ded3c..2440c810 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -31,6 +31,7 @@ bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); const AstStat* getFallthrough(const AstStat* node); +struct UnifierOptions; struct Unifier; // A substitution which replaces generic types in a given set by free types. @@ -245,6 +246,7 @@ struct TypeChecker * Treat any failures as type errors in the final typecheck report. */ bool unify(TypeId subTy, TypeId superTy, const Location& location); + bool unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options); bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); /** Attempt to unify the types. diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index ffddfe4b..42c1bc0b 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -13,7 +13,7 @@ namespace Luau using ScopePtr = std::shared_ptr; -std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location); -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location); +std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location); +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location); } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 9db4e22b..fe822b01 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -19,11 +19,31 @@ enum Variance Invariant }; +// A substitution which replaces singleton types by their wider types +struct Widen : Substitution +{ + Widen(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId ty) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId ty) override; + bool ignoreChildren(TypeId ty) override; +}; + +// TODO: Use this more widely. +struct UnifierOptions +{ + bool isFunctionCall = false; +}; + struct Unifier { TypeArena* const types; Mode mode; - ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; @@ -34,9 +54,9 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. @@ -65,7 +85,10 @@ private: void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); + + TypeId widen(TypeId ty); TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + void cacheResult(TypeId subTy, TypeId superTy); public: diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 5a1ae397..29a2c6b5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -236,10 +236,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { ty = follow(ty); - auto canUnify = [&typeArena, &module](TypeId subTy, TypeId superTy) { + auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); + Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) { diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 8d7d2d97..7635dc0f 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -201,6 +201,7 @@ static bool similar(AstExpr* lhs, AstExpr* rhs) return true; } + CASE(AstExprIfElse) return similar(le->condition, re->condition) && similar(le->trueExpr, re->trueExpr) && similar(le->falseExpr, re->falseExpr); else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 817a33e9..412b78bb 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuau LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) LUAU_FASTFLAG(LuauImmutableTypes) -LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) namespace Luau { @@ -379,28 +378,14 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) - { - std::vector options; - options.reserve(t.options.size()); + std::vector options; + options.reserve(t.options.size()); - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - TypeId result = dest.addType(UnionTypeVar{std::move(options)}); - seenTypes[typeId] = result; - } - else - { - TypeId result = dest.addType(UnionTypeVar{}); - seenTypes[typeId] = result; - - UnionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - } + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; } void TypeCloner::operator()(const IntersectionTypeVar& t) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 54bd0d5e..a02d396b 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1153,6 +1153,14 @@ struct Printer writer.symbol(")"); } } + else if (const auto& a = typeAnnotation.as()) + { + writer.keyword(a->value ? "true" : "false"); + } + else if (const auto& a = typeAnnotation.as()) + { + writer.string(std::string_view(a->value.data, a->value.size)); + } else if (typeAnnotation.is()) { writer.symbol("%error-type%"); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index c29699b7..faf60eb3 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,7 +29,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) -LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) @@ -38,14 +37,14 @@ LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauAnotherTypeLevelFix, false) +LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) +LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) namespace Luau { @@ -1125,7 +1124,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco ty = follow(ty); - if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) + if (tableSelf && tableSelf->state != TableState::Sealed) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1138,7 +1137,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) + if (tableSelf && tableSelf->state != TableState::Sealed) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1210,8 +1209,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); - if (FFlag::LuauProperTypeLevels) - aliasScope->level.subLevel = subLevel; + aliasScope->level.subLevel = subLevel; auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); @@ -1624,7 +1622,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) { ErrorVec errors; - auto result = Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); + auto result = Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); reportErrors(errors); return result; } @@ -1632,7 +1630,7 @@ std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsTyp std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) { ErrorVec errors; - auto result = Luau::findMetatableEntry(errors, globalScope, type, entry, location); + auto result = Luau::findMetatableEntry(errors, type, entry, location); reportErrors(errors); return result; } @@ -1751,13 +1749,23 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 - std::vector result = reduceUnion(parts); + if (FFlag::LuauDoNotTryToReduce) + { + if (parts.size() == 1) + return parts[0]; - if (result.size() == 1) - return result[0]; + return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. + } + else + { + // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 + std::vector result = reduceUnion(parts); - return addType(IntersectionTypeVar{result}); + if (result.size() == 1) + return result[0]; + + return addType(IntersectionTypeVar{result}); + } } if (addErrors) @@ -2823,10 +2831,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { auto freshTy = [&]() { - if (FFlag::LuauProperTypeLevels) - return freshType(level); - else - return freshType(scope); + return freshType(level); }; if (auto globalName = funName.as()) @@ -3790,7 +3795,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - unify(fn, r, expr.location); + if (FFlag::LuauWidenIfSupertypeIsFree) + { + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + } + else + unify(fn, r, expr.location); return {{retPack}}; } @@ -4243,9 +4255,15 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const } bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) +{ + UnifierOptions options; + return unify(subTy, superTy, location, options); +} + +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options) { Unifier state = mkUnifier(location); - state.tryUnify(subTy, superTy); + state.tryUnify(subTy, superTy, options.isFunctionCall); if (FFlag::LuauUseCommittingTxnLog) state.log.commit(); @@ -4654,7 +4672,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d } }; - if (auto ttv = getTableType(FFlag::LuauUnionTagMatchFix ? utk->table : follow(utk->table))) + if (auto ttv = getTableType(utk->table)) accumulate(ttv->props); else if (auto ctv = get(follow(utk->table))) { @@ -4691,8 +4709,7 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { ScopePtr scope = std::make_shared(parent); - if (FFlag::LuauProperTypeLevels) - scope->level = parent->level; + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4724,7 +4741,7 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; + return Unifier{¤tModule->internalTypes, currentModule->mode, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -5444,7 +5461,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions2); + LUAU_ASSERT(FFlag::LuauDiscriminableUnions2 || FFlag::LuauAssertStripsFalsyTypes); const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 593b54c8..c2435890 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -10,7 +10,7 @@ LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) namespace Luau { -std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location) +std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location) { type = follow(type); @@ -37,7 +37,7 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa return std::nullopt; } -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location) +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location) { if (get(ty)) return ty; @@ -49,7 +49,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc return it->second.type; } - std::optional mtIndex = findMetatableEntry(errors, globalScope, ty, "__index", location); + std::optional mtIndex = findMetatableEntry(errors, ty, "__index", location); int count = 0; while (mtIndex) { @@ -82,7 +82,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); - mtIndex = findMetatableEntry(errors, globalScope, *mtIndex, "__index", location); + mtIndex = findMetatableEntry(errors, *mtIndex, "__index", location); } return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index b2358c27..a1dcfdbe 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,9 +23,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAG(LuauDiscriminableUnions2) namespace Luau @@ -145,20 +143,13 @@ bool isNil(TypeId ty) bool isBoolean(TypeId ty) { - if (FFlag::LuauRefactorTypeVarQuestions) - { - if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) - return true; + if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + return true; - if (auto utv = get(follow(ty))) - return std::all_of(begin(utv), end(utv), isBoolean); + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isBoolean); - return false; - } - else - { - return isPrim(ty, PrimitiveTypeVar::Boolean); - } + return false; } bool isNumber(TypeId ty) @@ -168,20 +159,13 @@ bool isNumber(TypeId ty) bool isString(TypeId ty) { - if (FFlag::LuauRefactorTypeVarQuestions) - { - if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) - return true; + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) + return true; - if (auto utv = get(follow(ty))) - return std::all_of(begin(utv), end(utv), isString); + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isString); - return false; - } - else - { - return isPrim(ty, PrimitiveTypeVar::String); - } + return false; } bool isThread(TypeId ty) @@ -194,45 +178,11 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - if (FFlag::LuauRefactorTypeVarQuestions) - { - auto utv = get(follow(ty)); - if (!utv) - return false; - - return std::any_of(begin(utv), end(utv), isNil); - } - else - { - std::unordered_set seen; - std::deque queue{ty}; - while (!queue.empty()) - { - TypeId current = follow(queue.front()); - queue.pop_front(); - - if (seen.count(current)) - continue; - - seen.insert(current); - - if (isNil(current)) - return true; - - if (auto u = get(current)) - { - for (TypeId option : u->options) - { - if (isNil(option)) - return true; - - queue.push_back(option); - } - } - } - + auto utv = get(follow(ty)); + if (!utv) return false; - } + + return std::any_of(begin(utv), end(utv), isNil); } bool isTableIntersection(TypeId ty) @@ -263,38 +213,24 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (FFlag::LuauRefactorTypeVarQuestions) + else if (isString(type)) { - if (isString(type)) - { - auto ptv = get(getSingletonTypes().stringType); - LUAU_ASSERT(ptv && ptv->metatable); - return ptv->metatable; - } - else - return std::nullopt; - } - else - { - if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) - { - LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); - return primitiveType->metatable; - } - else - return std::nullopt; + auto ptv = get(getSingletonTypes().stringType); + LUAU_ASSERT(ptv && ptv->metatable); + return ptv->metatable; } + + return std::nullopt; } const TableTypeVar* getTableType(TypeId type) { - if (FFlag::LuauUnionTagMatchFix) - type = follow(type); + type = follow(type); if (const TableTypeVar* ttv = get(type)) return ttv; else if (const MetatableTypeVar* mtv = get(type)) - return get(FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table); + return get(follow(mtv->table)); else return nullptr; } @@ -311,7 +247,7 @@ const std::string* getName(TypeId type) { if (mtv->syntheticName) return &*mtv->syntheticName; - type = FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table; + type = follow(mtv->table); } if (auto ttv = get(type)) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 322f6ebf..d0eba013 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,9 +21,8 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); -LUAU_FASTFLAG(LuauProperTypeLevels); -LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) +LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) namespace Luau { @@ -122,7 +121,7 @@ struct PromoteTypeLevels } }; -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) +static void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) @@ -133,6 +132,7 @@ void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const Typ visitTypeVarOnce(ty, ptl, seen); } +// TODO: use this and make it static. void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -247,6 +247,48 @@ struct SkipCacheForType bool result = false; }; +bool Widen::isDirty(TypeId ty) +{ + return FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty)); +} + +bool Widen::isDirty(TypePackId) +{ + return false; +} + +TypeId Widen::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + auto stv = FFlag::LuauUseCommittingTxnLog ? log->getMutable(ty) : getMutable(ty); + LUAU_ASSERT(stv); + + if (get(stv)) + return getSingletonTypes().stringType; + else + { + // If this assert trips, it's likely we now have number singletons. + LUAU_ASSERT(get(stv)); + return getSingletonTypes().booleanType; + } +} + +TypePackId Widen::clean(TypePackId) +{ + throw std::runtime_error("Widen attempted to clean a dirty type pack?"); +} + +bool Widen::ignoreChildren(TypeId ty) +{ + // Sometimes we unify ("hi") -> free1 with (free2) -> free3, so don't ignore functions. + // TODO: should we be doing this? we would need to rework how checkCallOverload does the unification. + if (FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))) + return false; + + // We only care about unions. + return !(FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))); +} + static std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { @@ -263,43 +305,22 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - if (FFlag::LuauUnionTagMatchFix) + if (auto ttv = getTableType(type)) { - if (auto ttv = getTableType(type)) + for (auto&& [name, prop] : ttv->props) { - for (auto&& [name, prop] : ttv->props) - { - if (auto sing = get(follow(prop.type))) - return {{name, sing}}; - } - } - } - else - { - type = follow(type); - - if (auto ttv = get(type)) - { - for (auto&& [name, prop] : ttv->props) - { - if (auto sing = get(follow(prop.type))) - return {{name, sing}}; - } - } - else if (auto mttv = get(type)) - { - return getTableMatchTag(mttv->table); + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; } } return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, +Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) - , globalScope(std::move(globalScope)) , log(parentLog) , location(location) , variance(variance) @@ -308,11 +329,10 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, +Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) - , globalScope(std::move(globalScope)) , DEPRECATED_log(sharedSeen) , log(parentLog, sharedSeen) , location(location) @@ -435,6 +455,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else if (superFree) { + TypeLevel superLevel = superFree->level; + occursCheck(superTy, subTy); bool occursFailed = false; if (FFlag::LuauUseCommittingTxnLog) @@ -442,8 +464,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else occursFailed = bool(get(superTy)); - TypeLevel superLevel = superFree->level; - // Unification can't change the level of a generic. auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); if (subGeneric && !subGeneric->level.subsumes(superLevel)) @@ -459,20 +479,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (FFlag::LuauUseCommittingTxnLog) { promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - log.replace(superTy, BoundTypeVar(subTy)); + log.replace(superTy, BoundTypeVar(widen(subTy))); } else { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - else if (auto subLevel = getMutableLevel(subTy)) - { - if (!subLevel->subsumes(superFree->level)) - *subLevel = superFree->level; - } + promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); DEPRECATED_log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); + *asMutable(superTy) = BoundTypeVar(widen(subTy)); } } @@ -507,16 +521,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); - else if (auto superLevel = getMutableLevel(superTy)) - { - if (!superLevel->subsumes(subFree->level)) - { - DEPRECATED_log(superTy); - *superLevel = subFree->level; - } - } + promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); DEPRECATED_log(subTy); *asMutable(subTy) = BoundTypeVar(superTy); @@ -2064,6 +2069,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } +TypeId Unifier::widen(TypeId ty) +{ + if (!FFlag::LuauWidenIfSupertypeIsFree) + return ty; + + Widen widen{types}; + std::optional result = widen.substitute(ty); + // TODO: what does it mean for substitution to fail to widen? + return result.value_or(ty); +} + TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); @@ -2915,7 +2931,7 @@ void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) { - return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); + return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); } void Unifier::occursCheck(TypeId needle, TypeId haystack) @@ -3096,9 +3112,9 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { if (FFlag::LuauUseCommittingTxnLog) - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, &log}; + return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; else - return Unifier{types, mode, globalScope, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; + return Unifier{types, mode, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 235d6349..8767daa0 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -12,7 +12,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) -LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) @@ -2372,11 +2371,11 @@ std::pair, AstArray> Parser::parseG { Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack)) + if (lexer.current().type == Lexeme::Dot3 || seenPack) { seenPack = true; - if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3) + if (lexer.current().type != Lexeme::Dot3) report(lexer.current().location, "Generic types come before generic type packs"); else nextLexeme(); @@ -2414,9 +2413,6 @@ std::pair, AstArray> Parser::parseG } else { - if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack) - report(lexer.current().location, "Generic types come before generic type packs"); - if (withDefaultValues && lexer.current().type == '=') { seenDefault = true; diff --git a/VM/include/lua.h b/VM/include/lua.h index c5dcef25..af0e2835 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -44,7 +44,7 @@ typedef int (*lua_Continuation)(lua_State* L, int status); ** prototype for memory-allocation functions */ -typedef void* (*lua_Alloc)(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize); +typedef void* (*lua_Alloc)(void* ud, void* ptr, size_t osize, size_t nsize); /* non-return type */ #define l_noret void LUA_NORETURN @@ -178,11 +178,11 @@ LUA_API int lua_pushthread(lua_State* L); /* ** get functions (Lua -> stack) */ -LUA_API void lua_gettable(lua_State* L, int idx); -LUA_API void lua_getfield(lua_State* L, int idx, const char* k); -LUA_API void lua_rawgetfield(lua_State* L, int idx, const char* k); -LUA_API void lua_rawget(lua_State* L, int idx); -LUA_API void lua_rawgeti(lua_State* L, int idx, int n); +LUA_API int lua_gettable(lua_State* L, int idx); +LUA_API int lua_getfield(lua_State* L, int idx, const char* k); +LUA_API int lua_rawgetfield(lua_State* L, int idx, const char* k); +LUA_API int lua_rawget(lua_State* L, int idx); +LUA_API int lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 29d5f397..39c76e08 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -659,16 +659,16 @@ int lua_pushthread(lua_State* L) ** get functions (Lua -> stack) */ -void lua_gettable(lua_State* L, int idx) +int lua_gettable(lua_State* L, int idx) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); - return; + return ttype(L->top - 1); } -void lua_getfield(lua_State* L, int idx, const char* k) +int lua_getfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); @@ -677,10 +677,10 @@ void lua_getfield(lua_State* L, int idx, const char* k) setsvalue(L, &key, luaS_new(L, k)); luaV_gettable(L, t, &key, L->top); api_incr_top(L); - return; + return ttype(L->top - 1); } -void lua_rawgetfield(lua_State* L, int idx, const char* k) +int lua_rawgetfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); @@ -689,26 +689,26 @@ void lua_rawgetfield(lua_State* L, int idx, const char* k) setsvalue(L, &key, luaS_new(L, k)); setobj2s(L, L->top, luaH_getstr(hvalue(t), tsvalue(&key))); api_incr_top(L); - return; + return ttype(L->top - 1); } -void lua_rawget(lua_State* L, int idx) +int lua_rawget(lua_State* L, int idx) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); - return; + return ttype(L->top - 1); } -void lua_rawgeti(lua_State* L, int idx, int n) +int lua_rawgeti(lua_State* L, int idx, int n) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); api_incr_top(L); - return; + return ttype(L->top - 1); } void lua_createtable(lua_State* L, int narray, int nrec) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index d87f0661..b5ae496b 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -151,8 +151,7 @@ l_noret luaD_throw(lua_State* L, int errcode) static void correctstack(lua_State* L, TValue* oldstack) { L->top = (L->top - oldstack) + L->stack; - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (UpVal* up = L->openupval; up != NULL; up = (UpVal*)up->next) + for (UpVal* up = L->openupval; up != NULL; up = up->u.l.threadnext) up->v = (up->v - oldstack) + L->stack; for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) { diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 582d4627..66447a95 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,13 +6,10 @@ #include "lmem.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) -LUAU_FASTFLAG(LuauGcPagedSweep) - Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); - luaC_link(L, f, LUA_TPROTO); + luaC_init(L, f, LUA_TPROTO); f->k = NULL; f->sizek = 0; f->p = NULL; @@ -40,7 +37,7 @@ Proto* luaF_newproto(lua_State* L) Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) { Closure* c = luaM_newgco(L, Closure, sizeLclosure(nelems), L->activememcat); - luaC_link(L, c, LUA_TFUNCTION); + luaC_init(L, c, LUA_TFUNCTION); c->isC = 0; c->env = e; c->nupvalues = cast_byte(nelems); @@ -55,7 +52,7 @@ Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) { Closure* c = luaM_newgco(L, Closure, sizeCclosure(nelems), L->activememcat); - luaC_link(L, c, LUA_TFUNCTION); + luaC_init(L, c, LUA_TFUNCTION); c->isC = 1; c->env = e; c->nupvalues = cast_byte(nelems); @@ -82,8 +79,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) return p; } - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - pp = (UpVal**)&p->next; + pp = &p->u.l.threadnext; } UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ @@ -94,19 +90,10 @@ UpVal* luaF_findupval(lua_State* L, StkId level) // chain the upvalue in the threads open upvalue list at the proper position UpVal* next = *pp; - - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - uv->next = (GCObject*)next; - - if (FFlag::LuauGcPagedSweep) - { - uv->u.l.threadprev = pp; - if (next) - { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - next->u.l.threadprev = (UpVal**)&uv->next; - } - } + uv->u.l.threadnext = next; + uv->u.l.threadprev = pp; + if (next) + next->u.l.threadprev = &uv->u.l.threadnext; *pp = uv; @@ -125,15 +112,11 @@ void luaF_unlinkupval(UpVal* uv) uv->u.l.next->u.l.prev = uv->u.l.prev; uv->u.l.prev->u.l.next = uv->u.l.next; - if (FFlag::LuauGcPagedSweep) - { - // unlink upvalue from the thread open upvalue list - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and this and the following cast will not be required - *uv->u.l.threadprev = (UpVal*)uv->next; + // unlink upvalue from the thread open upvalue list + *uv->u.l.threadprev = uv->u.l.threadnext; - if (UpVal* next = (UpVal*)uv->next) - next->u.l.threadprev = uv->u.l.threadprev; - } + if (UpVal* next = uv->u.l.threadnext) + next->u.l.threadprev = uv->u.l.threadprev; } void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) @@ -145,34 +128,27 @@ void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) void luaF_close(lua_State* L, StkId level) { - global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval + global_State* g = L->global; UpVal* uv; while (L->openupval != NULL && (uv = L->openupval)->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); - if (!FFlag::LuauGcPagedSweep) - L->openupval = (UpVal*)uv->next; /* remove from `open' list */ + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); - if (FFlag::LuauGcPagedSweep && isdead(g, o)) + if (isdead(g, o)) { - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again uv->v = &uv->u.value; } - else if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) - { - luaF_freeupval(L, uv, NULL); /* free upvalue */ - } else { - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); setobj(L, &uv->u.value, uv->v); - uv->v = &uv->u.value; /* now current value lives here */ - luaC_linkupval(L, uv); /* link upvalue into `gcroot' list */ + uv->v = &uv->u.value; + // GC state of a new closed upvalue has to be initialized + luaC_initupval(L, uv); } } } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 724b24b2..8c3a2029 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauGcPagedSweep, false) - #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 #define GC_SWEEPPAGESTEPCOST 4 @@ -64,7 +62,6 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSatomic: g->gcstats.currcycle.atomictime += seconds; break; - case GCSsweepstring: case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; break; @@ -490,65 +487,6 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page) } } -#define sweepwholelist(L, p) sweeplist(L, p, SIZE_MAX) - -static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - GCObject* curr; - global_State* g = L->global; - int deadmask = otherwhite(g); - LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */ - while ((curr = *p) != NULL && count-- > 0) - { - int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; - if (curr->gch.tt == LUA_TTHREAD) - { - sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval); /* sweep open upvalues */ - - lua_State* th = gco2th(curr); - - if (alive) - { - resetbit(th->stackstate, THREAD_SLEEPINGBIT); - shrinkstack(th); - } - } - if (alive) - { /* not dead? */ - LUAU_ASSERT(!isdead(g, curr)); - makewhite(g, curr); /* make it white (for next cycle) */ - p = &curr->gch.next; - } - else - { /* must erase `curr' */ - LUAU_ASSERT(isdead(g, curr)); - *p = curr->gch.next; - if (curr == g->rootgc) /* is the first element of the list? */ - g->rootgc = curr->gch.next; /* adjust first */ - freeobj(L, curr, NULL); - } - } - - return p; -} - -static void deletelist(lua_State* L, GCObject** p, GCObject* limit) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - GCObject* curr; - while ((curr = *p) != limit) - { - if (curr->gch.tt == LUA_TTHREAD) /* delete open upvalues of each thread */ - deletelist(L, (GCObject**)&gco2th(curr)->openupval, NULL); - - *p = curr->gch.next; - freeobj(L, curr, NULL); - } -} - static void shrinkbuffers(lua_State* L) { global_State* g = L->global; @@ -570,8 +508,6 @@ static void shrinkbuffersfull(lua_State* L) static bool deletegco(void* context, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - // we are in the process of deleting everything // threads with open upvalues will attempt to close them all on removal // but those upvalues might point to stack values that were already deleted @@ -598,32 +534,13 @@ void luaC_freeall(lua_State* L) LUAU_ASSERT(L == g->mainthread); - if (FFlag::LuauGcPagedSweep) - { - luaM_visitgco(L, L, deletegco); + luaM_visitgco(L, L, deletegco); - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ - LUAU_ASSERT(g->strt.hash[i] == NULL); + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + LUAU_ASSERT(g->strt.hash[i] == NULL); - LUAU_ASSERT(L->global->strt.nuse == 0); - LUAU_ASSERT(g->strbufgc == NULL); - } - else - { - LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ - - deletelist(L, &g->rootgc, obj2gco(L)); - - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ - deletelist(L, (GCObject**)&g->strt.hash[i], NULL); - - LUAU_ASSERT(L->global->strt.nuse == 0); - deletelist(L, (GCObject**)&g->strbufgc, NULL); - - // unfortunately, when string objects are freed, the string table use count is decremented - // even when the string is a buffer that wasn't placed into the table - L->global->strt.nuse = 0; - } + LUAU_ASSERT(L->global->strt.nuse == 0); + LUAU_ASSERT(g->strbufgc == NULL); } static void markmt(global_State* g) @@ -687,26 +604,13 @@ static size_t atomic(lua_State* L) g->weak = NULL; /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); - g->sweepstrgc = 0; - - if (FFlag::LuauGcPagedSweep) - { - g->sweepgcopage = g->allgcopages; - g->gcstate = GCSsweep; - } - else - { - g->sweepgc = &g->rootgc; - g->gcstate = GCSsweepstring; - } - + g->sweepgcopage = g->allgcopages; + g->gcstate = GCSsweep; return work; } static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; int deadmask = otherwhite(g); @@ -740,8 +644,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) // a version of generic luaM_visitpage specialized for the main sweep stage static int sweepgcopage(lua_State* L, lua_Page* page) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - char* start; char* end; int busyBlocks; @@ -819,75 +721,29 @@ static size_t gcstep(lua_State* L, size_t limit) cost = atomic(L); /* finish mark phase */ - if (FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->gcstate == GCSsweep); - else - LUAU_ASSERT(g->gcstate == GCSsweepstring); - break; - } - case GCSsweepstring: - { - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - while (g->sweepstrgc < g->strt.size && cost < limit) - { - sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++]); - - cost += GC_SWEEPCOST; - } - - // nothing more to sweep? - if (g->sweepstrgc >= g->strt.size) - { - // sweep string buffer list and preserve used string count - uint32_t nuse = L->global->strt.nuse; - - sweepwholelist(L, (GCObject**)&g->strbufgc); - - L->global->strt.nuse = nuse; - - g->gcstate = GCSsweep; // end sweep-string phase - } + LUAU_ASSERT(g->gcstate == GCSsweep); break; } case GCSsweep: { - if (FFlag::LuauGcPagedSweep) + while (g->sweepgcopage && cost < limit) { - while (g->sweepgcopage && cost < limit) - { - lua_Page* next = luaM_getnextgcopage(g->sweepgcopage); // page sweep might destroy the page + lua_Page* next = luaM_getnextgcopage(g->sweepgcopage); // page sweep might destroy the page - int steps = sweepgcopage(L, g->sweepgcopage); + int steps = sweepgcopage(L, g->sweepgcopage); - g->sweepgcopage = next; - cost += steps * GC_SWEEPPAGESTEPCOST; - } - - // nothing more to sweep? - if (g->sweepgcopage == NULL) - { - // don't forget to visit main thread - sweepgco(L, NULL, obj2gco(g->mainthread)); - - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } + g->sweepgcopage = next; + cost += steps * GC_SWEEPPAGESTEPCOST; } - else + + // nothing more to sweep? + if (g->sweepgcopage == NULL) { - while (*g->sweepgc && cost < limit) - { - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX); + // don't forget to visit main thread + sweepgco(L, NULL, obj2gco(g->mainthread)); - cost += GC_SWEEPMAX * GC_SWEEPCOST; - } - - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ } break; } @@ -1013,26 +869,18 @@ void luaC_fullgc(lua_State* L) if (g->gcstate <= GCSatomic) { /* reset sweep marks to sweep all elements (returning them to white) */ - g->sweepstrgc = 0; - if (FFlag::LuauGcPagedSweep) - g->sweepgcopage = g->allgcopages; - else - g->sweepgc = &g->rootgc; + g->sweepgcopage = g->allgcopages; /* reset other collector lists */ g->gray = NULL; g->grayagain = NULL; g->weak = NULL; - - if (FFlag::LuauGcPagedSweep) - g->gcstate = GCSsweep; - else - g->gcstate = GCSsweepstring; + g->gcstate = GCSsweep; } - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSsweep); /* finish any pending sweep phase */ while (g->gcstate != GCSpause) { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSsweep); gcstep(L, SIZE_MAX); } @@ -1120,30 +968,19 @@ void luaC_barrierback(lua_State* L, Table* t) g->grayagain = o; } -void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt) +void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt) { global_State* g = L->global; - if (!FFlag::LuauGcPagedSweep) - { - o->gch.next = g->rootgc; - g->rootgc = o; - } o->gch.marked = luaC_white(g); o->gch.tt = tt; o->gch.memcat = L->activememcat; } -void luaC_linkupval(lua_State* L, UpVal* uv) +void luaC_initupval(lua_State* L, UpVal* uv) { global_State* g = L->global; GCObject* o = obj2gco(uv); - if (!FFlag::LuauGcPagedSweep) - { - o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ - g->rootgc = o; - } - if (isgray(o)) { if (keepinvariant(g)) @@ -1221,9 +1058,6 @@ const char* luaC_statename(int state) case GCSatomic: return "atomic"; - case GCSsweepstring: - return "sweepstring"; - case GCSsweep: return "sweep"; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 2acb5d8a..253e269f 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -13,9 +13,7 @@ #define GCSpropagate 1 #define GCSpropagateagain 2 #define GCSatomic 3 -// TODO: remove with FFlagLuauGcPagedSweep -#define GCSsweepstring 4 -#define GCSsweep 5 +#define GCSsweep 4 /* ** macro to tell when main invariant (white objects cannot point to black @@ -132,13 +130,13 @@ luaC_wakethread(L); \ } -#define luaC_link(L, o, tt) luaC_linkobj(L, cast_to(GCObject*, (o)), tt) +#define luaC_init(L, o, tt) luaC_initobj(L, cast_to(GCObject*, (o)), tt) LUAI_FUNC void luaC_freeall(lua_State* L); LUAI_FUNC void luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); -LUAI_FUNC void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt); -LUAI_FUNC void luaC_linkupval(lua_State* L, UpVal* uv); +LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); +LUAI_FUNC void luaC_initupval(lua_State* L, UpVal* uv); LUAI_FUNC void luaC_barrierupval(lua_State* L, GCObject* v); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 30242e52..2b38619b 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauGcPagedSweep) - static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -104,8 +102,7 @@ static void validatestack(global_State* g, lua_State* l) if (l->namecall) validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (UpVal* uv = l->openupval; uv; uv = (UpVal*)uv->next) + for (UpVal* uv = l->openupval; uv; uv = uv->u.l.threadnext) { LUAU_ASSERT(uv->tt == LUA_TUPVAL); LUAU_ASSERT(uv->v != &uv->u.value); @@ -141,7 +138,7 @@ static void validateobj(global_State* g, GCObject* o) /* dead objects can only occur during sweep */ if (isdead(g, o)) { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSsweep); return; } @@ -180,18 +177,6 @@ static void validateobj(global_State* g, GCObject* o) } } -static void validatelist(global_State* g, GCObject* o) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - while (o) - { - validateobj(g, o); - - o = o->gch.next; - } -} - static void validategraylist(global_State* g, GCObject* o) { if (!keepinvariant(g)) @@ -224,8 +209,6 @@ static void validategraylist(global_State* g, GCObject* o) static bool validategco(void* context, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - lua_State* L = (lua_State*)context; global_State* g = L->global; @@ -248,20 +231,9 @@ void luaC_validate(lua_State* L) validategraylist(g, g->gray); validategraylist(g, g->grayagain); - if (FFlag::LuauGcPagedSweep) - { - validategco(L, NULL, obj2gco(g->mainthread)); + validategco(L, NULL, obj2gco(g->mainthread)); - luaM_visitgco(L, L, validategco); - } - else - { - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, (GCObject*)(g->strt.hash[i])); - - validatelist(g, g->rootgc); - validatelist(g, (GCObject*)(g->strbufgc)); - } + luaM_visitgco(L, L, validategco); for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { @@ -521,30 +493,8 @@ static void dumpobj(FILE* f, GCObject* o) } } -static void dumplist(FILE* f, GCObject* o) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - while (o) - { - dumpref(f, o); - fputc(':', f); - dumpobj(f, o); - fputc(',', f); - fputc('\n', f); - - // thread has additional list containing collectable objects that are not present in rootgc - if (o->gch.tt == LUA_TTHREAD) - dumplist(f, (GCObject*)gco2th(o)->openupval); - - o = o->gch.next; - } -} - static bool dumpgco(void* context, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - FILE* f = (FILE*)context; dumpref(f, gco); @@ -563,19 +513,9 @@ void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* fprintf(f, "{\"objects\":{\n"); - if (FFlag::LuauGcPagedSweep) - { - dumpgco(f, NULL, obj2gco(g->mainthread)); + dumpgco(f, NULL, obj2gco(g->mainthread)); - luaM_visitgco(L, f, dumpgco); - } - else - { - dumplist(f, g->rootgc); - dumplist(f, (GCObject*)(g->strbufgc)); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, (GCObject*)(g->strt.hash[i])); - } + luaM_visitgco(L, f, dumpgco); fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , fprintf(f, "},\"roots\":{\n"); diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index c93f431f..fd95f596 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -68,7 +68,7 @@ void luaL_sandboxthread(lua_State* L) lua_setsafeenv(L, LUA_GLOBALSINDEX, true); } -static void* l_alloc(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize) +static void* l_alloc(void* ud, void* ptr, size_t osize, size_t nsize) { (void)ud; (void)osize; diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 19617b8c..899cb0c0 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -78,8 +78,6 @@ * allocated pages. */ -LUAU_FASTFLAG(LuauGcPagedSweep) - #ifndef __has_feature #define __has_feature(x) 0 #endif @@ -98,8 +96,10 @@ LUAU_FASTFLAG(LuauGcPagedSweep) * To prevent some of them accidentally growing and us losing memory without realizing it, we're going to lock * the sizes of all critical structures down. */ -#if defined(__APPLE__) && !defined(__MACH__) +#if defined(__APPLE__) #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : gcc32) +#elif defined(__i386__) +#define ABISWITCH(x64, ms32, gcc32) (gcc32) #else // Android somehow uses a similar ABI to MSVC, *not* to iOS... #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : ms32) @@ -114,14 +114,8 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table #endif static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); -// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(16, 16, 16) -static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); -// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(48, 32, 32) -static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); - -// TODO (FFlagLuauGcPagedSweep): new code with old 'next' pointer requires that GCObject start at the same point as TString/UpVal -static_assert(offsetof(GCObject, uv) == 0, "UpVal data must be located at the start of the GCObject"); -static_assert(offsetof(GCObject, ts) == 0, "TString data must be located at the start of the GCObject"); +static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for userdata header"); +static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; @@ -208,53 +202,13 @@ l_noret luaM_toobig(lua_State* L) luaG_runerror(L, "memory allocation error: block too big"); } -static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - global_State* g = L->global; - lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, kPageSize); - if (!page) - luaD_throw(L, LUA_ERRMEM); - - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader; - int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; - - ASAN_POISON_MEMORY_REGION(page->data, blockSize * blockCount); - - // setup page header - page->prev = NULL; - page->next = NULL; - - page->gcolistprev = NULL; - page->gcolistnext = NULL; - - page->pageSize = kPageSize; - page->blockSize = blockSize; - - // note: we start with the last block in the page and move downward - // either order would work, but that way we don't need to store the block count in the page - // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order - page->freeList = NULL; - page->freeNext = (blockCount - 1) * blockSize; - page->busyBlocks = 0; - - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!g->freepages[sizeClass]); - g->freepages[sizeClass] = page; - - return page; -} - static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int blockSize, int blockCount) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); - lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); + lua_Page* page = (lua_Page*)(*g->frealloc)(g->ud, NULL, 0, pageSize); if (!page) luaD_throw(L, LUA_ERRMEM); @@ -290,8 +244,6 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; @@ -304,29 +256,8 @@ static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** g return page; } -static void freepageold(lua_State* L, lua_Page* page, uint8_t sizeClass) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - global_State* g = L->global; - - // remove page from freelist - if (page->next) - page->next->prev = page->prev; - - if (page->prev) - page->prev->next = page->next; - else if (g->freepages[sizeClass] == page) - g->freepages[sizeClass] = page->next; - - // so long - (*g->frealloc)(L, g->ud, page, kPageSize, 0); -} - static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; if (gcopageset) @@ -342,13 +273,11 @@ static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) } // so long - (*g->frealloc)(L, g->ud, page, page->pageSize, 0); + (*g->frealloc)(g->ud, page, page->pageSize, 0); } static void freeclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, lua_Page* page, uint8_t sizeClass) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - // remove page from freelist if (page->next) page->next->prev = page->prev; @@ -368,12 +297,7 @@ static void* newblock(lua_State* L, int sizeClass) // slow path: no page in the freelist, allocate a new one if (!page) - { - if (FFlag::LuauGcPagedSweep) - page = newclasspage(L, g->freepages, NULL, sizeClass, true); - else - page = newpageold(L, sizeClass); - } + page = newclasspage(L, g->freepages, NULL, sizeClass, true); LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); @@ -416,8 +340,6 @@ static void* newblock(lua_State* L, int sizeClass) static void* newgcoblock(lua_State* L, int sizeClass) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; lua_Page* page = g->freegcopages[sizeClass]; @@ -496,17 +418,11 @@ static void freeblock(lua_State* L, int sizeClass, void* block) // if it's the last block in the page, we don't need the page if (page->busyBlocks == 0) - { - if (FFlag::LuauGcPagedSweep) - freeclasspage(L, g->freepages, NULL, page, sizeClass); - else - freepageold(L, page, sizeClass); - } + freeclasspage(L, g->freepages, NULL, page, sizeClass); } static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); LUAU_ASSERT(page && page->busyBlocks > 0); LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); LUAU_ASSERT(block >= page->data && block < (char*)page + page->pageSize); @@ -544,7 +460,7 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) int nclass = sizeclass(nsize); - void* block = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + void* block = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(g->ud, NULL, 0, nsize); if (block == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -556,9 +472,6 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) { - if (!FFlag::LuauGcPagedSweep) - return (GCObject*)luaM_new_(L, nsize, memcat); - // we need to accommodate space for link for free blocks (freegcolink) LUAU_ASSERT(nsize >= kGCOLinkOffset + sizeof(void*)); @@ -602,7 +515,7 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) if (oclass >= 0) freeblock(L, oclass, block); else - (*g->frealloc)(L, g->ud, block, osize, 0); + (*g->frealloc)(g->ud, block, osize, 0); g->totalbytes -= osize; g->memcatbytes[memcat] -= osize; @@ -610,12 +523,6 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, lua_Page* page) { - if (!FFlag::LuauGcPagedSweep) - { - luaM_free_(L, block, osize, memcat); - return; - } - global_State* g = L->global; LUAU_ASSERT((osize == 0) == (block == NULL)); @@ -652,7 +559,7 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 // if either block needs to be allocated using a block allocator, we can't use realloc directly if (nclass >= 0 || oclass >= 0) { - result = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + result = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(g->ud, NULL, 0, nsize); if (result == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -662,11 +569,11 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 if (oclass >= 0) freeblock(L, oclass, block); else - (*g->frealloc)(L, g->ud, block, osize, 0); + (*g->frealloc)(g->ud, block, osize, 0); } else { - result = (*g->frealloc)(L, g->ud, block, osize, nsize); + result = (*g->frealloc)(g->ud, block, osize, nsize); if (result == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); } @@ -679,8 +586,6 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - int blockCount = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; LUAU_ASSERT(page->freeNext >= -page->blockSize && page->freeNext <= (blockCount - 1) * page->blockSize); @@ -700,8 +605,6 @@ lua_Page* luaM_getnextgcopage(lua_Page* page) void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - char* start; char* end; int busyBlocks; @@ -730,8 +633,6 @@ void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; for (lua_Page* curr = g->allgcopages; curr;) diff --git a/VM/src/lmem.h b/VM/src/lmem.h index 1bfe48fa..00788452 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -7,11 +7,7 @@ struct lua_Page; union GCObject; -// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_newgco to luaM_new -#define luaM_new(L, t, size, memcat) cast_to(t*, luaM_new_(L, size, memcat)) #define luaM_newgco(L, t, size, memcat) cast_to(t*, luaM_newgco_(L, size, memcat)) -// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_freegco to luaM_free -#define luaM_free(L, p, size, memcat) luaM_free_(L, (p), size, memcat) #define luaM_freegco(L, p, size, memcat, page) luaM_freegco_(L, obj2gco(p), size, memcat, page) #define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 57ffd82a..5e02c2ea 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -15,7 +15,6 @@ typedef union GCObject GCObject; */ // clang-format off #define CommonHeader \ - GCObject* next; /* TODO: remove with FFlagLuauGcPagedSweep */ \ uint8_t tt; uint8_t marked; uint8_t memcat // clang-format on @@ -233,6 +232,8 @@ typedef struct TString int16_t atom; // 2 byte padding + TString* next; // next string in the hash table bucket or the string buffer linked list + unsigned int hash; unsigned int len; @@ -327,7 +328,7 @@ typedef struct UpVal struct UpVal* next; /* thread double linked list (when open) */ - // TODO: when FFlagLuauGcPagedSweep is removed, old outer 'next' value will be placed here + struct UpVal* threadnext; /* note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State */ struct UpVal** threadprev; } l; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index d6d127c0..d4f3f0a1 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,7 +10,6 @@ #include "ldo.h" #include "ldebug.h" -LUAU_FASTFLAG(LuauGcPagedSweep) LUAU_FASTFLAGVARIABLE(LuauReduceStackReallocs, false) /* @@ -90,8 +89,6 @@ static void close_state(lua_State* L) global_State* g = L->global; luaF_close(L, L->stack); /* close all upvalues for this thread */ luaC_freeall(L); /* collect all objects */ - if (!FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->rootgc == obj2gco(L)); LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); @@ -99,22 +96,20 @@ static void close_state(lua_State* L) for (int i = 0; i < LUA_SIZECLASSES; i++) { LUAU_ASSERT(g->freepages[i] == NULL); - if (FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->freegcopages[i] == NULL); + LUAU_ASSERT(g->freegcopages[i] == NULL); } - if (FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->allgcopages == NULL); + LUAU_ASSERT(g->allgcopages == NULL); LUAU_ASSERT(g->totalbytes == sizeof(LG)); LUAU_ASSERT(g->memcatbytes[0] == sizeof(LG)); for (int i = 1; i < LUA_MEMORY_CATEGORIES; i++) LUAU_ASSERT(g->memcatbytes[i] == 0); - (*g->frealloc)(L, g->ud, L, sizeof(LG), 0); + (*g->frealloc)(g->ud, L, sizeof(LG), 0); } lua_State* luaE_newthread(lua_State* L) { lua_State* L1 = luaM_newgco(L, lua_State, sizeof(lua_State), L->activememcat); - luaC_link(L, L1, LUA_TTHREAD); + luaC_init(L, L1, LUA_TTHREAD); preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category stack_init(L1, L); /* init stack */ @@ -184,13 +179,11 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) int i; lua_State* L; global_State* g; - void* l = (*f)(NULL, ud, NULL, 0, sizeof(LG)); + void* l = (*f)(ud, NULL, 0, sizeof(LG)); if (l == NULL) return NULL; L = (lua_State*)l; g = &((LG*)L)->g; - if (!FFlag::LuauGcPagedSweep) - L->next = NULL; L->tt = LUA_TTHREAD; L->marked = g->currentwhite = bit2mask(WHITE0BIT, FIXEDBIT); L->memcat = 0; @@ -214,11 +207,6 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) setnilvalue(&g->pseudotemp); setnilvalue(registry(L)); g->gcstate = GCSpause; - if (!FFlag::LuauGcPagedSweep) - g->rootgc = obj2gco(L); - g->sweepstrgc = 0; - if (!FFlag::LuauGcPagedSweep) - g->sweepgc = &g->rootgc; g->gray = NULL; g->grayagain = NULL; g->weak = NULL; @@ -230,14 +218,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) for (i = 0; i < LUA_SIZECLASSES; i++) { g->freepages[i] = NULL; - if (FFlag::LuauGcPagedSweep) - g->freegcopages[i] = NULL; - } - if (FFlag::LuauGcPagedSweep) - { - g->allgcopages = NULL; - g->sweepgcopage = NULL; + g->freegcopages[i] = NULL; } + g->allgcopages = NULL; + g->sweepgcopage = NULL; for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 6dd89138..3ee96718 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -142,11 +142,6 @@ typedef struct global_State uint8_t gcstate; /* state of garbage collector */ - int sweepstrgc; /* position of sweep in `strt' */ - // TODO: remove with FFlagLuauGcPagedSweep - GCObject* rootgc; /* list of all collectable objects */ - // TODO: remove with FFlagLuauGcPagedSweep - GCObject** sweepgc; /* position of sweep in `rootgc' */ GCObject* gray; /* list of gray objects */ GCObject* grayagain; /* list of objects to be traversed atomically */ GCObject* weak; /* list of weak tables (to be cleared) */ diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 9bbc43de..87250146 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAG(LuauGcPagedSweep) - unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -46,8 +44,6 @@ unsigned int luaS_hash(const char* str, size_t len) void luaS_resize(lua_State* L, int newsize) { - if (L->global->gcstate == GCSsweepstring) - return; /* cannot resize during GC traverse */ TString** newhash = luaM_newarray(L, newsize, TString*, 0); stringtable* tb = &L->global->strt; for (int i = 0; i < newsize; i++) @@ -58,13 +54,11 @@ void luaS_resize(lua_State* L, int newsize) TString* p = tb->hash[i]; while (p) { /* for each node in the list */ - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - TString* next = (TString*)p->next; /* save next */ + TString* next = p->next; /* save next */ unsigned int h = p->hash; int h1 = lmod(h, newsize); /* new position */ LUAU_ASSERT(cast_int(h % newsize) == lmod(h, newsize)); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - p->next = (GCObject*)newhash[h1]; /* chain it */ + p->next = newhash[h1]; /* chain it */ newhash[h1] = p; p = next; } @@ -91,8 +85,7 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; tb = &L->global->strt; h = lmod(h, tb->size); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the case will not be required - ts->next = (GCObject*)tb->hash[h]; /* chain new entry */ + ts->next = tb->hash[h]; /* chain new entry */ tb->hash[h] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) @@ -104,20 +97,9 @@ static void linkstrbuf(lua_State* L, TString* ts) { global_State* g = L->global; - if (FFlag::LuauGcPagedSweep) - { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - ts->next = (GCObject*)g->strbufgc; - g->strbufgc = ts; - ts->marked = luaC_white(g); - } - else - { - GCObject* o = obj2gco(ts); - o->gch.next = (GCObject*)g->strbufgc; - g->strbufgc = gco2ts(o); - o->gch.marked = luaC_white(g); - } + ts->next = g->strbufgc; + g->strbufgc = ts; + ts->marked = luaC_white(g); } static void unlinkstrbuf(lua_State* L, TString* ts) @@ -130,14 +112,12 @@ static void unlinkstrbuf(lua_State* L, TString* ts) { if (curr == ts) { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - *p = (TString*)curr->next; + *p = curr->next; return; } else { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - p = (TString**)&curr->next; + p = &curr->next; } } @@ -167,8 +147,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) int bucket = lmod(h, tb->size); // search if we already have this string in the hash table - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (TString* el = tb->hash[bucket]; el != NULL; el = (TString*)el->next) + for (TString* el = tb->hash[bucket]; el != NULL; el = el->next) { if (el->len == ts->len && memcmp(el->data, ts->data, ts->len) == 0) { @@ -187,8 +166,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) // Complete string object ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - ts->next = (GCObject*)tb->hash[bucket]; // chain new entry + ts->next = tb->hash[bucket]; // chain new entry tb->hash[bucket] = ts; tb->nuse++; @@ -201,8 +179,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) TString* luaS_newlstr(lua_State* L, const char* str, size_t l) { unsigned int h = luaS_hash(str, l); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (TString* el = L->global->strt.hash[lmod(h, L->global->strt.size)]; el != NULL; el = (TString*)el->next) + for (TString* el = L->global->strt.hash[lmod(h, L->global->strt.size)]; el != NULL; el = el->next) { if (el->len == l && (memcmp(str, getstr(el), l) == 0)) { @@ -217,8 +194,6 @@ TString* luaS_newlstr(lua_State* L, const char* str, size_t l) static bool unlinkstr(lua_State* L, TString* ts) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; TString** p = &g->strt.hash[lmod(ts->hash, g->strt.size)]; @@ -227,14 +202,12 @@ static bool unlinkstr(lua_State* L, TString* ts) { if (curr == ts) { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - *p = (TString*)curr->next; + *p = curr->next; return true; } else { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - p = (TString**)&curr->next; + p = &curr->next; } } @@ -243,20 +216,11 @@ static bool unlinkstr(lua_State* L, TString* ts) void luaS_free(lua_State* L, TString* ts, lua_Page* page) { - if (FFlag::LuauGcPagedSweep) - { - // Unchain from the string table - if (!unlinkstr(L, ts)) - unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands - else - L->global->strt.nuse--; - - luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); - } + // Unchain from the string table + if (!unlinkstr(L, ts)) + unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands else - { L->global->strt.nuse--; - luaM_free(L, ts, sizestring(ts->len), ts->memcat); - } + luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); } diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index c57374e0..0412ea76 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -425,7 +425,7 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) Table* luaH_new(lua_State* L, int narray, int nhash) { Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); - luaC_link(L, t, LUA_TTABLE); + luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; t->flags = cast_byte(~0); t->array = NULL; @@ -742,7 +742,7 @@ int luaH_getn(Table* t) Table* luaH_clone(lua_State* L, Table* tt) { Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); - luaC_link(L, t, LUA_TTABLE); + luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; t->flags = tt->flags; t->array = NULL; diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 758a9bdb..0dfac508 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -12,7 +12,7 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) if (s > INT_MAX - sizeof(Udata)) luaM_toobig(L); Udata* u = luaM_newgco(L, Udata, sizeudata(s), L->activememcat); - luaC_link(L, u, LUA_TUSERDATA); + luaC_init(L, u, LUA_TUSERDATA); u->len = int(s); u->metatable = NULL; LUAU_ASSERT(tag >= 0 && tag <= 255); diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 04638d23..0bdd49f5 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -1,10 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include -#include "Luau/TypeInfer.h" -#include "Luau/Linter.h" + #include "Luau/BuiltinDefinitions.h" -#include "Luau/ModuleResolver.h" #include "Luau/Common.h" +#include "Luau/Linter.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) { diff --git a/fuzz/luau.proto b/fuzz/luau.proto index c78fcf31..190b8c5b 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -96,11 +96,13 @@ message ExprIndexExpr { } message ExprFunction { - repeated Local args = 1; - required bool vararg = 2; - required StatBlock body = 3; - repeated Type types = 4; - repeated Type rettypes = 5; + repeated Typename generics = 1; + repeated Typename genericpacks = 2; + repeated Local args = 3; + required bool vararg = 4; + required StatBlock body = 5; + repeated Type types = 6; + repeated Type rettypes = 7; } message TableItem { @@ -153,7 +155,10 @@ message ExprBinary { message ExprIfElse { required Expr cond = 1; required Expr then = 2; - required Expr else = 3; + oneof else_oneof { + Expr else = 3; + ExprIfElse elseif = 4; + } } message LValue { @@ -183,6 +188,7 @@ message Stat { StatFunction function = 14; StatLocalFunction local_function = 15; StatTypeAlias type_alias = 16; + StatRequireIntoLocalHelper require_into_local = 17; } } @@ -276,9 +282,16 @@ message StatLocalFunction { } message StatTypeAlias { - required Typename name = 1; - required Type type = 2; - repeated Typename generics = 3; + required bool export = 1; + required Typename name = 2; + required Type type = 3; + repeated Typename generics = 4; + repeated Typename genericpacks = 5; +} + +message StatRequireIntoLocalHelper { + required Local var = 1; + required int32 modulenum = 2; } message Type { @@ -292,6 +305,8 @@ message Type { TypeIntersection intersection = 7; TypeClass class = 8; TypeRef ref = 9; + TypeBoolean boolean = 10; + TypeString string = 11; } } @@ -301,7 +316,8 @@ message TypePrimitive { message TypeLiteral { required Typename name = 1; - repeated Typename generics = 2; + repeated Type generics = 2; + repeated Typename genericpacks = 3; } message TypeTableItem { @@ -320,8 +336,10 @@ message TypeTable { } message TypeFunction { - repeated Type args = 1; - repeated Type rets = 2; + repeated Typename generics = 1; + repeated Typename genericpacks = 2; + repeated Type args = 3; + repeated Type rets = 4; // TODO: vararg? } @@ -347,3 +365,16 @@ message TypeRef { required Local prefix = 1; required Typename index = 2; } + +message TypeBoolean { + required bool val = 1; +} + +message TypeString { + required string val = 1; +} + +message ModuleSet { + optional StatBlock module = 1; + required StatBlock program = 2; +} diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index f407248a..912fef23 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -2,16 +2,17 @@ #include "src/libfuzzer/libfuzzer_macro.h" #include "luau.pb.h" -#include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/ModuleResolver.h" -#include "Luau/ModuleResolver.h" -#include "Luau/Compiler.h" -#include "Luau/Linter.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/Compiler.h" +#include "Luau/Frontend.h" +#include "Luau/Linter.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" #include "Luau/ToString.h" #include "Luau/Transpiler.h" +#include "Luau/TypeInfer.h" #include "lua.h" #include "lualib.h" @@ -30,7 +31,7 @@ const bool kFuzzTypes = true; static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); -std::string protoprint(const luau::StatBlock& stat, bool types); +std::vector protoprint(const luau::ModuleSet& stat, bool types); LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) @@ -38,6 +39,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) +LUAU_FASTFLAG(DebugLuauFreezeArena) std::chrono::milliseconds kInterruptTimeout(10); std::chrono::time_point interruptDeadline; @@ -135,10 +137,58 @@ int registerTypes(Luau::TypeChecker& env) return 0; } +struct FuzzFileResolver : Luau::FileResolver +{ + std::optional readSource(const Luau::ModuleName& name) override + { + auto it = source.find(name); + if (it == source.end()) + return std::nullopt; -static std::string debugsource; + return Luau::SourceCode{it->second, Luau::SourceCode::Module}; + } -DEFINE_PROTO_FUZZER(const luau::StatBlock& message) + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* expr) override + { + if (Luau::AstExprGlobal* g = expr->as()) + return Luau::ModuleInfo{g->name.value}; + + return std::nullopt; + } + + std::string getHumanReadableModuleName(const Luau::ModuleName& name) const override + { + return name; + } + + std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override + { + return std::nullopt; + } + + std::unordered_map source; +}; + +struct FuzzConfigResolver : Luau::ConfigResolver +{ + FuzzConfigResolver() + { + defaultConfig.mode = Luau::Mode::Nonstrict; // typecheckTwice option will cover Strict mode + defaultConfig.enabledLint.warningMask = ~0ull; + defaultConfig.parseOptions.captureComments = true; + } + + virtual const Luau::Config& getConfig(const Luau::ModuleName& name) const override + { + return defaultConfig; + } + + Luau::Config defaultConfig; +}; + +static std::vector debugsources; + +DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { FInt::LuauTypeInferRecursionLimit.value = 100; FInt::LuauTypeInferTypePackLoopLimit.value = 100; @@ -151,91 +201,90 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = true; - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); + FFlag::DebugLuauFreezeArena.value = true; - std::string source = protoprint(message, kFuzzTypes); + std::vector sources = protoprint(message, kFuzzTypes); // stash source in a global for easier crash dump debugging - debugsource = source; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.c_str(), source.size(), names, allocator); - - // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = registerTypes(sharedEnv); - (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 0); - (void)once2; - - iceHandler.onInternalError = [](const char* error) { - printf("ICE: %s\n", error); - LUAU_ASSERT(!"ICE"); - }; + debugsources = sources; static bool debug = getenv("LUAU_DEBUG") != 0; if (debug) { - fprintf(stdout, "--\n%s\n", source.c_str()); + for (std::string& source : sources) + fprintf(stdout, "--\n%s\n", source.c_str()); fflush(stdout); } - std::string bytecode; + // parse all sources + std::vector> parseAllocators; + std::vector> parseNameTables; - // compile - if (kFuzzCompiler && parseResult.errors.empty()) + Luau::ParseOptions parseOptions; + parseOptions.captureComments = true; + + std::vector parseResults; + + for (std::string& source : sources) { - Luau::CompileOptions compileOptions; + parseAllocators.push_back(std::make_unique()); + parseNameTables.push_back(std::make_unique(*parseAllocators.back())); - try - { - Luau::BytecodeBuilder bcb; - Luau::compileOrThrow(bcb, parseResult.root, names, compileOptions); - bytecode = bcb.getBytecode(); - } - catch (const Luau::CompileError&) - { - // not all valid ASTs can be compiled due to limits on number of registers - } + parseResults.push_back(Luau::Parser::parse(source.c_str(), source.size(), *parseNameTables.back(), *parseAllocators.back(), parseOptions)); } - // typecheck - if (kFuzzTypeck && parseResult.root) - { - Luau::SourceModule sourceModule; - sourceModule.root = parseResult.root; - sourceModule.mode = Luau::Mode::Nonstrict; - - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; - - Luau::ModulePtr module = nullptr; - - try - { - module = typeck.check(sourceModule, Luau::Mode::Nonstrict); - } - catch (std::exception&) - { - // This catches internal errors that the type checker currently (unfortunately) throws in some cases - } - - // lint (note that we need access to types so we need to do this with typeck in scope) - if (kFuzzLinter) - { - Luau::LintOptions lintOptions = {~0u}; - Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), {}, lintOptions); - } - } - - // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down - // note: it's important for typeck to be destroyed at this point! + // typecheck all sources if (kFuzzTypeck) { - for (auto& p : sharedEnv.globalScope->bindings) + static FuzzFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::FrontendOptions options{true, true}; + static Luau::Frontend frontend(&fileResolver, &configResolver, options); + + static int once = registerTypes(frontend.typeChecker); + (void)once; + static int once2 = (Luau::freeze(frontend.typeChecker.globalTypes), 0); + (void)once2; + + frontend.iceHandler.onInternalError = [](const char* error) { + printf("ICE: %s\n", error); + LUAU_ASSERT(!"ICE"); + }; + + // restart + frontend.clear(); + fileResolver.source.clear(); + + // load sources + for (size_t i = 0; i < sources.size(); i++) + { + std::string name = "module" + std::to_string(i); + fileResolver.source[name] = sources[i]; + } + + // check sources + for (size_t i = 0; i < sources.size(); i++) + { + std::string name = "module" + std::to_string(i); + + try + { + Luau::CheckResult result = frontend.check(name, std::nullopt); + + // lint (note that we need access to types so we need to do this with typeck in scope) + if (kFuzzLinter && result.errors.empty()) + frontend.lint(name, std::nullopt); + } + catch (std::exception&) + { + // This catches internal errors that the type checker currently (unfortunately) throws in some cases + } + } + + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + for (auto& p : frontend.typeChecker.globalScope->bindings) { Luau::ToStringOptions opts; opts.exhaustive = true; @@ -246,12 +295,44 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) } } - if (kFuzzTranspile && parseResult.root) + if (kFuzzTranspile) { - transpileWithTypes(*parseResult.root); + for (Luau::ParseResult& parseResult : parseResults) + { + if (parseResult.root) + transpileWithTypes(*parseResult.root); + } } - // run resulting bytecode + std::string bytecode; + + // compile + if (kFuzzCompiler) + { + for (size_t i = 0; i < parseResults.size(); i++) + { + Luau::ParseResult& parseResult = parseResults[i]; + Luau::AstNameTable& parseNameTable = *parseNameTables[i]; + + if (parseResult.errors.empty()) + { + Luau::CompileOptions compileOptions; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, parseResult.root, parseNameTable, compileOptions); + bytecode = bcb.getBytecode(); + } + catch (const Luau::CompileError&) + { + // not all valid ASTs can be compiled due to limits on number of registers + } + } + } + } + + // run resulting bytecode (from last successfully compiler module) if (kFuzzVM && bytecode.size()) { static lua_State* globalState = createGlobalState(); diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index e61b6936..66a89f24 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -208,6 +208,35 @@ struct ProtoToLuau source += std::to_string(name.index() & 0xff); } + template + void genericidents(const T& node) + { + if (node.generics_size() || node.genericpacks_size()) + { + source += '<'; + bool first = true; + + for (size_t i = 0; i < node.generics_size(); ++i) + { + if (!first) + source += ','; + first = false; + ident(node.generics(i)); + } + + for (size_t i = 0; i < node.genericpacks_size(); ++i) + { + if (!first) + source += ','; + first = false; + ident(node.genericpacks(i)); + source += "..."; + } + + source += '>'; + } + } + void print(const luau::Expr& expr) { if (expr.has_group()) @@ -240,6 +269,8 @@ struct ProtoToLuau print(expr.unary()); else if (expr.has_binary()) print(expr.binary()); + else if (expr.has_ifelse()) + print(expr.ifelse()); else source += "_"; } @@ -350,6 +381,7 @@ struct ProtoToLuau void function(const luau::ExprFunction& expr) { + genericidents(expr); source += "("; for (int i = 0; i < expr.args_size(); ++i) { @@ -478,12 +510,21 @@ struct ProtoToLuau void print(const luau::ExprIfElse& expr) { - source += " if "; + source += "if "; print(expr.cond()); source += " then "; print(expr.then()); - source += " else "; - print(expr.else_()); + + if (expr.has_else_()) + { + source += " else "; + print(expr.else_()); + } + else if (expr.has_elseif()) + { + source += " else"; + print(expr.elseif()); + } } void print(const luau::LValue& expr) @@ -534,6 +575,8 @@ struct ProtoToLuau print(stat.local_function()); else if (stat.has_type_alias()) print(stat.type_alias()); + else if (stat.has_require_into_local()) + print(stat.require_into_local()); else source += "do end\n"; } @@ -804,26 +847,24 @@ struct ProtoToLuau void print(const luau::StatTypeAlias& stat) { + if (stat.export_()) + source += "export "; + source += "type "; ident(stat.name()); - - if (stat.generics_size()) - { - source += '<'; - for (size_t i = 0; i < stat.generics_size(); ++i) - { - if (i != 0) - source += ','; - ident(stat.generics(i)); - } - source += '>'; - } - + genericidents(stat); source += " = "; print(stat.type()); source += '\n'; } + void print(const luau::StatRequireIntoLocalHelper& stat) + { + source += "local "; + print(stat.var()); + source += " = require(module" + std::to_string(stat.modulenum() % 2) + ")\n"; + } + void print(const luau::Type& type) { if (type.has_primitive()) @@ -844,6 +885,10 @@ struct ProtoToLuau print(type.class_()); else if (type.has_ref()) print(type.ref()); + else if (type.has_boolean()) + print(type.boolean()); + else if (type.has_string()) + print(type.string()); else source += "any"; } @@ -858,15 +903,28 @@ struct ProtoToLuau { ident(type.name()); - if (type.generics_size()) + if (type.generics_size() || type.genericpacks_size()) { source += '<'; + bool first = true; + for (size_t i = 0; i < type.generics_size(); ++i) { - if (i != 0) + if (!first) source += ','; - ident(type.generics(i)); + first = false; + print(type.generics(i)); } + + for (size_t i = 0; i < type.genericpacks_size(); ++i) + { + if (!first) + source += ','; + first = false; + ident(type.genericpacks(i)); + source += "..."; + } + source += '>'; } } @@ -893,6 +951,7 @@ struct ProtoToLuau void print(const luau::TypeFunction& type) { + genericidents(type); source += '('; for (size_t i = 0; i < type.args_size(); ++i) { @@ -950,12 +1009,38 @@ struct ProtoToLuau source += '.'; ident(type.index()); } + + void print(const luau::TypeBoolean& type) + { + source += type.val() ? "true" : "false"; + } + + void print(const luau::TypeString& type) + { + source += '"'; + for (char ch : type.val()) + if (isgraph(ch)) + source += ch; + source += '"'; + } }; -std::string protoprint(const luau::StatBlock& stat, bool types) +std::vector protoprint(const luau::ModuleSet& stat, bool types) { + std::vector result; + + if (stat.has_module()) + { + ProtoToLuau printer; + printer.types = types; + printer.print(stat.module()); + result.push_back(printer.source); + } + ProtoToLuau printer; printer.types = types; - printer.print(stat); - return printer.source; + printer.print(stat.program()); + result.push_back(printer.source); + + return result; } diff --git a/fuzz/prototest.cpp b/fuzz/prototest.cpp index 804e708a..ccaa1971 100644 --- a/fuzz/prototest.cpp +++ b/fuzz/prototest.cpp @@ -2,11 +2,15 @@ #include "src/libfuzzer/libfuzzer_macro.h" #include "luau.pb.h" -std::string protoprint(const luau::StatBlock& stat, bool types); +std::vector protoprint(const luau::ModuleSet& stat, bool types); -DEFINE_PROTO_FUZZER(const luau::StatBlock& message) +DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { - std::string source = protoprint(message, true); + std::vector sources = protoprint(message, true); - printf("%s\n", source.c_str()); + for (size_t i = 0; i < sources.size(); i++) + { + printf("Module 'l%d':\n", int(i)); + printf("%s\n", sources[i].c_str()); + } } diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index 5020c771..3905cc19 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -1,9 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include -#include "Luau/TypeInfer.h" + #include "Luau/BuiltinDefinitions.h" -#include "Luau/ModuleResolver.h" #include "Luau/Common.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1978a0d3..6aadef32 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2752,7 +2752,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") ScopedFastFlag sffs[] = { {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauRefactorTypeVarQuestions", true}, }; check(R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b09c1efb..eb6ca749 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -712,6 +712,47 @@ TEST_CASE("Reference") CHECK(dtorhits == 2); } +TEST_CASE("ApiTables") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_newtable(L); + lua_pushnumber(L, 123.0); + lua_setfield(L, -2, "key"); + lua_pushstring(L, "test"); + lua_rawseti(L, -2, 5); + + // lua_gettable + lua_pushstring(L, "key"); + CHECK(lua_gettable(L, -2) == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_getfield + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawgetfield + CHECK(lua_rawgetfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawget + lua_pushstring(L, "key"); + CHECK(lua_rawget(L, -2) == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawgeti + CHECK(lua_rawgeti(L, -1, 5) == LUA_TSTRING); + CHECK(strcmp(lua_tostring(L, -1), "test") == 0); + lua_pop(L, 1); + + lua_pop(L, 1); +} + TEST_CASE("ApiFunctionCalls") { StateRef globalState = runConformance("apicalls.lua"); @@ -796,7 +837,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](lua_State* L, void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { + auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { if (nsize == 0) { free(ptr); @@ -923,4 +964,53 @@ TEST_CASE("StringConversion") runConformance("strconv.lua"); } +TEST_CASE("GCDump") +{ + // internal function, declared in lgc.h - not exposed via lua.h + extern void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // push various objects on stack to cover different paths + lua_createtable(L, 1, 2); + lua_pushstring(L, "value"); + lua_setfield(L, -2, "key"); + + lua_pushinteger(L, 42); + lua_rawseti(L, -2, 1000); + + lua_pushinteger(L, 42); + lua_rawseti(L, -2, 1); + + lua_pushvalue(L, -1); + lua_setmetatable(L, -2); + + lua_newuserdata(L, 42); + lua_pushvalue(L, -2); + lua_setmetatable(L, -2); + + lua_pushinteger(L, 1); + lua_pushcclosure(L, lua_silence, "test", 1); + + lua_State* CL = lua_newthread(L); + + lua_pushstring(CL, "local x x = {} local function f() x[1] = math.abs(42) end function foo() coroutine.yield() end foo() return f"); + lua_loadstring(CL); + lua_resume(CL, nullptr, 0); + +#ifdef _WIN32 + const char* path = "NUL"; +#else + const char* path = "/dev/null"; +#endif + + FILE* f = fopen(path, "w"); + REQUIRE(f); + + luaC_dump(L, f, nullptr); + + fclose(f); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index d4b97360..ab19cea3 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1594,4 +1594,17 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") REQUIRE_EQ(result.warnings.size(), 0); // --!nolint disables WrongComment lint :) } +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsIfStatAndExpr") +{ + LintResult result = lint(R"( +if if 1 then 2 else 3 then +elseif if 1 then 2 else 3 then +elseif if 0 then 5 else 4 then +end +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 0d4c088d..77e49ce3 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2575,7 +2575,6 @@ do end TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true}; ParseResult result = tryParse(R"( type Y = (T...) -> U... diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index ac5be859..332aba9e 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -651,4 +651,19 @@ local a: Packed CHECK_EQ(code, transpile(code, {}, true).code); } + +TEST_CASE_FIXTURE(Fixture, "transpile_singleton_types") +{ + ScopedFastFlag luauParseSingletonTypes{"LuauParseSingletonTypes", true}; + + std::string code = R"( +type t1 = 'hello' +type t2 = true +type t3 = '' +type t4 = false + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index d677e28d..26881b5c 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -175,6 +175,8 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") { + ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; + CheckResult result = check(R"( type A = {x: {y: {z: {thing: string}}}} type B = {x: {y: {z: {thing: string}}}} @@ -184,7 +186,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_dep )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("r")); + CHECK_EQ("string & string", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") @@ -218,7 +220,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_type_any") { CheckResult result = check(R"( - type A = {x: number} + type A = {y: number} type B = {x: any} local t: A & B diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 48e6be6a..bff8926c 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1115,6 +1115,22 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersection_table") +{ + ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; + + CheckResult result = check(R"( + type T = {} & {f: ((string) -> string)?} + local function f(t: T, x) + if t.f then + t.f(x) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { ScopedFastFlag sff[] = { diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 9021700d..856549bd 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -439,4 +439,128 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauSingletonTypes", true}, + {"LuauEqConstraint", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWeakEqConstraint", false}, + }; + + CheckResult result = check(R"( + local function foo(f, x) + if x == "hi" then + f(x) + f("foo") + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 18}))); + // should be ((string) -> a..., string) -> () but needs lower bounds calculation + CHECK_EQ("((string) -> (b...), a) -> ()", toString(requireType("foo"))); +} + +// TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") +// { +// ScopedFastFlag sff[]{ +// {"LuauParseSingletonTypes", true}, +// {"LuauSingletonTypes", true}, +// {"LuauDiscriminableUnions2", true}, +// {"LuauEqConstraint", true}, +// {"LuauWidenIfSupertypeIsFree", true}, +// {"LuauWeakEqConstraint", false}, +// }; + +// CheckResult result = check(R"( +// local function foo(f, x): "hello"? -- anyone there? +// return if x == "hi" +// then f(x) +// else nil +// end +// )"); + +// LUAU_REQUIRE_NO_ERRORS(result); + +// CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +// CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); +// } + +TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauWidenIfSupertypeIsFree", true}, + }; + + CheckResult result = check(R"( + local foo: "foo" = "foo" + local copy = foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireType("copy"))); +} + +TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree", true}, + }; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", meows: boolean} + type Dog = {tag: "Dog", barks: boolean} + type Animal = Cat | Dog + + local function f(tag: "Cat" | "Dog"): Animal? + if tag == "Cat" then + local result = {tag = tag, meows = true} + return result + elseif tag == "Dog" then + local result = {tag = tag, barks = true} + return result + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauWidenIfSupertypeIsFree", true}, + }; + + CheckResult result = check(R"( + local function foo(t, x) + if x == "hi" or x == "bye" then + table.insert(t, x) + end + + return t + end + + local t = foo({}, "hi") + table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{string}", toString(requireType("t"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 6bcd4b99..aa949789 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2239,4 +2239,22 @@ TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") CHECK_EQ("Type 't2' does not have key 'x'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "confusing_indexing") +{ + ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; + + CheckResult result = check(R"( + type T = {} & {p: number | string} + local function f(t: T) + return t.p + end + + local foo = f({p = "string"}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number | string", toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 32358571..f44d9fd8 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5127,8 +5127,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") { - ScopedFastFlag noSealedTypeMod{"LuauNoSealedTypeMod", true}; - fileResolver.source["game/A"] = R"( export type Type = { unrelated: boolean } return {} @@ -5190,7 +5188,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; @@ -5210,7 +5207,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; @@ -5230,7 +5226,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; @@ -5250,7 +5245,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 4669ea8e..c9bf5103 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -19,7 +19,7 @@ struct TryUnifyFixture : Fixture ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState}; + Unifier state{&arena, Mode::Strict, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -261,8 +261,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") { - ScopedFastFlag luauUnionTagMatchFix{"LuauUnionTagMatchFix", true}; - TypeVar redirect{FreeTypeVar{TypeLevel{}}}; TypeVar table{TableTypeVar{}}; TypeVar metatable{MetatableTypeVar{&redirect, &table}}; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index e43161fa..fd5f4dbc 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -361,16 +361,12 @@ local b: (T, T, T) -> T TEST_CASE("isString_on_string_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; CHECK(isString(&helloString)); } TEST_CASE("isString_on_unions_of_various_string_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; @@ -380,8 +376,6 @@ TEST_CASE("isString_on_unions_of_various_string_singletons") TEST_CASE("proof_that_isString_uses_all_of") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; @@ -392,16 +386,12 @@ TEST_CASE("proof_that_isString_uses_all_of") TEST_CASE("isBoolean_on_boolean_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; CHECK(isBoolean(&trueBool)); } TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; @@ -411,8 +401,6 @@ TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") TEST_CASE("proof_that_isBoolean_uses_all_of") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index 9924e194..ccc7e390 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -183,13 +183,12 @@ openupval - u.l.next + u.l.threadnext this - l_gt - env + gt userdata From 9bfecab5baf796c0b358a88fbc8ca8d70af2da12 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 4 Mar 2022 08:19:20 -0800 Subject: [PATCH 028/102] Sync to upstream/release/517 --- Analysis/include/Luau/TxnLog.h | 32 ++- Analysis/include/Luau/TypeInfer.h | 18 -- Analysis/include/Luau/TypeVar.h | 5 +- Analysis/include/Luau/Unifier.h | 2 +- Analysis/src/Autocomplete.cpp | 66 ++--- Analysis/src/BuiltinDefinitions.cpp | 13 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 1 - Analysis/src/JsonEncoder.cpp | 44 +-- Analysis/src/Linter.cpp | 157 +++++++++- Analysis/src/Module.cpp | 5 +- Analysis/src/ToString.cpp | 119 +------- Analysis/src/Transpiler.cpp | 59 ++-- Analysis/src/TxnLog.cpp | 84 +++++- Analysis/src/TypeInfer.cpp | 301 ++++++++------------ Analysis/src/TypeVar.cpp | 27 +- Analysis/src/Unifier.cpp | 137 +++++++-- Ast/src/Parser.cpp | 3 +- CLI/Repl.cpp | 35 +-- VM/include/lua.h | 3 + VM/src/lapi.cpp | 52 +++- VM/src/laux.cpp | 46 +-- VM/src/ldebug.cpp | 5 + VM/src/lgc.cpp | 74 ++++- VM/src/lgc.h | 7 - VM/src/lnumprint.cpp | 9 - VM/src/lnumutils.h | 1 - VM/src/lstate.h | 23 +- VM/src/ltable.cpp | 3 +- VM/src/ltablib.cpp | 21 ++ bench/gc/test_GC_Boehm_Trees.lua | 3 + bench/gc/test_GC_Tree_Pruning_Eager.lua | 2 +- bench/gc/test_GC_Tree_Pruning_Gen.lua | 2 +- bench/gc/test_GC_Tree_Pruning_Lazy.lua | 2 +- extern/isocline/include/isocline.h | 2 +- fuzz/number.cpp | 4 - fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 10 +- tests/Conformance.test.cpp | 6 +- tests/Linter.test.cpp | 107 +++++++ tests/Parser.test.cpp | 6 - tests/ToDot.test.cpp | 2 - tests/ToString.test.cpp | 2 - tests/Transpiler.test.cpp | 3 - tests/TypeInfer.aliases.test.cpp | 3 +- tests/TypeInfer.builtins.test.cpp | 27 ++ tests/TypeInfer.generics.test.cpp | 89 ++++++ tests/TypeInfer.provisional.test.cpp | 95 +----- tests/TypeInfer.refinements.test.cpp | 11 +- tests/TypeInfer.singletons.test.cpp | 44 +-- tests/TypeInfer.tables.test.cpp | 81 +++++- tests/TypeInfer.test.cpp | 71 +++++ tests/TypeInfer.tryUnify.test.cpp | 7 +- tests/TypeInfer.typePacks.cpp | 39 --- tests/TypeInfer.unionTypes.test.cpp | 31 +- tests/conformance/nextvar.lua | 38 +++ 55 files changed, 1267 insertions(+), 774 deletions(-) diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index f238e258..f8105383 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -12,6 +12,8 @@ LUAU_FASTFLAG(LuauShareTxnSeen); namespace Luau { +using TypeOrPackId = const void*; + // Log of where what TypeIds we are rebinding and what they used to be // Remove with LuauUseCommitTxnLog struct DEPRECATED_TxnLog @@ -23,7 +25,7 @@ struct DEPRECATED_TxnLog { } - explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) + explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) : originalSeenSize(sharedSeen->size()) , ownedSeen() , sharedSeen(sharedSeen) @@ -48,15 +50,23 @@ struct DEPRECATED_TxnLog void pushSeen(TypeId lhs, TypeId rhs); void popSeen(TypeId lhs, TypeId rhs); + bool haveSeen(TypePackId lhs, TypePackId rhs); + void pushSeen(TypePackId lhs, TypePackId rhs); + void popSeen(TypePackId lhs, TypePackId rhs); + private: std::vector> typeVarChanges; std::vector> typePackChanges; std::vector>> tableChanges; size_t originalSeenSize; + bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); + public: - std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic - std::vector>* sharedSeen; // shared with all the descendent logs + std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic + std::vector>* sharedSeen; // shared with all the descendent logs }; // Pending state for a TypeVar. Generated by a TxnLog and committed via @@ -127,12 +137,12 @@ struct TxnLog } } - explicit TxnLog(std::vector>* sharedSeen) + explicit TxnLog(std::vector>* sharedSeen) : sharedSeen(sharedSeen) { } - TxnLog(TxnLog* parent, std::vector>* sharedSeen) + TxnLog(TxnLog* parent, std::vector>* sharedSeen) : parent(parent) , sharedSeen(sharedSeen) { @@ -173,6 +183,10 @@ struct TxnLog void pushSeen(TypeId lhs, TypeId rhs); void popSeen(TypeId lhs, TypeId rhs); + bool haveSeen(TypePackId lhs, TypePackId rhs) const; + void pushSeen(TypePackId lhs, TypePackId rhs); + void popSeen(TypePackId lhs, TypePackId rhs); + // Queues a type for modification. The original type will not change until commit // is called. Use pending to get the pending state. // @@ -316,12 +330,16 @@ private: // TxnLogs; use sharedSeen instead. This field exists because in the tree // of TxnLogs, the root must own its seen set. In all descendant TxnLogs, // this is an empty vector. - std::vector> ownedSeen; + std::vector> ownedSeen; + + bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const; + void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); public: // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. - std::vector>* sharedSeen; + std::vector>* sharedSeen; }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 2440c810..839043cc 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -73,24 +73,6 @@ struct Instantiation : Substitution TypePackId clean(TypePackId tp) override; }; -// A substitution which replaces free types by generic types. -struct Quantification : Substitution -{ - Quantification(TypeArena* arena, TypeLevel level) - : Substitution(TxnLog::empty(), arena) - , level(level) - { - } - - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - // A substitution which replaces free types by any struct Anyification : Substitution { diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8d1a9fa6..29578dcd 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -298,7 +298,7 @@ struct TableTypeVar TableTypeVar() = default; explicit TableTypeVar(TableState state, TypeLevel level); - TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state = TableState::Unsealed); + TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state); Props props; std::optional indexer; @@ -477,6 +477,9 @@ bool isOptional(TypeId ty); bool isTableIntersection(TypeId ty); bool isOverloadedFunction(TypeId ty); +// True when string is a subtype of ty +bool maybeString(TypeId ty); + std::optional getMetatable(TypeId type); TableTypeVar* getMutableTableType(TypeId type); const TableTypeVar* getTableType(TypeId type); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index fe822b01..4c0462fe 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -56,7 +56,7 @@ struct Unifier Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 29a2c6b5..c3de8d0e 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); -LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); static const std::unordered_set kStatementStartingKeywords = { @@ -272,55 +271,34 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); - if (FFlag::PreferToCallFunctionsForIntersects) - { - auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { - auto [retHead, retTail] = flatten(ftv->retType); + auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return true; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) return true; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return true; - } - - return false; - }; - - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; } - else if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId id : itv->parts) - { - if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - } - } - } - else + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) { - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId id : itv->parts) { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return TypeCorrectKind::CorrectFunctionResult; + return TypeCorrectKind::CorrectFunctionResult; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index d72422a5..e4e5dab8 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) +LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -283,8 +284,16 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); - auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); - attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); + if (TableTypeVar* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + // tabTy is a generic table type which we can't express via declaration syntax yet + ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); + + if (FFlag::LuauTableCloneType) + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index bf6e1193..471b61ad 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -170,7 +170,6 @@ declare function gcinfo(): number move: ({V}, number, number, number, {V}?) -> {V}, clear: ({[K]: V}) -> (), - freeze: ({[K]: V}) -> {[K]: V}, isfrozen: ({[K]: V}) -> boolean, } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index ec399158..811e7c24 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,8 +5,6 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" -LUAU_FASTFLAG(LuauTypeAliasDefaults) - namespace Luau { @@ -369,38 +367,24 @@ struct AstJsonEncoder : public AstVisitor void write(const AstGenericType& genericType) { - if (FFlag::LuauTypeAliasDefaults) - { - writeRaw("{"); - bool c = pushComma(); - write("name", genericType.name); - if (genericType.defaultValue) - write("type", genericType.defaultValue); - popComma(c); - writeRaw("}"); - } - else - { - write(genericType.name); - } + writeRaw("{"); + bool c = pushComma(); + write("name", genericType.name); + if (genericType.defaultValue) + write("type", genericType.defaultValue); + popComma(c); + writeRaw("}"); } void write(const AstGenericTypePack& genericTypePack) { - if (FFlag::LuauTypeAliasDefaults) - { - writeRaw("{"); - bool c = pushComma(); - write("name", genericTypePack.name); - if (genericTypePack.defaultValue) - write("type", genericTypePack.defaultValue); - popComma(c); - writeRaw("}"); - } - else - { - write(genericTypePack.name); - } + writeRaw("{"); + bool c = pushComma(); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("type", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); } void write(AstExprTable::Item::Kind kind) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 7635dc0f..56c4e3e8 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -13,6 +13,7 @@ #include LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) +LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) namespace Luau { @@ -233,6 +234,20 @@ public: } private: + struct FunctionInfo + { + explicit FunctionInfo(AstExprFunction* ast) + : ast(ast) + , dominatedGlobals({}) + , conditionalExecution(false) + { + } + + AstExprFunction* ast; + DenseHashSet dominatedGlobals; + bool conditionalExecution; + }; + struct Global { AstExprGlobal* firstRef = nullptr; @@ -241,6 +256,9 @@ private: bool assigned = false; bool builtin = false; + bool definedInModuleScope = false; + bool definedAsFunction = false; + bool readBeforeWritten = false; std::optional deprecated; }; @@ -248,7 +266,8 @@ private: DenseHashMap globals; std::vector globalRefs; - std::vector functionStack; + std::vector functionStack; + LintGlobalLocal() : globals(AstName()) @@ -291,12 +310,18 @@ private: "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", g.firstRef->name.value, top->location.begin.line + 1); } + else if (FFlag::LuauLintGlobalNeverReadBeforeWritten && g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && + g.firstRef->name != context->placeholder) + { + emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + "Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); + } } } bool visit(AstExprFunction* node) override { - functionStack.push_back(node); + functionStack.emplace_back(node); node->body->visit(this); @@ -307,6 +332,11 @@ private: bool visit(AstExprGlobal* node) override { + if (FFlag::LuauLintGlobalNeverReadBeforeWritten && !functionStack.empty() && !functionStack.back().dominatedGlobals.contains(node->name)) + { + Global& g = globals[node->name]; + g.readBeforeWritten = true; + } trackGlobalRef(node); if (node->name == context->placeholder) @@ -335,6 +365,21 @@ private: { Global& g = globals[gv->name]; + if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + { + if (functionStack.empty()) + { + g.definedInModuleScope = true; + } + else + { + if (!functionStack.back().conditionalExecution) + { + functionStack.back().dominatedGlobals.insert(gv->name); + } + } + } + if (g.builtin) emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); @@ -369,7 +414,14 @@ private: emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); else + { g.assigned = true; + if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + { + g.definedAsFunction = true; + g.definedInModuleScope = functionStack.empty(); + } + } trackGlobalRef(gv); } @@ -377,6 +429,98 @@ private: return true; } + class HoldConditionalExecution + { + public: + HoldConditionalExecution(LintGlobalLocal& p) + : p(p) + { + if (!p.functionStack.empty() && !p.functionStack.back().conditionalExecution) + { + resetToFalse = true; + p.functionStack.back().conditionalExecution = true; + } + } + ~HoldConditionalExecution() + { + if (resetToFalse) + p.functionStack.back().conditionalExecution = false; + } + + private: + bool resetToFalse = false; + LintGlobalLocal& p; + }; + + bool visit(AstStatIf* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->thenbody->visit(this); + if (node->elsebody) + node->elsebody->visit(this); + + return false; + } + + bool visit(AstStatWhile* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->body->visit(this); + + return false; + } + + bool visit(AstStatRepeat* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->body->visit(this); + + return false; + } + + bool visit(AstStatFor* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->from->visit(this); + node->to->visit(this); + + if (node->step) + node->step->visit(this); + + node->body->visit(this); + + return false; + } + + bool visit(AstStatForIn* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + for (AstExpr* expr : node->values) + expr->visit(this); + + node->body->visit(this); + + return false; + } + void trackGlobalRef(AstExprGlobal* node) { Global& g = globals[node->name]; @@ -390,7 +534,12 @@ private: // to reduce the cost of tracking we only track this for user globals if (!g.builtin) { - g.functionRef = functionStack; + g.functionRef.clear(); + g.functionRef.reserve(functionStack.size()); + for (const FunctionInfo& entry : functionStack) + { + g.functionRef.push_back(entry.ast); + } } } else @@ -401,7 +550,7 @@ private: // we need to find a common prefix between all uses of a global size_t prefix = 0; - while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix]) + while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix].ast) prefix++; g.functionRef.resize(prefix); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 412b78bb..76dc72d2 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauTypeAliasDefaults) LUAU_FASTFLAG(LuauImmutableTypes) namespace Luau @@ -463,7 +462,7 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, See TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + if (param.defaultValue) defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); result.typeParams.push_back({ty, defaultValue}); @@ -474,7 +473,7 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, See TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + if (param.defaultValue) defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); result.typePackParams.push_back({tp, defaultValue}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5e79b841..010ca361 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasDefaults) - /* * Prefix generic typenames with gen- * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 @@ -288,28 +286,15 @@ struct TypeVarStringifier else first = false; - if (FFlag::LuauTypeAliasDefaults) - { - bool wrap = !singleTp && get(follow(tp)); + bool wrap = !singleTp && get(follow(tp)); - if (wrap) - state.emit("("); + if (wrap) + state.emit("("); - stringify(tp); + stringify(tp); - if (wrap) - state.emit(")"); - } - else - { - if (!singleTp) - state.emit("("); - - stringify(tp); - - if (!singleTp) - state.emit(")"); - } + if (wrap) + state.emit(")"); } if (types.size() || typePacks.size()) @@ -1105,100 +1090,8 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) -{ - std::string s = prefix; - - auto toString_ = [&opts](TypeId ty) -> std::string { - ToStringResult res = toStringDetailed(ty, opts); - opts.nameMap = std::move(res.nameMap); - return res.name; - }; - - auto toStringPack_ = [&opts](TypePackId ty) -> std::string { - ToStringResult res = toStringDetailed(ty, opts); - opts.nameMap = std::move(res.nameMap); - return res.name; - }; - - if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) - { - s += "<"; - - bool first = true; - for (TypeId g : ftv.generics) - { - if (!first) - s += ", "; - first = false; - s += toString_(g); - } - - for (TypePackId gp : ftv.genericPacks) - { - if (!first) - s += ", "; - first = false; - s += toStringPack_(gp); - } - - s += ">"; - } - - s += "("; - - auto argPackIter = begin(ftv.argTypes); - auto argNameIter = ftv.argNames.begin(); - - bool first = true; - while (argPackIter != end(ftv.argTypes)) - { - if (!first) - s += ", "; - first = false; - - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - ++argNameIter; - } - else - { - s += "_: "; - } - - s += toString_(*argPackIter); - ++argPackIter; - } - - if (argPackIter.tail()) - { - if (auto vtp = get(*argPackIter.tail())) - s += ", ...: " + toString_(vtp->ty); - else - s += ", ...: " + toStringPack_(*argPackIter.tail()); - } - - s += "): "; - - size_t retSize = size(ftv.retType); - bool hasTail = !finite(ftv.retType); - if (retSize == 0 && !hasTail) - s += "()"; - else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) - s += toStringPack_(ftv.retType); - else - s += "(" + toStringPack_(ftv.retType) + ")"; - - return s; -} - std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) { - if (!FFlag::LuauTypeAliasDefaults) - return toStringNamedFunction_DEPRECATED(prefix, ftv, opts); - ToStringResult result; StringifierState state(opts, result, opts.nameMap); TypeVarStringifier tvs{state}; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index a02d396b..92ed241e 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasDefaults) - namespace { bool isIdentifierStartChar(char c) @@ -796,21 +794,14 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - { - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o.location.begin); + writer.identifier(o.name.value); - if (o.defaultValue) - { - writer.maybeSpace(o.defaultValue->location.begin, 2); - writer.symbol("="); - visualizeTypeAnnotation(*o.defaultValue); - } - } - else + if (o.defaultValue) { - writer.identifier(o.name.value); + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); } } @@ -818,23 +809,15 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - { - writer.advance(o.location.begin); - writer.identifier(o.name.value); - writer.symbol("..."); + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); - if (o.defaultValue) - { - writer.maybeSpace(o.defaultValue->location.begin, 2); - writer.symbol("="); - visualizeTypePackAnnotation(*o.defaultValue, false); - } - } - else + if (o.defaultValue) { - writer.identifier(o.name.value); - writer.symbol("..."); + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); } } @@ -882,18 +865,14 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); writer.symbol("..."); } @@ -1023,18 +1002,14 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); writer.symbol("..."); } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 00067bdd..c7bf1e62 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -61,22 +61,52 @@ void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) { - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); } void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) { - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - sharedSeen->push_back(sortedPair); + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); } void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool DEPRECATED_TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void DEPRECATED_TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void DEPRECATED_TxnLog::popSeen(TypePackId lhs, TypePackId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool DEPRECATED_TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); +} + +void DEPRECATED_TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + sharedSeen->push_back(sortedPair); +} + +void DEPRECATED_TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } @@ -186,10 +216,40 @@ TxnLog TxnLog::inverse() } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) const +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) { return true; @@ -203,19 +263,19 @@ bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const return false; } -void TxnLog::pushSeen(TypeId lhs, TypeId rhs) +void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } -void TxnLog::popSeen(TypeId lhs, TypeId rhs) +void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index faf60eb3..8e6b3b52 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,22 +29,20 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) -LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) +LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauAnotherTypeLevelFix, false) LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) +LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) namespace Luau { @@ -445,7 +443,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter)) { while (checkIter != protoIter) { @@ -1676,7 +1674,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = FFlag::LuauAscribeCorrectLevelToInferredProperitesOfFreeTables ? freshType(tableType->level) : freshType(scope); + TypeId result = freshType(tableType->level); tableType->props[name] = {result}; return result; } @@ -1776,31 +1774,62 @@ std::optional TypeChecker::getIndexTypeFromType( std::vector TypeChecker::reduceUnion(const std::vector& types) { - std::set s; - - for (TypeId t : types) + if (FFlag::LuauDoNotAccidentallyDependOnPointerOrdering) { - if (const UnionTypeVar* utv = get(follow(t))) + std::vector result; + for (TypeId t : types) { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - s.insert(ty); + t = follow(t); + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + { + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (std::find(result.begin(), result.end(), ty) == result.end()) + result.push_back(ty); + } + } + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); } - else - s.insert(t); - } - // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. - for (TypeId t : s) + return result; + } + else { - t = follow(t); - if (get(t) || get(t)) - return {t}; - } + std::set s; - std::vector r(s.begin(), s.end()); - std::sort(r.begin(), r.end()); - return r; + for (TypeId t : types) + { + if (const UnionTypeVar* utv = get(follow(t))) + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + s.insert(ty); + } + else + s.insert(t); + } + + // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. + for (TypeId t : s) + { + t = follow(t); + if (get(t) || get(t)) + return {t}; + } + + std::vector r(s.begin(), s.end()); + std::sort(r.begin(), r.end()); + return r; + } } std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) @@ -2811,7 +2840,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = freshType(FFlag::LuauAnotherTypeLevelFix ? exprTable->level : scope->level); + TypeId resultType = freshType(exprTable->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } @@ -4453,51 +4482,6 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) return addTypePack(TypePackVar(FreeTypePack{level})); } -bool Quantification::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = log->getMutable(ty)) - return level.subsumes(ttv->level) && ((ttv->state == TableState::Free) || (ttv->state == TableState::Unsealed)); - else if (const FreeTypeVar* ftv = log->getMutable(ty)) - return level.subsumes(ftv->level); - else - return false; -} - -bool Quantification::isDirty(TypePackId tp) -{ - if (const FreeTypePack* ftv = log->getMutable(tp)) - return level.subsumes(ftv->level); - else - return false; -} - -TypeId Quantification::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) - { - TableState state = (ttv->state == TableState::Unsealed ? TableState::Sealed : TableState::Generic); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - { - TypeId generic = addType(GenericTypeVar{level}); - generics.push_back(generic); - return generic; - } -} - -TypePackId Quantification::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - TypePackId genericPack = addTypePack(TypePackVar(GenericTypePack{level})); - genericPacks.push_back(genericPack); - return genericPack; -} - bool Anyification::isDirty(TypeId ty) { if (const TableTypeVar* ttv = log->getMutable(ty)) @@ -4550,29 +4534,8 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) return ty; - if (FFlag::LuauQuantifyInPlace2) - { - Luau::quantify(ty, scope->level); - return ty; - } - - Quantification quantification{¤tModule->internalTypes, scope->level}; - std::optional qty = quantification.substitute(ty); - - if (!qty.has_value()) - { - reportError(location, UnificationTooComplex{}); - return errorRecoveryType(scope); - } - - if (ty == *qty) - return ty; - - FunctionTypeVar* qftv = getMutable(*qty); - LUAU_ASSERT(qftv); - qftv->generics = std::move(quantification.generics); - qftv->genericPacks = std::move(quantification.genericPacks); - return *qty; + Luau::quantify(ty, scope->level); + return ty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) @@ -4915,35 +4878,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - bool hasDefaultTypes = false; - bool hasDefaultPacks = false; bool parameterCountErrorReported = false; + bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); - if (FFlag::LuauTypeAliasDefaults) + if (!lit->hasParameterList) { - hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); - hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); - - if (!lit->hasParameterList) - { - if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) - { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - parameterCountErrorReported = true; - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); - } - } - } - else - { - if (!lit->hasParameterList && !tf->typePackParams.empty()) + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; if (!FFlag::LuauErrorRecoveryType) return errorRecoveryType(scope); } @@ -4986,72 +4934,69 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); - if (FFlag::LuauTypeAliasDefaults) + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); + + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); + + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) { - size_t typesProvided = typeParams.size(); - size_t typesRequired = tf->typeParams.size(); + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; - size_t packsProvided = typePackParams.size(); - size_t packsRequired = tf->typePackParams.size(); + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; - bool notEnoughParameters = - (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); - bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; - - // Add default type and type pack parameters if that's required and it's possible - if (notEnoughParameters && hasDefaultParameters) + if (typesProvided < typesRequired) { - // 'applyTypeFunction' is used to substitute default types that reference previous generic types - ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; - - for (size_t i = 0; i < typesProvided; ++i) - applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; - - if (typesProvided < typesRequired) + for (size_t i = typesProvided; i < typesRequired; ++i) { - for (size_t i = typesProvided; i < typesRequired; ++i) + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); + + if (!defaultTy) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + + if (!maybeInstantiated.has_value()) { - TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); - - if (!defaultTy) - break; - - std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); - - if (!maybeInstantiated.has_value()) - { - reportError(annotation.location, UnificationTooComplex{}); - maybeInstantiated = errorRecoveryType(scope); - } - - applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; - typeParams.push_back(*maybeInstantiated); + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); } + } - for (size_t i = 0; i < packsProvided; ++i) - applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; - if (packsProvided < packsRequired) + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) { - for (size_t i = packsProvided; i < packsRequired; ++i) + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); + + if (!defaultTp) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + + if (!maybeInstantiated.has_value()) { - TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); - - if (!defaultTp) - break; - - std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); - - if (!maybeInstantiated.has_value()) - { - reportError(annotation.location, UnificationTooComplex{}); - maybeInstantiated = errorRecoveryTypePack(scope); - } - - applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; - typePackParams.push_back(*maybeInstantiated); + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); } } } @@ -5343,12 +5288,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; - // TODO: CLI-46926 it's not a good idea to rename the type here TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; + bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); TableTypeVar* ttv = getMutableTableType(target); - - if (ttv && needsClone) + + if (shouldMutate && ttv && needsClone) { // Substitution::clone is a shallow clone. If this is a metatable type, we // want to mutate its table, so we need to explicitly clone that table as @@ -5368,7 +5313,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } } - if (ttv) + if (shouldMutate && ttv) { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; @@ -5382,7 +5327,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { LUAU_ASSERT(scope->parent); - const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; + const TypeLevel level = levelOpt.value_or(scope->level); std::vector generics; @@ -5390,7 +5335,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && generic.defaultValue) + if (generic.defaultValue) defaultValue = resolveType(scope, *generic.defaultValue); Name n = generic.name.value; @@ -5426,7 +5371,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue) + if (genericPack.defaultValue) defaultValue = resolveTypePack(scope, *genericPack.defaultValue); Name n = genericPack.name.value; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index a1dcfdbe..5af2c8a6 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,6 +24,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) namespace Luau @@ -157,6 +158,7 @@ bool isNumber(TypeId ty) return isPrim(ty, PrimitiveTypeVar::Number); } +// Returns true when ty is a subtype of string bool isString(TypeId ty) { if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) @@ -168,6 +170,27 @@ bool isString(TypeId ty) return false; } +// Returns true when ty is a supertype of string +bool maybeString(TypeId ty) +{ + if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) + { + ty = follow(ty); + + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; + + if (auto utv = get(ty)) + return std::any_of(begin(utv), end(utv), maybeString); + + return false; + } + else + { + return isString(ty); + } +} + bool isThread(TypeId ty) { return isPrim(ty, PrimitiveTypeVar::Thread); @@ -684,7 +707,7 @@ TypeId SingletonTypes::makeStringMetatable() {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, + {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, {"pack", {arena->addType(FunctionTypeVar{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack, @@ -761,6 +784,8 @@ void persist(TypeId ty) } else if (auto ttv = get(t)) { + LUAU_ASSERT(ttv->state != TableState::Free && ttv->state != TableState::Unsealed); + for (const auto& [_name, prop] : ttv->props) queue.push_back(prop.type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d0eba013..6c29486a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,11 +18,13 @@ LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) +LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, true) namespace Luau { @@ -329,7 +331,7 @@ Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, +Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) @@ -656,26 +658,85 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - if (FFlag::LuauUseCommittingTxnLog) + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } else { - if (i != count - 1) + if (FFlag::LuauUseCommittingTxnLog) { - innerState.DEPRECATED_log.rollback(); + if (i == count - 1) + { + log.concat(std::move(innerState.log)); + } } else { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + if (i != count - 1) + { + innerState.DEPRECATED_log.rollback(); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } } - } - ++i; + ++i; + } + } + + // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) + { + auto tryBind = [this, subTy](TypeId superOption) { + superOption = FFlag::LuauUseCommittingTxnLog ? log.follow(superOption) : follow(superOption); + + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + return; + + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (FFlag::LuauUseCommittingTxnLog) + { + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + } + else + { + if (DEPRECATED_log.haveSeen(subTy, superOption)) + { + if (auto ttv = getMutable(superOption)) + { + DEPRECATED_log(ttv); + ttv->boundTo = subTy; + } + else + { + DEPRECATED_log(superOption); + *asMutable(superOption) = BoundTypeVar(subTy); + } + } + } + }; + + if (auto utv = (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) + { + for (TypeId ty : utv) + tryBind(ty); + } + else + tryBind(superTy); } if (unificationTooComplex) @@ -1163,6 +1224,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; + if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + return; + if (log.getMutable(superTp)) { occursCheck(superTp, subTp); @@ -1365,6 +1429,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; + if (FFlag::LuauTxnLogSeesTypePacks2 && DEPRECATED_log.haveSeen(superTp, subTp)) + return; + if (get(superTp)) { occursCheck(superTp, subTp); @@ -1619,6 +1686,17 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } + if (FFlag::LuauTxnLogSeesTypePacks2) + { + for (size_t i = 0; i < numGenericPacks; i++) + { + if (FFlag::LuauUseCommittingTxnLog) + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + else + DEPRECATED_log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + } + } + CountMismatch::Context context = ctx; if (!isFunctionCall) @@ -1708,6 +1786,17 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ctx = context; + if (FFlag::LuauTxnLogSeesTypePacks2) + { + for (int i = int(numGenericPacks) - 1; 0 <= i; i--) + { + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + else + DEPRECATED_log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + } + } + for (int i = int(numGenerics) - 1; 0 <= i; i--) { if (FFlag::LuauUseCommittingTxnLog) @@ -1760,7 +1849,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::vector extraProperties; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer && subTable->state != TableState::Free) + if (!subTable->indexer && subTable->state != TableState::Free) { for (const auto& [propName, superProp] : superTable->props) { @@ -1769,7 +1858,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) bool isAny = FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(superProp.type)) : get(follow(superProp.type)); - if (subIter == subTable->props.end() && !isOptional(superProp.type) && !isAny) + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) missingProperties.push_back(propName); } @@ -1781,7 +1870,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) { for (const auto& [propName, subProp] : subTable->props) @@ -1790,7 +1879,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) bool isAny = FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(subProp.type)) : get(follow(subProp.type)); - if (superIter == superTable->props.end() && !isOptional(subProp.type) && !isAny) + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } @@ -1830,7 +1919,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) innerState.DEPRECATED_log.rollback(); } } - else if (subTable->indexer && isString(subTable->indexer->indexType)) + else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1855,9 +1944,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) innerState.DEPRECATED_log.rollback(); } } - else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) + // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` + // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: should isOptional(anyType) be true? + // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } else if (subTable->state == TableState::Free) @@ -1887,7 +1978,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // If both lt and rt contain the property, then // we're done since we already unified them above } - else if (superTable->indexer && isString(superTable->indexer->indexType)) + else if (superTable->indexer && maybeString(superTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1936,9 +2027,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? + else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) { } else if (superTable->state == TableState::Free) @@ -2333,7 +2422,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec bool errorReported = false; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer) + if (!subTable->indexer) { for (const auto& [propName, superProp] : superTable->props) { diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8767daa0..1cb8f134 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,7 +11,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) @@ -779,7 +778,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ true); expectAndConsume('=', "type alias"); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 13304d57..5fd6d341 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -413,29 +413,22 @@ static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) ic_complete_word(cenv, editBuffer, icGetCompletions, isMethodOrFunctionChar); } -struct LinenoiseScopedHistory +static void loadHistory(const char* name) { - LinenoiseScopedHistory() + std::string path; + + if (const char* home = getenv("HOME")) { - const std::string name(".luau_history"); - - if (const char* home = getenv("HOME")) - { - historyFilepath = joinPaths(home, name); - } - else if (const char* userProfile = getenv("USERPROFILE")) - { - historyFilepath = joinPaths(userProfile, name); - } - - if (!historyFilepath.empty()) - ic_set_history(historyFilepath.c_str(), -1 /* default entries (= 200) */); + path = joinPaths(home, name); + } + else if (const char* userProfile = getenv("USERPROFILE")) + { + path = joinPaths(userProfile, name); } - ~LinenoiseScopedHistory() {} - - std::string historyFilepath; -}; + if (!path.empty()) + ic_set_history(path.c_str(), -1 /* default entries (= 200) */); +} static void runReplImpl(lua_State* L) { @@ -447,8 +440,10 @@ static void runReplImpl(lua_State* L) // Prevent auto insertion of braces ic_enable_brace_insertion(false); + // Loads history from the given file; isocline automatically saves the history on process exit + loadHistory(".luau_history"); + std::string buffer; - LinenoiseScopedHistory scopedHistory; for (;;) { diff --git a/VM/include/lua.h b/VM/include/lua.h index af0e2835..0a561f27 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -265,6 +265,8 @@ LUA_API double lua_clock(); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)); +LUA_API void lua_clonefunction(lua_State* L, int idx); + /* ** reference system, can be used to pin objects */ @@ -324,6 +326,7 @@ typedef struct lua_Debug lua_Debug; /* activation record */ /* Functions to be called by the debugger in specific events */ typedef void (*lua_Hook)(lua_State* L, lua_Debug* ar); +LUA_API int lua_stackdepth(lua_State* L); LUA_API int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar); LUA_API int lua_getargument(lua_State* L, int level, int n); LUA_API const char* lua_getlocal(lua_State* L, int level, int n); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 39c76e08..f7f15442 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,7 +14,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauGcForwardMetatableBarrier, false) +LUAU_FASTFLAG(LuauGcAdditionalStats) const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" @@ -876,16 +876,7 @@ int lua_setmetatable(lua_State* L, int objindex) luaG_runerror(L, "Attempt to modify a readonly table"); hvalue(obj)->metatable = mt; if (mt) - { - if (FFlag::LuauGcForwardMetatableBarrier) - { - luaC_objbarrier(L, hvalue(obj), mt); - } - else - { - luaC_objbarriert(L, hvalue(obj), mt); - } - } + luaC_objbarrier(L, hvalue(obj), mt); break; } case LUA_TUSERDATA: @@ -1069,6 +1060,8 @@ int lua_gc(lua_State* L, int what, int data) g->GCthreshold = 0; bool waspaused = g->gcstate == GCSpause; + double startmarktime = g->gcstats.currcycle.marktime; + double startsweeptime = g->gcstats.currcycle.sweeptime; // track how much work the loop will actually perform size_t actualwork = 0; @@ -1086,6 +1079,31 @@ int lua_gc(lua_State* L, int what, int data) } } + if (FFlag::LuauGcAdditionalStats) + { + // record explicit step statistics + GCCycleStats* cyclestats = g->gcstate == GCSpause ? &g->gcstats.lastcycle : &g->gcstats.currcycle; + + double totalmarktime = cyclestats->marktime - startmarktime; + double totalsweeptime = cyclestats->sweeptime - startsweeptime; + + if (totalmarktime > 0.0) + { + cyclestats->markexplicitsteps++; + + if (totalmarktime > cyclestats->markmaxexplicittime) + cyclestats->markmaxexplicittime = totalmarktime; + } + + if (totalsweeptime > 0.0) + { + cyclestats->sweepexplicitsteps++; + + if (totalsweeptime > cyclestats->sweepmaxexplicittime) + cyclestats->sweepmaxexplicittime = totalsweeptime; + } + } + // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) { @@ -1299,6 +1317,18 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) L->global->udatagc[tag] = dtor; } +LUA_API void lua_clonefunction(lua_State* L, int idx) +{ + StkId p = index2addr(L, idx); + api_check(L, isLfunction(p)); + + luaC_checkthreadsleep(L); + + Closure* cl = clvalue(p); + Closure* newcl = luaF_newLclosure(L, 0, L->gt, cl->l.p); + setclvalue(L, L->top - 1, newcl); +} + lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 71975a52..9a6f7793 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,8 +11,6 @@ #include -LUAU_FASTFLAG(LuauSchubfach) - /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -480,18 +478,13 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) switch (lua_type(L, idx)) { case LUA_TNUMBER: - if (FFlag::LuauSchubfach) - { - double n = lua_tonumber(L, idx); - char s[LUAI_MAXNUM2STR]; - char* e = luai_num2str(s, n); - lua_pushlstring(L, s, e - s); - } - else - { - lua_pushstring(L, lua_tostring(L, idx)); - } + { + double n = lua_tonumber(L, idx); + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, n); + lua_pushlstring(L, s, e - s); break; + } case LUA_TSTRING: lua_pushvalue(L, idx); break; @@ -505,29 +498,18 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { const float* v = lua_tovector(L, idx); - if (FFlag::LuauSchubfach) + char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; + char* e = s; + for (int i = 0; i < LUA_VECTOR_SIZE; ++i) { - char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; - char* e = s; - for (int i = 0; i < LUA_VECTOR_SIZE; ++i) + if (i != 0) { - if (i != 0) - { - *e++ = ','; - *e++ = ' '; - } - e = luai_num2str(e, v[i]); + *e++ = ','; + *e++ = ' '; } - lua_pushlstring(L, s, e - s); - } - else - { -#if LUA_VECTOR_SIZE == 4 - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); -#else - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); -#endif + e = luai_num2str(e, v[i]); } + lua_pushlstring(L, s, e - s); break; } default: diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index a4f93c62..7a9947b7 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -168,6 +168,11 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, return status; } +int lua_stackdepth(lua_State* L) +{ + return int(L->ci - L->base_ci); +} + int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) { int status = 0; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 8c3a2029..a656854e 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -11,6 +11,8 @@ #include "lmem.h" #include "ludata.h" +LUAU_FASTFLAGVARIABLE(LuauGcAdditionalStats, false) + #include #define GC_SWEEPMAX 40 @@ -53,17 +55,28 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSpause: // record root mark time if we have switched to next state if (g->gcstate == GCSpropagate) + { g->gcstats.currcycle.marktime += seconds; + + if (FFlag::LuauGcAdditionalStats && assist) + g->gcstats.currcycle.markassisttime += seconds; + } break; case GCSpropagate: case GCSpropagateagain: g->gcstats.currcycle.marktime += seconds; + + if (FFlag::LuauGcAdditionalStats && assist) + g->gcstats.currcycle.markassisttime += seconds; break; case GCSatomic: g->gcstats.currcycle.atomictime += seconds; break; case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; + + if (FFlag::LuauGcAdditionalStats && assist) + g->gcstats.currcycle.sweepassisttime += seconds; break; default: LUAU_ASSERT(!"Unexpected GC state"); @@ -78,7 +91,7 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, static void startGcCycleStats(global_State* g) { g->gcstats.currcycle.starttimestamp = lua_clock(); - g->gcstats.currcycle.waittime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; + g->gcstats.currcycle.pausetime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; } static void finishGcCycleStats(global_State* g) @@ -585,10 +598,21 @@ static size_t atomic(lua_State* L) LUAU_ASSERT(g->gcstate == GCSatomic); size_t work = 0; + double currts = lua_clock(); + double prevts = currts; + /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ work += propagateall(g); + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimeupval += currts - prevts; + prevts = currts; + } + /* remark weak tables */ g->gray = g->weak; g->weak = NULL; @@ -596,16 +620,41 @@ static size_t atomic(lua_State* L) markobject(g, L); /* mark running thread */ markmt(g); /* mark basic metatables (again) */ work += propagateall(g); + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimeweak += currts - prevts; + prevts = currts; + } + /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; work += propagateall(g); - work += cleartable(L, g->weak); /* remove collected objects from weak tables */ + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimegray += currts - prevts; + prevts = currts; + } + + /* remove collected objects from weak tables */ + work += cleartable(L, g->weak); g->weak = NULL; + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimeclear += currts - prevts; + } + /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); g->sweepgcopage = g->allgcopages; g->gcstate = GCSsweep; + return work; } @@ -693,6 +742,9 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) { + if (FFlag::LuauGcAdditionalStats) + g->gcstats.currcycle.propagatework = g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork; + // perform one iteration over 'gray again' list g->gray = g->grayagain; g->grayagain = NULL; @@ -710,6 +762,10 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { + if (FFlag::LuauGcAdditionalStats) + g->gcstats.currcycle.propagateagainwork = + g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork - g->gcstats.currcycle.propagatework; + g->gcstate = GCSatomic; } break; @@ -811,6 +867,12 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; + + if (assist) + g->gcstats.currcycle.assistrequests += g->gcstepsize; + else + g->gcstats.currcycle.explicitrequests += g->gcstepsize; + int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -833,6 +895,11 @@ void luaC_step(lua_State* L, bool assist) recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); + if (lastgcstate == GCSpropagate) + g->gcstats.currcycle.markrequests += g->gcstepsize; + else if (lastgcstate == GCSsweep) + g->gcstats.currcycle.sweeprequests += g->gcstepsize; + // at the end of the last cycle if (g->gcstate == GCSpause) { @@ -844,6 +911,9 @@ void luaC_step(lua_State* L, bool assist) finishGcCycleStats(g); + if (FFlag::LuauGcAdditionalStats) + g->gcstats.currcycle.starttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.heapgoalsizebytes = heapgoal; g->gcstats.currcycle.heaptriggersizebytes = heaptrigger; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 253e269f..cbeeebd4 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -111,13 +111,6 @@ luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } -// TODO: remove with FFlagLuauGcForwardMetatableBarrier -#define luaC_objbarriert(L, t, o) \ - { \ - if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \ - luaC_barriertable(L, t, obj2gco(o)); \ - } - #define luaC_upvalbarrier(L, uv, tv) \ { \ if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || ((UpVal*)uv)->v != &((UpVal*)uv)->u.value)) \ diff --git a/VM/src/lnumprint.cpp b/VM/src/lnumprint.cpp index 2fd0f1bb..d64e3ca4 100644 --- a/VM/src/lnumprint.cpp +++ b/VM/src/lnumprint.cpp @@ -6,7 +6,6 @@ #include "lcommon.h" #include -#include // TODO: Remove with LuauSchubfach #ifdef _MSC_VER #include @@ -18,8 +17,6 @@ // The code uses the notation from the paper for local variables where appropriate, and refers to paper sections/figures/results. -LUAU_FASTFLAGVARIABLE(LuauSchubfach, false) - // 9.8.2. Precomputed table for 128-bit overestimates of powers of 10 (see figure 3 for table bounds) // To avoid storing 616 128-bit numbers directly we use a technique inspired by Dragonbox implementation and store 16 consecutive // powers using a 128-bit baseline and a bitvector with 1-bit scale and 3-bit offset for the delta between each entry and base*5^k @@ -275,12 +272,6 @@ inline char* trimzero(char* end) char* luai_num2str(char* buf, double n) { - if (!FFlag::LuauSchubfach) - { - snprintf(buf, LUAI_MAXNUM2STR, LUA_NUMBER_FMT, n); - return buf + strlen(buf); - } - // IEEE-754 union { diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index fba07bc3..549b4630 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -55,7 +55,6 @@ LUAU_FASTMATH_END #define luai_num2unsigned(i, n) ((i) = (unsigned)(long long)(n)) #endif -#define LUA_NUMBER_FMT "%.14g" /* TODO: Remove with LuauSchubfach */ #define LUAI_MAXNUM2STR 48 LUAI_FUNC char* luai_num2str(char* buf, double n); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 3ee96718..b2bedb48 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -77,25 +77,46 @@ typedef struct CallInfo struct GCCycleStats { + size_t starttotalsizebytes = 0; size_t heapgoalsizebytes = 0; size_t heaptriggersizebytes = 0; - double waittime = 0.0; // time from end of the last cycle to the start of a new one + double pausetime = 0.0; // time from end of the last cycle to the start of a new one double starttimestamp = 0.0; double endtimestamp = 0.0; double marktime = 0.0; + double markassisttime = 0.0; + double markmaxexplicittime = 0.0; + size_t markexplicitsteps = 0; + size_t markrequests = 0; double atomicstarttimestamp = 0.0; size_t atomicstarttotalsizebytes = 0; double atomictime = 0.0; + // specific atomic stage parts + double atomictimeupval = 0.0; + double atomictimeweak = 0.0; + double atomictimegray = 0.0; + double atomictimeclear = 0.0; + double sweeptime = 0.0; + double sweepassisttime = 0.0; + double sweepmaxexplicittime = 0.0; + size_t sweepexplicitsteps = 0; + size_t sweeprequests = 0; + + size_t assistrequests = 0; + size_t explicitrequests = 0; size_t assistwork = 0; size_t explicitwork = 0; + size_t propagatework = 0; + size_t propagateagainwork = 0; + size_t endtotalsizebytes = 0; }; diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 0412ea76..ef0b4b93 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -447,7 +447,8 @@ void luaH_free(lua_State* L, Table* t, lua_Page* page) { if (t->node != dummynode) luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); - luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); + if (t->array) + luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); luaM_freegco(L, t, sizeof(Table), t->memcat, page); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 0d3374ef..00753742 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lapi.h" #include "lstate.h" #include "ltable.h" #include "lstring.h" @@ -9,6 +10,8 @@ #include "ldebug.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauTableClone, false) + static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -507,6 +510,23 @@ static int tisfrozen(lua_State* L) return 1; } +static int tclone(lua_State* L) +{ + if (!FFlag::LuauTableClone) + luaG_runerror(L, "table.clone is not available"); + + luaL_checktype(L, 1, LUA_TTABLE); + luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); + + Table* tt = luaH_clone(L, hvalue(L->base)); + + TValue v; + sethvalue(L, &v, tt); + luaA_pushobject(L, &v); + + return 1; +} + static const luaL_Reg tab_funcs[] = { {"concat", tconcat}, {"foreach", foreach}, @@ -524,6 +544,7 @@ static const luaL_Reg tab_funcs[] = { {"clear", tclear}, {"freeze", tfreeze}, {"isfrozen", tisfrozen}, + {"clone", tclone}, {NULL, NULL}, }; diff --git a/bench/gc/test_GC_Boehm_Trees.lua b/bench/gc/test_GC_Boehm_Trees.lua index 1451769f..5abad1d8 100644 --- a/bench/gc/test_GC_Boehm_Trees.lua +++ b/bench/gc/test_GC_Boehm_Trees.lua @@ -74,4 +74,7 @@ function test() end end +bench.runs = 6 +bench.extraRuns = 2 + bench.runCode(test, "GC: Boehm tree") diff --git a/bench/gc/test_GC_Tree_Pruning_Eager.lua b/bench/gc/test_GC_Tree_Pruning_Eager.lua index 514766ae..2111d9ff 100644 --- a/bench/gc/test_GC_Tree_Pruning_Eager.lua +++ b/bench/gc/test_GC_Tree_Pruning_Eager.lua @@ -40,7 +40,7 @@ function test() local tree = { id = 0 } - for i = 1,1000 do + for i = 1,100 do fill_tree(tree, 10) prune_tree(tree, 0) diff --git a/bench/gc/test_GC_Tree_Pruning_Gen.lua b/bench/gc/test_GC_Tree_Pruning_Gen.lua index a8d0f40a..f88bd7f4 100644 --- a/bench/gc/test_GC_Tree_Pruning_Gen.lua +++ b/bench/gc/test_GC_Tree_Pruning_Gen.lua @@ -42,7 +42,7 @@ function test() local tree = { id = 0 } fill_tree(tree, 16) - for i = 1,1000 do + for i = 1,100 do local small_tree = { id = 0 } fill_tree(small_tree, 8) diff --git a/bench/gc/test_GC_Tree_Pruning_Lazy.lua b/bench/gc/test_GC_Tree_Pruning_Lazy.lua index 8cb69192..3ea6bbef 100644 --- a/bench/gc/test_GC_Tree_Pruning_Lazy.lua +++ b/bench/gc/test_GC_Tree_Pruning_Lazy.lua @@ -46,7 +46,7 @@ function test() local tree = { id = 0 } - for i = 1,1000 do + for i = 1,100 do fill_tree(tree, 10) prune_tree(tree, 0) diff --git a/extern/isocline/include/isocline.h b/extern/isocline/include/isocline.h index 0d46cf3f..a7e03ed2 100644 --- a/extern/isocline/include/isocline.h +++ b/extern/isocline/include/isocline.h @@ -259,7 +259,7 @@ void ic_complete_qword( ic_completion_env_t* cenv, const char* prefix, ic_comple /// The `escape_char` is the escaping character, usually `\` but use 0 to not have escape characters. /// The `quote_chars` define the quotes, use NULL for the default `"\'\""` quotes. /// @see ic_complete_word() which uses the default values for `non_word_chars`, `quote_chars` and `\` for escape characters. -void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t fun, +void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char, char escape_char, const char* quote_chars ); /// \} diff --git a/fuzz/number.cpp b/fuzz/number.cpp index 70447409..31c953e3 100644 --- a/fuzz/number.cpp +++ b/fuzz/number.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauSchubfach); - #define LUAI_MAXNUM2STR 48 char* luai_num2str(char* buf, double n); @@ -17,8 +15,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) if (Size < 8) return 0; - FFlag::LuauSchubfach.value = true; - double num; memcpy(&num, Data, 8); diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 912fef23..1022831b 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -59,7 +59,7 @@ void interrupt(lua_State* L, int gc) } } -void* allocate(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize) +void* allocate(void* ud, void* ptr, size_t osize, size_t nsize) { if (nsize == 0) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 6aadef32..ce890ba8 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) LUAU_FASTFLAG(LuauUseCommittingTxnLog) +LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -262,7 +263,7 @@ TEST_CASE_FIXTURE(ACFixture, "get_member_completions") auto ac = autocomplete('1'); - CHECK_EQ(16, ac.entryMap.size()); + CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2235,7 +2236,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - CHECK_EQ(16, ac.entryMap.size()); + CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2695,8 +2696,6 @@ local r4 = t:bar1(@4) TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - check(R"( type A = () -> T )"); @@ -2709,8 +2708,6 @@ type A = () -> T TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - check(R"( type A = () -> T )"); @@ -2768,7 +2765,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag preferToCallFunctionsForIntersects("PreferToCallFunctionsForIntersects", true); check(R"( local bar: ((number) -> number) & (number, number) -> number) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index eb6ca749..63fbb363 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,6 +241,8 @@ TEST_CASE("Math") TEST_CASE("Table") { + ScopedFastFlag sff("LuauTableClone", true); + runConformance("nextvar.lua"); } @@ -465,6 +467,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag sff("LuauTableCloneType", true); + runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; @@ -959,8 +963,6 @@ TEST_CASE("Coverage") TEST_CASE("StringConversion") { - ScopedFastFlag sff{"LuauSchubfach", true}; - runConformance("strconv.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index ab19cea3..4d6c207c 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -157,6 +157,113 @@ return bar() CHECK_EQ(result.warnings[0].text, "Global 'foo' is only used in the enclosing function 'bar'; consider changing it to local"); } +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFx") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +return bar() + baz() +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Global 'foo' is never read before being written. Consider changing it to local"); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFxWithRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +function read() + print(foo) +end + +return bar() + baz() + read() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalWithConditional") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + if true then foo = 6 end + return foo +end + +function baz() + foo = 6 + return foo +end + +return bar() + baz() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal3WithConditionalRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +function read() + if false then print(foo) end +end + +return bar() + baz() + read() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalInnerRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function foo() + local f = function() return bar end + f() + bar = 42 +end + +function baz() bar = 0 end + +return foo() + baz() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMulti") { LintResult result = lint(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 77e49ce3..7f6a6c0d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1988,8 +1988,6 @@ TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - AstStat* stat = parse(R"( type A = {} type B = {} @@ -2005,8 +2003,6 @@ type G = (U...) -> T... TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); @@ -2574,8 +2570,6 @@ do end TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ParseResult result = tryParse(R"( type Y = (T...) -> U... )"); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 0ca9c994..29bdd866 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -96,8 +96,6 @@ n2 [label="number"]; TEST_CASE_FIXTURE(Fixture, "function") { - ScopedFastFlag luauQuantifyInPlace2{"LuauQuantifyInPlace2", true}; - CheckResult result = check(R"( local function f(a, ...: string) return a end )"); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index bbb26291..6713a589 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -500,8 +500,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 332aba9e..5f0295b0 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -641,9 +641,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string") TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - std::string code = R"( type Packed = (T, U, V...)->(W...) local a: Packed diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 31d7ef10..d584eb2d 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -625,9 +625,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni ScopedFastFlag sff[] = { {"LuauTwoPassAliasDefinitionFix", true}, - // We also force these two flags because this surfaced an unfortunate interaction. + // We also force this flag because it surfaced an unfortunate interaction. {"LuauErrorRecoveryType", true}, - {"LuauQuantifyInPlace2", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index f3dfb214..bf990770 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -934,4 +934,31 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir CHECK_EQ("(nil) -> nil", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") +{ + CheckResult result = check(R"( + local t1: {a: number} = {a = 42} + local t2: {b: string} = {b = "hello"} + local t3: {boolean} = {false, true} + + local tf1 = table.freeze(t1) + local tf2 = table.freeze(t2) + local tf3 = table.freeze(t3) + + local a = tf1.a + local b = tf2.b + local c = tf3[2] + + local d = tf1.b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f8fccf6b..c482847b 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -697,4 +697,93 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") +{ + ScopedFastFlag sffs[] = { + { "LuauTableSubtypingVariance2", true }, + { "LuauUnsealedTableLiteral", true }, + { "LuauPropertiesGetExpectedType", true }, + { "LuauRecursiveTypeParameterRestriction", true }, + }; + + CheckResult result = check(R"( +--!strict +-- At one point this produced a UAF +type T = { a: U, b: a } +type U = { c: T?, d : a } +local x: T = { a = { c = nil, d = 5 }, b = 37 } +x.a.c = x +local y: T = { a = { c = nil, d = 5 }, b = 37 } +y.a.c = y + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' +caused by: + Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' +caused by: + Property 'd' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") +{ + ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; + + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (create: () -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(create: () -> U...): U... + return create() + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") +{ + ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; + + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (create: () -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(create) + return create() + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") +{ + ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; + + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (arg: S, create: (S) -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(arg: T, create: (T) -> U...): U... + return create(arg) + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index eee0e0f1..2e16b21e 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -8,7 +8,6 @@ #include LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -40,16 +39,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string old_expected = R"( - function f(a:{fn:()->(free,free...)}): () - if type(a) == 'boolean'then - local a1:boolean=a - elseif a.fn()then - local a2:{fn:()->(free,free...)}=a - end - end - )"; - const std::string expected = R"( function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then @@ -60,10 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ(expected, decorateWithTypes(code)); - else - CHECK_EQ(old_expected, decorateWithTypes(code)); + CHECK_EQ(expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") @@ -135,46 +121,6 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") CHECK_EQ("number", toString(tm->givenType)); } -TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") -{ - CheckResult result = check(R"( - local a: {x: number, y: number, [any]: any} | {y: number} - - function f(t) - t.y = 1 - return t - end - - local b = f(a) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // :( - // Should be the same as the type of a - REQUIRE_EQ("{| y: number |}", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") -{ - CheckResult result = check(R"( - local a: {y: number} | {x: number, y: number, [any]: any} - - function f(t) - t.y = 1 - return t - end - - local b = f(a) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // :( - // Should be the same as the type of a - REQUIRE_EQ("{| [any]: any, x: number, y: number |}", toString(requireType("b"))); -} - // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { @@ -557,25 +503,6 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct } } -TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") -{ - CheckResult result = check(R"( - --!strict - local function setNumber(t: { p: number? }, x:number) t.p = x end - local function getString(t: { p: string? }):string return t.p or "" end - -- This shouldn't type-check! - local function oh(x:number): string - local t: {} = {} - setNumber(t, x) - return getString(t) - end - local s: string = oh(37) - )"); - - // Really this should return an error, but it doesn't - LUAU_REQUIRE_NO_ERRORS(result); -} - // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. @@ -600,25 +527,9 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") -{ - // Mutability in type function application right now can create strange recursive types - // TODO: instantiation right now is problematic, in this example should either leave the Table type alone - // or it should rename the type to 'Self' so that the result will be 'Self
' - CheckResult result = check(R"( -type Table = { a: number } -type Self = T -local a: Self
- )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "Table
"); -} - TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { ScopedFastFlag sff[]{ - {"LuauQuantifyInPlace2", true}, {"LuauReturnAnyInsteadOfICE", true}, }; @@ -664,8 +575,6 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { - ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; - CheckResult result = check(R"( local function f() return end local g = function() return f() end @@ -676,8 +585,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { - ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; - CheckResult result = check(R"( --!strict local function f(...) return ... end diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index bff8926c..a5147d56 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,7 +8,6 @@ LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -1179,20 +1178,14 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - else - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); } CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" - else - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 856549bd..3ed536ea 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -465,30 +465,32 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si CHECK_EQ("((string) -> (b...), a) -> ()", toString(requireType("foo"))); } -// TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") -// { -// ScopedFastFlag sff[]{ -// {"LuauParseSingletonTypes", true}, -// {"LuauSingletonTypes", true}, -// {"LuauDiscriminableUnions2", true}, -// {"LuauEqConstraint", true}, -// {"LuauWidenIfSupertypeIsFree", true}, -// {"LuauWeakEqConstraint", false}, -// }; +TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauEqConstraint", true}, + {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWeakEqConstraint", false}, + {"LuauDoNotAccidentallyDependOnPointerOrdering", true} + }; -// CheckResult result = check(R"( -// local function foo(f, x): "hello"? -- anyone there? -// return if x == "hi" -// then f(x) -// else nil -// end -// )"); + CheckResult result = check(R"( + local function foo(f, x): "hello"? -- anyone there? + return if x == "hi" + then f(x) + else nil + end + )"); -// LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); -// CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); -// CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); -// } + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); + CHECK_EQ(R"(((string) -> (a, c...), b) -> "hello"?)", toString(requireType("foo"))); + // CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); +} TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") { diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index aa949789..da035ba1 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1219,13 +1219,12 @@ TEST_CASE_FIXTURE(Fixture, "passing_compatible_unions_to_a_generic_table_without { CheckResult result = check(R"( type A = {x: number, y: number, [any]: any} | {y: number} - local a: A function f(t) t.y = 1 end - f(a) + f({y = 5} :: A) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -2165,6 +2164,44 @@ b() CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + ScopedFastFlag sffs[] = { + {"LuauTableSubtypingVariance2", true}, + {"LuauSubtypingAddOptPropsToUnsealedTables", true}, + }; + + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "top_table_type") +{ + CheckResult result = check(R"( + --!strict + type Table = { [any] : any } + type HasTable = { p: Table? } + type HasHasTable = { p: HasTable? } + local t : Table = { p = 5 } + local u : HasTable = { p = { p = 5 } } + local v : HasHasTable = { p = { p = { p = 5 } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "length_operator_union") { CheckResult result = check(R"( @@ -2257,4 +2294,44 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") CHECK_EQ("number | string", toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") +{ + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + + CheckResult result = check(R"( + local a: {x: number, y: number, [any]: any} | {y: number} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") +{ + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + + CheckResult result = check(R"( + local a: {y: number} | {x: number, y: number, [any]: any} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f44d9fd8..f63579b5 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4044,6 +4044,49 @@ type t0 = any CHECK(ttv->instantiatedTypeParams.empty()); } +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +type K = X +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +local a = {} +a.x = 4 +local b: X +a.y = 5 +local c: X +c = b +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("a"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") { CheckResult result = check(R"( @@ -4065,6 +4108,21 @@ return m LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + // Mutability in type function application right now can create strange recursive types + CheckResult result = check(R"( +type Table = { a: number } +type Self = T +local a: Self
+ )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a")), "Table"); +} + TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") { TypeId mathTy = requireType(typeChecker.globalScope, "math"); @@ -5284,4 +5342,17 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +{ + CheckResult result = check(R"( +local function f(x: string) + local p = x:split('a') + p = table.pack(table.unpack(p, 1, #p - 1)) + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index c9bf5103..f6ee3ccc 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauQuantifyInPlace2); - using namespace Luau; LUAU_FASTFLAG(LuauUseCommittingTxnLog) @@ -167,10 +165,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("(number) -> boolean", toString(requireType("f"))); - else - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index cbe2e48f..6b96f449 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -622,9 +622,6 @@ type Other = Packed TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: U } @@ -654,9 +651,6 @@ local c: Y = { a = "s" } TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: U } @@ -682,9 +676,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: U, c: V } @@ -700,9 +691,6 @@ local b: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T...) -> () } local a: Y<> @@ -715,9 +703,6 @@ local a: Y<> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: (U...) -> T } @@ -731,9 +716,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T...) -> U... } local a: Y @@ -746,9 +728,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T...) -> U..., b: (T...) -> V... } local a: Y @@ -761,9 +740,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T, U, V...) -> W... } local a: Y @@ -782,9 +758,6 @@ local d: Y ()> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T } local a: Y = { a = 2 } @@ -834,9 +807,6 @@ local a: Y<...number> TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - fileResolver.source["Module/Types"] = R"( export type A = { a: T, b: U } export type B = { a: T, b: U } @@ -882,9 +852,6 @@ local h: Types.H<> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = (T...) -> number local a: Y @@ -897,9 +864,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type A = (T, V...) -> (U, W...) type B = A @@ -914,9 +878,6 @@ type C = A TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type F ()> = (K) -> V type R = { m: F } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 3b53ddfe..0e0b6ebb 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -400,10 +400,10 @@ local e = a.z CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); } -TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") { CheckResult result = check(R"( -local x: { x: number } = { x = 3 } +local x = { x = 3 } type A = number? type B = string? local y: { x: number, y: A | B } @@ -413,7 +413,7 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local x: { x: number } = { x = 3 } +local x = { x = 3 } local a: number? = 2 local y = {} @@ -426,6 +426,31 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +{ + ScopedFastFlag sffs[] = { + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + {"LuauSubtypingAddOptPropsToUnsealedTables", true}, + }; + + CheckResult result = check(R"( + -- the difference between this and unify_unsealed_table_union_check is the type annotation on x +local t = { x = 3, y = true } +local x: { x: number } = t +type A = number? +type B = string? +local y: { x: number, y: A | B } +-- Shouldn't typecheck! +y = x +-- If it does, we can convert any type to any other type +y.y = 5 +local oh : boolean = t.y + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") { CheckResult result = check(R"( diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index 94ba5ccf..e85fcbe8 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -512,4 +512,42 @@ do assert(#t == 7) end +-- test clone +do + local t = {a = 1, b = 2, 3, 4, 5} + local tt = table.clone(t) + + assert(#tt == 3) + assert(tt.a == 1 and tt.b == 2) + + t.c = 3 + assert(tt.c == nil) + + t = table.freeze({"test"}) + tt = table.clone(t) + assert(table.isfrozen(t) and not table.isfrozen(tt)) + + t = setmetatable({}, {}) + tt = table.clone(t) + assert(getmetatable(t) == getmetatable(tt)) + + t = setmetatable({}, {__metatable = "protected"}) + assert(not pcall(table.clone, t)) + + function order(t) + local r = '' + for k,v in pairs(t) do + r ..= tostring(v) + end + return v + end + + t = {a = 1, b = 2, c = 3, d = 4, e = 5, f = 6} + tt = table.clone(t) + assert(order(t) == order(tt)) + + assert(not pcall(table.clone)) + assert(not pcall(table.clone, 42)) +end + return"OK" From feea507be3f6991a76907e92220f50db05d7a98e Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 11 Mar 2022 08:31:18 -0800 Subject: [PATCH 029/102] Sync to upstream/release/518 --- Analysis/include/Luau/TxnLog.h | 55 - Analysis/include/Luau/TypePack.h | 1 - Analysis/include/Luau/Unifier.h | 1 - Analysis/src/Autocomplete.cpp | 56 +- Analysis/src/BuiltinDefinitions.cpp | 17 +- Analysis/src/TxnLog.cpp | 117 -- Analysis/src/TypeInfer.cpp | 672 ++++------ Analysis/src/TypePack.cpp | 26 +- Analysis/src/Unifier.cpp | 1826 +++++++-------------------- VM/include/lua.h | 35 +- VM/include/luaconf.h | 27 - VM/include/lualib.h | 4 +- VM/src/laux.cpp | 15 + VM/src/lbaselib.cpp | 11 +- VM/src/lgc.h | 7 + tests/Autocomplete.test.cpp | 52 +- tests/TypeInfer.aliases.test.cpp | 23 + tests/TypeInfer.builtins.test.cpp | 51 +- tests/TypeInfer.generics.test.cpp | 45 + tests/TypeInfer.tables.test.cpp | 50 + tests/TypeInfer.test.cpp | 38 +- tests/TypeInfer.tryUnify.test.cpp | 44 +- tests/TypeInfer.typePacks.cpp | 13 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tools/LuauVisualize.py | 107 ++ tools/lldb-formatters.lldb | 2 + 26 files changed, 1137 insertions(+), 2168 deletions(-) create mode 100644 tools/LuauVisualize.py create mode 100644 tools/lldb-formatters.lldb diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index f8105383..c8ebaaeb 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -14,61 +14,6 @@ namespace Luau using TypeOrPackId = const void*; -// Log of where what TypeIds we are rebinding and what they used to be -// Remove with LuauUseCommitTxnLog -struct DEPRECATED_TxnLog -{ - DEPRECATED_TxnLog() - : originalSeenSize(0) - , ownedSeen() - , sharedSeen(&ownedSeen) - { - } - - explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) - : originalSeenSize(sharedSeen->size()) - , ownedSeen() - , sharedSeen(sharedSeen) - { - } - - DEPRECATED_TxnLog(const DEPRECATED_TxnLog&) = delete; - DEPRECATED_TxnLog& operator=(const DEPRECATED_TxnLog&) = delete; - - DEPRECATED_TxnLog(DEPRECATED_TxnLog&&) = default; - DEPRECATED_TxnLog& operator=(DEPRECATED_TxnLog&&) = default; - - void operator()(TypeId a); - void operator()(TypePackId a); - void operator()(TableTypeVar* a); - - void rollback(); - - void concat(DEPRECATED_TxnLog rhs); - - bool haveSeen(TypeId lhs, TypeId rhs); - void pushSeen(TypeId lhs, TypeId rhs); - void popSeen(TypeId lhs, TypeId rhs); - - bool haveSeen(TypePackId lhs, TypePackId rhs); - void pushSeen(TypePackId lhs, TypePackId rhs); - void popSeen(TypePackId lhs, TypePackId rhs); - -private: - std::vector> typeVarChanges; - std::vector> typePackChanges; - std::vector>> tableChanges; - size_t originalSeenSize; - - bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs); - void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); - void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); - -public: - std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic - std::vector>* sharedSeen; // shared with all the descendent logs -}; - // Pending state for a TypeVar. Generated by a TxnLog and committed via // TxnLog::commit. struct PendingType diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index c74bad11..946be356 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -105,7 +105,6 @@ private: const TypePack* tp = nullptr; size_t currentIndex = 0; - // Only used if LuauUseCommittingTxnLog is true. const TxnLog* log; }; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4c0462fe..71958f4a 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -45,7 +45,6 @@ struct Unifier TypeArena* const types; Mode mode; - DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; ErrorVec errors; Location location; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index c3de8d0e..e94c432f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,9 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauUseCommittingTxnLog) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); static const std::unordered_set kStatementStartingKeywords = { @@ -240,28 +237,9 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); - if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) - { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - superTy = clone(superTy, *typeArena, seenTypes, seenTypePacks, cloneState); - subTy = clone(subTy, *typeArena, seenTypes, seenTypePacks, cloneState); - - auto errors = unifier.canUnify(subTy, superTy); - return errors.empty(); - } - else - { - unifier.tryUnify(subTy, superTy); - - bool ok = unifier.errors.empty(); - - if (!FFlag::LuauUseCommittingTxnLog) - unifier.DEPRECATED_log.rollback(); - - return ok; - } + unifier.tryUnify(subTy, superTy); + bool ok = unifier.errors.empty(); + return ok; }; auto typeAtPosition = findExpectedTypeAt(module, node, position); @@ -403,28 +381,14 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - if (FFlag::LuauMissingFollowACMetatables) + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) { - TypeId followed = follow(indexIt->second.type); - if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); - } - } - else - { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) - { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); - } + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index e4e5dab8..bf9ef303 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -10,6 +10,7 @@ LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) +LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -376,11 +377,19 @@ static std::optional> magicFunctionSetMetaTable( TypeId mtTy = arena.addType(mtv); - AstExpr* targetExpr = expr.args.data[0]; - if (AstExprLocal* targetLocal = targetExpr->as()) + if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - const Name targetName(targetLocal->local->name.value); - scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + return ExprResult{}; + } + + if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) + { + AstExpr* targetExpr = expr.args.data[0]; + if (AstExprLocal* targetLocal = targetExpr->as()) + { + const Name targetName(targetLocal->local->name.value); + scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + } } return ExprResult{arena.addTypePack({mtTy})}; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index c7bf1e62..876f5f05 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,110 +7,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauUseCommittingTxnLog, false) - namespace Luau { -void DEPRECATED_TxnLog::operator()(TypeId a) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - typeVarChanges.emplace_back(a, *a); -} - -void DEPRECATED_TxnLog::operator()(TypePackId a) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - typePackChanges.emplace_back(a, *a); -} - -void DEPRECATED_TxnLog::operator()(TableTypeVar* a) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - tableChanges.emplace_back(a, a->boundTo); -} - -void DEPRECATED_TxnLog::rollback() -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) - std::swap(*asMutable(it->first), it->second); - - for (auto it = typePackChanges.rbegin(); it != typePackChanges.rend(); ++it) - std::swap(*asMutable(it->first), it->second); - - for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) - std::swap(it->first->boundTo, it->second); - - LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); - sharedSeen->resize(originalSeenSize); -} - -void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); - rhs.typeVarChanges.clear(); - - typePackChanges.insert(typePackChanges.end(), rhs.typePackChanges.begin(), rhs.typePackChanges.end()); - rhs.typePackChanges.clear(); - - tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); - rhs.tableChanges.clear(); -} - -bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) -{ - return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) -{ - pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) -{ - popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -bool DEPRECATED_TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) -{ - return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) -{ - pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::popSeen(TypePackId lhs, TypePackId rhs) -{ - popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -bool DEPRECATED_TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); -} - -void DEPRECATED_TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - sharedSeen->push_back(sortedPair); -} - -void DEPRECATED_TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == sharedSeen->back()); - sharedSeen->pop_back(); -} - const std::string nullPendingResult = ""; std::string toString(PendingType* pending) @@ -170,8 +69,6 @@ const TxnLog* TxnLog::empty() void TxnLog::concat(TxnLog rhs) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (auto& [ty, rep] : rhs.typeVarChanges) typeVarChanges[ty] = std::move(rep); @@ -181,8 +78,6 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (auto& [ty, rep] : typeVarChanges) *asMutable(ty) = rep.get()->pending; @@ -194,16 +89,12 @@ void TxnLog::commit() void TxnLog::clear() { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - typeVarChanges.clear(); typePackChanges.clear(); } TxnLog TxnLog::inverse() { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - TxnLog inversed(sharedSeen); for (auto& [ty, _rep] : typeVarChanges) @@ -247,8 +138,6 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) { @@ -265,16 +154,12 @@ bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); @@ -282,7 +167,6 @@ void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) PendingType* TxnLog::queue(TypeId ty) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(!ty->persistent); // Explicitly don't look in ancestors. If we have discovered something new @@ -296,7 +180,6 @@ PendingType* TxnLog::queue(TypeId ty) PendingTypePack* TxnLog::queue(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(!tp->persistent); // Explicitly don't look in ancestors. If we have discovered something new diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8e6b3b52..3fe4c90e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -24,7 +24,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) @@ -36,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) @@ -43,6 +43,8 @@ LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) +LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) +LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) namespace Luau { @@ -652,18 +654,15 @@ ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) { Unifier state = mkUnifier(location); - if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + if (FFlag::DebugLuauFreezeDuringUnification) freeze(currentModule->internalTypes); state.tryUnify(subTy, superTy); - if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + if (FFlag::DebugLuauFreezeDuringUnification) unfreeze(currentModule->internalTypes); - if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); - - if (state.errors.empty() && FFlag::LuauUseCommittingTxnLog) + if (state.errors.empty()) state.log.commit(); return state.errors; @@ -847,8 +846,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) state.tryUnify(valuePack, variablePack); reportErrors(state.errors); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. // We also want to do this for 'local T = setmetatable(...)'. @@ -1040,8 +1038,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) Unifier state = mkUnifier(firstValue->location); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); } @@ -1102,8 +1099,53 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + TableTypeVar* ttv = getMutableTableType(exprTy); + if (!ttv) + { + if (isTableIntersection(exprTy)) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + else if (!get(exprTy) && !get(exprTy)) + reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); + } + else if (ttv->state == TableState::Sealed) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + + ty = follow(ty); + + if (ttv && ttv->state != TableState::Sealed) + ttv->props[name->index.value] = {ty, /* deprecated */ false, {}, name->indexLocation}; + + if (function.func->self) + { + const FunctionTypeVar* funTy = get(ty); + if (!funTy) + ice("Methods should be functions"); + + std::optional arg0 = first(funTy->argTypes); + if (!arg0) + ice("Methods should always have at least 1 argument (self)"); + } + + checkFunctionBody(funScope, ty, *function.func); + + if (ttv && ttv->state != TableState::Sealed) + ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; + } + else if (FFlag::LuauStatFunctionSimplify) + { + LUAU_ASSERT(function.name->is()); + + ty = follow(ty); + + checkFunctionBody(funScope, ty, *function.func); + } else if (function.func->self) { + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + AstExprIndexName* indexName = function.name->as(); if (!indexName) ice("member function declaration has malformed name expression"); @@ -1141,6 +1183,8 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + TypeId leftType = checkLValueBinding(scope, *function.name); checkFunctionBody(funScope, ty, *function.func); @@ -1217,6 +1261,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; + + if (FFlag::LuauFixIncorrectLineNumberDuplicateType) + scope->typeAliasLocations[name] = typealias.location; } } else @@ -2102,9 +2149,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn Unifier state = mkUnifier(expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); - - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) @@ -2283,9 +2328,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isEquality) { state.tryUnify(rhsType, lhsType); - - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); } bool needsMetamethod = !isEquality; @@ -2336,8 +2379,7 @@ TypeId TypeChecker::checkRelationalOperation( return errorRecoveryType(booleanType); } - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); } } @@ -2347,8 +2389,7 @@ TypeId TypeChecker::checkRelationalOperation( state.tryUnify( instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); return booleanType; @@ -2464,25 +2505,15 @@ TypeId TypeChecker::checkBinaryOperation( TypePackId fallbackArguments = freshTypePack(scope); TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); state.errors.clear(); - - if (FFlag::LuauUseCommittingTxnLog) - { - state.log.clear(); - } - else - { - state.DEPRECATED_log.rollback(); - } + state.log.clear(); state.tryUnify(actualFunctionType, fallbackFunctionType, /*isFunctionCall*/ true); - if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + if (state.errors.empty()) state.log.commit(); - else if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); } - if (FFlag::LuauUseCommittingTxnLog && !hasErrors) + if (!hasErrors) { state.log.commit(); } @@ -2729,13 +2760,11 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); retType = errorRecoveryType(retType); } - else if (FFlag::LuauUseCommittingTxnLog) + else state.log.commit(); return retType; @@ -3209,7 +3238,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } // Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount(TypePackId tp) +static size_t getMinParameterCount_DEPRECATED(TypePackId tp) { size_t minCount = 0; size_t optionalCount = 0; @@ -3235,6 +3264,32 @@ static size_t getMinParameterCount(TypePackId tp) return minCount; } +static size_t getMinParameterCount(TxnLog* log, TypePackId tp) +{ + size_t minCount = 0; + size_t optionalCount = 0; + + auto it = begin(tp, log); + auto endIter = end(tp); + + while (it != endIter) + { + TypeId ty = *it; + if (isOptional(ty)) + ++optionalCount; + else + { + minCount += optionalCount; + optionalCount = 0; + minCount++; + } + + ++it; + } + + return minCount; +} + void TypeChecker::checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) { @@ -3248,396 +3303,199 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - size_t minParams = getMinParameterCount(paramPack); + size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); - if (FFlag::LuauUseCommittingTxnLog) + while (true) { - while (true) + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + + if (argIter == endIter && paramIter == endIter) { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); - if (argIter == endIter && paramIter == endIter) + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); - - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. - - if (argTail) + if (state.log.getMutable(state.log.follow(*argTail))) { - if (state.log.getMutable(state.log.follow(*argTail))) - { - if (paramTail) - state.tryUnify(*paramTail, *argTail); - else - state.log.replace(*argTail, TypePackVar(TypePack{{}})); - } - } - else if (paramTail) - { - // argTail is definitely empty - if (state.log.getMutable(state.log.follow(*paramTail))) - state.log.replace(*paramTail, TypePackVar(TypePack{{}})); - } - - return; - } - else if (argIter == endIter) - { - // Not enough arguments. - - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) - { - TypePackId tail = *argIter.tail(); - if (state.log.getMutable(tail)) - { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) - { - state.tryUnify(errorRecoveryType(anyType), *paramIter); - ++paramIter; - } - return; - } - else if (auto vtp = state.log.getMutable(tail)) - { - while (paramIter != endIter) - { - state.tryUnify(vtp->ty, *paramIter); - ++paramIter; - } - - return; - } - else if (state.log.getMutable(tail)) - { - std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) - { - rest.push_back(*paramIter); - ++paramIter; - } - - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(varPack, tail); - return; - } - } - - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = state.log.follow(*paramIter); - if (isOptional(t)) - { - } // ok - else if (state.log.getMutable(t)) - { - } // ok - else if (isNonstrictMode() && state.log.getMutable(t)) - { - } // ok + if (paramTail) + state.tryUnify(*paramTail, *argTail); else - { - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); - return; - } - ++paramIter; + state.log.replace(*argTail, TypePackVar(TypePack{{}})); } } - else if (paramIter == endIter) + else if (paramTail) { - // too many parameters passed - if (!paramIter.tail()) - { - while (argIter != endIter) - { - // The use of unify here is deliberate. We don't want this unification - // to be undoable. - unify(errorRecoveryType(scope), *argIter, state.location); - ++argIter; - } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = state.log.follow(*paramIter.tail()); + // argTail is definitely empty + if (state.log.getMutable(state.log.follow(*paramTail))) + state.log.replace(*paramTail, TypePackVar(TypePack{{}})); + } + return; + } + else if (argIter == endIter) + { + // Not enough arguments. + + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) + { + TypePackId tail = *argIter.tail(); if (state.log.getMutable(tail)) { - // Function is variadic. Ok. + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(errorRecoveryType(anyType), *paramIter); + ++paramIter; + } return; } else if (auto vtp = state.log.getMutable(tail)) { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) + while (paramIter != endIter) { - Location location = state.location; - - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; - - unify(*argIter, vtp->ty, location); - ++argIter; - ++argIndex; + state.tryUnify(vtp->ty, *paramIter); + ++paramIter; } return; } else if (state.log.getMutable(tail)) { - // Create a type pack out of the remaining argument types - // and unify it with the tail. std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) { - rest.push_back(*argIter); - ++argIter; + rest.push_back(*paramIter); + ++paramIter; } - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); state.tryUnify(varPack, tail); return; } - else if (state.log.getMutable(tail)) - { - state.log.replace(tail, TypePackVar(TypePack{{}})); - return; - } - else if (state.log.getMutable(tail)) - { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } } - else + + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) { - unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); - ++argIter; + TypeId t = state.log.follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (state.log.getMutable(t)) + { + } // ok + else if (isNonstrictMode() && state.log.getMutable(t)) + { + } // ok + else + { + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } ++paramIter; } - - ++paramIndex; } - } - else - { - while (true) + else if (paramIter == endIter) { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - - if (argIter == endIter && paramIter == endIter) + // too many parameters passed + if (!paramIter.tail()) { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); - - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. - - if (argTail) + while (argIter != endIter) { - if (get(*argTail)) - { - if (paramTail) - state.tryUnify(*paramTail, *argTail); - else - { - state.DEPRECATED_log(*argTail); - *asMutable(*argTail) = TypePack{{}}; - } - } + // The use of unify here is deliberate. We don't want this unification + // to be undoable. + unify(errorRecoveryType(scope), *argIter, state.location); + ++argIter; } - else if (paramTail) + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + TypePackId tail = state.log.follow(*paramIter.tail()); + + if (state.log.getMutable(tail)) + { + // Function is variadic. Ok. + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) { - // argTail is definitely empty - if (get(*paramTail)) - { - state.DEPRECATED_log(*paramTail); - *asMutable(*paramTail) = TypePack{{}}; - } + Location location = state.location; + + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; } return; } - else if (argIter == endIter) + else if (state.log.getMutable(tail)) { - // Not enough arguments. - - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) + // Create a type pack out of the remaining argument types + // and unify it with the tail. + std::vector rest; + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) { - TypePackId tail = *argIter.tail(); - if (get(tail)) - { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) - { - state.tryUnify(*paramIter, errorRecoveryType(anyType)); - ++paramIter; - } - return; - } - else if (auto vtp = get(tail)) - { - while (paramIter != endIter) - { - state.tryUnify(*paramIter, vtp->ty); - ++paramIter; - } - - return; - } - else if (get(tail)) - { - std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) - { - rest.push_back(*paramIter); - ++paramIter; - } - - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(varPack, tail); - return; - } + rest.push_back(*argIter); + ++argIter; } - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = follow(*paramIter); - if (isOptional(t)) - { - } // ok - else if (get(t)) - { - } // ok - else if (isNonstrictMode() && get(t)) - { - } // ok - else - { - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); - return; - } - ++paramIter; - } + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(varPack, tail); + return; } - else if (paramIter == endIter) + else if (state.log.getMutable(tail)) { - // too many parameters passed - if (!paramIter.tail()) - { - while (argIter != endIter) - { - unify(*argIter, errorRecoveryType(scope), state.location); - ++argIter; - } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = *paramIter.tail(); - - if (get(tail)) - { - // Function is variadic. Ok. - return; - } - else if (auto vtp = get(tail)) - { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) - { - Location location = state.location; - - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; - - unify(*argIter, vtp->ty, location); - ++argIter; - ++argIndex; - } - - return; - } - else if (get(tail)) - { - // Create a type pack out of the remaining argument types - // and unify it with the tail. - std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) - { - rest.push_back(*argIter); - ++argIter; - } - - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(tail, varPack); - return; - } - else if (get(tail)) - { - if (FFlag::LuauUseCommittingTxnLog) - { - state.log.replace(tail, TypePackVar(TypePack{{}})); - } - else - { - state.DEPRECATED_log(tail); - *asMutable(tail) = TypePack{}; - } - - return; - } - else if (get(tail)) - { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } + state.log.replace(tail, TypePackVar(TypePack{{}})); + return; } - else + else if (state.log.getMutable(tail)) { - unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); - ++argIter; - ++paramIter; + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; } - - ++paramIndex; } + else + { + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; + ++paramIter; + } + + ++paramIndex; } } @@ -3882,9 +3740,6 @@ std::optional> TypeChecker::checkCallOverload(const Scope checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); if (!state.errors.empty()) { - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); - return {}; } @@ -3912,14 +3767,10 @@ std::optional> TypeChecker::checkCallOverload(const Scope overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); - - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); } else { - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { @@ -3976,8 +3827,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - editedState.log.commit(); + editedState.log.commit(); reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); // This is a little bit suspect: If this overload would work with a . replaced by a : @@ -3987,8 +3837,6 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else if (!FFlag::LuauUseCommittingTxnLog) - editedState.DEPRECATED_log.rollback(); } else if (ftv->hasSelf) { @@ -4010,8 +3858,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - editedState.log.commit(); + editedState.log.commit(); reportError(TypeError{expr.location, FunctionRequiresSelf{}}); // This is a little bit suspect: If this overload would work with a : replaced by a . @@ -4021,8 +3868,6 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else if (!FFlag::LuauUseCommittingTxnLog) - editedState.DEPRECATED_log.rollback(); } } } @@ -4082,7 +3927,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } - if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + if (state.errors.empty()) state.log.commit(); if (i > 0) @@ -4092,9 +3937,6 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast s += "and "; s += toString(overload); - - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); } if (overloadsThatMatchArgCount.size() == 0) @@ -4168,24 +4010,16 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L // just performed. There's not a great way to pass that into checkExpr. Instead, we store // the inverse of the current log, and commit it. When we're done, we'll commit all the // inverses. This isn't optimal, and a better solution is welcome here. - if (FFlag::LuauUseCommittingTxnLog) - { - inverseLogs.push_back(state.log.inverse()); - state.log.commit(); - } + inverseLogs.push_back(state.log.inverse()); + state.log.commit(); } tp->head.push_back(actualType); } } - if (FFlag::LuauUseCommittingTxnLog) - { - for (TxnLog& log : inverseLogs) - log.commit(); - } - else - state.DEPRECATED_log.rollback(); + for (TxnLog& log : inverseLogs) + log.commit(); return {pack, predicates}; } @@ -4294,8 +4128,7 @@ bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, Unifier state = mkUnifier(location); state.tryUnify(subTy, superTy, options.isFunctionCall); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); @@ -4308,8 +4141,7 @@ bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& lo state.ctx = ctx; state.tryUnify(subTy, superTy); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); @@ -4321,8 +4153,7 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s Unifier state = mkUnifier(location); unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); @@ -4352,31 +4183,18 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - if (FFlag::LuauUseCommittingTxnLog) - state.log.concat(std::move(child.log)); - else - state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + state.log.concat(std::move(child.log)); state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); } else { - if (!FFlag::LuauUseCommittingTxnLog) - child.DEPRECATED_log.rollback(); - state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } } else { - if (FFlag::LuauUseCommittingTxnLog) - { - state.log.concat(std::move(child.log)); - } - else - { - state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); - } + state.log.concat(std::move(child.log)); } } } @@ -4540,7 +4358,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - Instantiation instantiation{FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(), ¤tModule->internalTypes, scope->level}; + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index b15548a8..91123f46 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - namespace Luau { @@ -51,16 +49,8 @@ TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) { while (tp && tp->head.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - { - currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; - tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; - } - else - { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; - } + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; } } @@ -71,16 +61,8 @@ TypePackIterator& TypePackIterator::operator++() ++currentIndex; while (tp && currentIndex >= tp->head.size()) { - if (FFlag::LuauUseCommittingTxnLog) - { - currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; - tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; - } - else - { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; - } + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; currentIndex = 0; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6c29486a..7b781f26 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -15,30 +15,28 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAG(LuauImmutableTypes) -LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, true) +LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) namespace Luau { struct PromoteTypeLevels { - DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; const TypeArena* typeArena = nullptr; TypeLevel minLevel; - explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) - : DEPRECATED_log(DEPRECATED_log) - , log(log) + PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) + : log(log) , typeArena(typeArena) , minLevel(minLevel) { @@ -50,15 +48,7 @@ struct PromoteTypeLevels LUAU_ASSERT(t); if (minLevel.subsumesStrict(t->level)) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeLevel(ty, minLevel); - } - else - { - DEPRECATED_log(ty); - t->level = minLevel; - } + log.changeLevel(ty, minLevel); } } @@ -81,10 +71,10 @@ struct PromoteTypeLevels { // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauUseCommittingTxnLog && !log.is(ty)) + if (!log.is(ty)) return true; - promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -94,7 +84,7 @@ struct PromoteTypeLevels if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) return false; - promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -107,7 +97,7 @@ struct PromoteTypeLevels if (ttv.state != TableState::Free && ttv.state != TableState::Generic) return true; - promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -115,33 +105,33 @@ struct PromoteTypeLevels { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauUseCommittingTxnLog && !log.is(tp)) + if (!log.is(tp)) return true; - promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); + promote(tp, log.getMutable(tp)); return true; } }; -static void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) +static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) return; - PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; + PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } // TODO: use this and make it static. -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) return; - PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; + PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } @@ -251,7 +241,7 @@ struct SkipCacheForType bool Widen::isDirty(TypeId ty) { - return FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty)); + return log->is(ty); } bool Widen::isDirty(TypePackId) @@ -262,7 +252,7 @@ bool Widen::isDirty(TypePackId) TypeId Widen::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - auto stv = FFlag::LuauUseCommittingTxnLog ? log->getMutable(ty) : getMutable(ty); + auto stv = log->getMutable(ty); LUAU_ASSERT(stv); if (get(stv)) @@ -284,11 +274,11 @@ bool Widen::ignoreChildren(TypeId ty) { // Sometimes we unify ("hi") -> free1 with (free2) -> free3, so don't ignore functions. // TODO: should we be doing this? we would need to rework how checkCallOverload does the unification. - if (FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))) + if (log->is(ty)) return false; // We only care about unions. - return !(FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))); + return !log->is(ty); } static std::optional hasUnificationTooComplex(const ErrorVec& errors) @@ -335,7 +325,6 @@ Unifier::Unifier(TypeArena* types, Mode mode, std::vector(superTy); - auto subFree = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superFree = log.getMutable(superTy); - subFree = log.getMutable(subTy); - } + auto superFree = log.getMutable(superTy); + auto subFree = log.getMutable(subTy); if (superFree && subFree && superFree->level.subsumes(subFree->level)) { occursCheck(subTy, superTy); // The occurrence check might have caused superTy no longer to be a free type - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(subTy)); - else - occursFailed = bool(get(subTy)); + bool occursFailed = bool(log.getMutable(subTy)); if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.replace(subTy, BoundTypeVar(superTy)); - } - else - { - DEPRECATED_log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); - } + log.replace(subTy, BoundTypeVar(superTy)); } return; } else if (superFree && subFree) { - if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) - { - DEPRECATED_log(superTy); - subFree->level = min(subFree->level, superFree->level); - } - occursCheck(superTy, subTy); - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(superTy)); - else - occursFailed = bool(get(superTy)); - - if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) - { - *asMutable(superTy) = BoundTypeVar(subTy); - return; - } + bool occursFailed = bool(log.getMutable(superTy)); if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) + if (superFree->level.subsumes(subFree->level)) { - if (superFree->level.subsumes(subFree->level)) - { - log.changeLevel(subTy, superFree->level); - } + log.changeLevel(subTy, superFree->level); + } - log.replace(superTy, BoundTypeVar(subTy)); - } - else - { - DEPRECATED_log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); - subFree->level = min(subFree->level, superFree->level); - } + log.replace(superTy, BoundTypeVar(subTy)); } return; @@ -460,14 +398,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool TypeLevel superLevel = superFree->level; occursCheck(superTy, subTy); - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(superTy)); - else - occursFailed = bool(get(superTy)); + bool occursFailed = bool(log.getMutable(superTy)); // Unification can't change the level of a generic. - auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); + auto subGeneric = log.getMutable(subTy); if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 @@ -478,18 +412,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // The occurrence check might have caused superTy no longer to be a free type if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - log.replace(superTy, BoundTypeVar(widen(subTy))); - } - else - { - promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - - DEPRECATED_log(superTy); - *asMutable(superTy) = BoundTypeVar(widen(subTy)); - } + promoteTypeLevels(log, types, superLevel, subTy); + log.replace(superTy, BoundTypeVar(widen(subTy))); } return; @@ -499,14 +423,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(subTy)); - else - occursFailed = bool(get(subTy)); + bool occursFailed = bool(log.getMutable(subTy)); // Unification can't change the level of a generic. - auto superGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy); + auto superGeneric = log.getMutable(superTy); if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 @@ -516,18 +436,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); - log.replace(subTy, BoundTypeVar(superTy)); - } - else - { - promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); - - DEPRECATED_log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); - } + promoteTypeLevels(log, types, subLevel, superTy); + log.replace(subTy, BoundTypeVar(superTy)); } return; @@ -550,55 +460,38 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. - if (FFlag::LuauUseCommittingTxnLog) - { - if (log.haveSeen(superTy, subTy)) - return; + if (log.haveSeen(superTy, subTy)) + return; - log.pushSeen(superTy, subTy); - } - else - { - if (DEPRECATED_log.haveSeen(superTy, subTy)) - return; + log.pushSeen(superTy, subTy); - DEPRECATED_log.pushSeen(superTy, subTy); - } - - if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); } - else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + else if (const UnionTypeVar* uv = log.getMutable(superTy)) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } - else if (const IntersectionTypeVar* uv = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + else if (const IntersectionTypeVar* uv = log.getMutable(superTy)) { tryUnifyTypeWithIntersection(subTy, superTy, uv); } - else if (const IntersectionTypeVar* uv = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + else if (const IntersectionTypeVar* uv = log.getMutable(subTy)) { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && - ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || - (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && - (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) + else if (FFlag::LuauSingletonTypes && (log.getMutable(superTy) || log.getMutable(superTy)) && + log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); @@ -607,29 +500,23 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + else if (log.getMutable(superTy)) tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + else if (log.getMutable(subTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + else if (log.getMutable(superTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + else if (log.getMutable(subTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superTy, subTy); - else - DEPRECATED_log.popSeen(superTy, subTy); + log.popSeen(superTy, subTy); } void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy) @@ -660,28 +547,12 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) { - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } else { - if (FFlag::LuauUseCommittingTxnLog) + if (i == count - 1) { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } - } - else - { - if (i != count - 1) - { - innerState.DEPRECATED_log.rollback(); - } - else - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } + log.concat(std::move(innerState.log)); } ++i; @@ -692,7 +563,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) { auto tryBind = [this, subTy](TypeId superOption) { - superOption = FFlag::LuauUseCommittingTxnLog ? log.follow(superOption) : follow(superOption); + superOption = log.follow(superOption); // just skip if the superOption is not free-ish. auto ttv = log.getMutable(superOption); @@ -701,36 +572,17 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (FFlag::LuauUseCommittingTxnLog) + if (log.haveSeen(subTy, superOption)) { - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - } - else - { - if (DEPRECATED_log.haveSeen(subTy, superOption)) - { - if (auto ttv = getMutable(superOption)) - { - DEPRECATED_log(ttv); - ttv->boundTo = subTy; - } - else - { - DEPRECATED_log(superOption); - *asMutable(superOption) = BoundTypeVar(subTy); - } - } + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); } }; - if (auto utv = (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) + if (auto utv = log.getMutable(superTy)) { for (TypeId ty : utv) tryBind(ty); @@ -815,10 +667,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (innerState.errors.empty()) { found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); break; } @@ -833,9 +682,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (!failedOption) failedOption = {innerState.errors.front()}; } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -870,10 +716,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I firstFailedOption = {innerState.errors.front()}; } - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); } if (unificationTooComplex) @@ -915,19 +758,13 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (innerState.errors.empty()) { found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) { unificationTooComplex = e; } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -971,78 +808,6 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy) sharedState.cachedUnify.insert({subTy, superTy}); } -struct DEPRECATED_WeirdIter -{ - TypePackId packId; - const TypePack* pack; - size_t index; - bool growing; - TypeLevel level; - - DEPRECATED_WeirdIter(TypePackId packId) - : packId(packId) - , pack(get(packId)) - , index(0) - , growing(false) - { - while (pack && pack->head.empty() && pack->tail) - { - packId = *pack->tail; - pack = get(packId); - } - } - - DEPRECATED_WeirdIter(const DEPRECATED_WeirdIter&) = default; - - const TypeId& operator*() - { - LUAU_ASSERT(good()); - return pack->head[index]; - } - - bool good() const - { - return pack != nullptr && index < pack->head.size(); - } - - bool advance() - { - if (!pack) - return good(); - - if (index < pack->head.size()) - ++index; - - if (growing || index < pack->head.size()) - return good(); - - if (pack->tail) - { - packId = follow(*pack->tail); - pack = get(packId); - index = 0; - } - - return good(); - } - - bool canGrow() const - { - return nullptr != get(packId); - } - - void grow(TypePackId newTail) - { - LUAU_ASSERT(canGrow()); - level = get(packId)->level; - *asMutable(packId) = Unifiable::Bound(newTail); - packId = newTail; - pack = get(newTail); - index = 0; - growing = true; - } -}; - struct WeirdIter { TypePackId packId; @@ -1141,9 +906,6 @@ ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) Unifier s = makeChildUnifier(); s.tryUnify_(subTy, superTy); - if (!FFlag::LuauUseCommittingTxnLog) - s.DEPRECATED_log.rollback(); - return s.errors; } @@ -1152,9 +914,6 @@ ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunction Unifier s = makeChildUnifier(); s.tryUnify_(subTy, superTy, isFunctionCall); - if (!FFlag::LuauUseCommittingTxnLog) - s.DEPRECATED_log.rollback(); - return s.errors; } @@ -1200,419 +959,207 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if (FFlag::LuauUseCommittingTxnLog) + superTp = log.follow(superTp); + subTp = log.follow(subTp); + + while (auto tp = log.getMutable(subTp)) { - superTp = log.follow(superTp); - subTp = log.follow(subTp); + if (tp->head.empty() && tp->tail) + subTp = log.follow(*tp->tail); + else + break; + } - while (auto tp = log.getMutable(subTp)) + while (auto tp = log.getMutable(superTp)) + { + if (tp->head.empty() && tp->tail) + superTp = log.follow(*tp->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + return; + + if (log.getMutable(superTp)) + { + occursCheck(superTp, subTp); + + if (!log.getMutable(superTp)) { - if (tp->head.empty() && tp->tail) - subTp = log.follow(*tp->tail); - else - break; + log.replace(superTp, Unifiable::Bound(subTp)); } + } + else if (log.getMutable(subTp)) + { + occursCheck(subTp, superTp); - while (auto tp = log.getMutable(superTp)) + if (!log.getMutable(subTp)) { - if (tp->head.empty() && tp->tail) - superTp = log.follow(*tp->tail); - else - break; + log.replace(subTp, Unifiable::Bound(superTp)); } + } + else if (log.getMutable(superTp)) + tryUnifyWithAny(subTp, superTp); + else if (log.getMutable(subTp)) + tryUnifyWithAny(superTp, subTp); + else if (log.getMutable(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (log.getMutable(subTp)) + tryUnifyVariadics(superTp, subTp, true); + else if (log.getMutable(superTp) && log.getMutable(subTp)) + { + auto superTpv = log.getMutable(superTp); + auto subTpv = log.getMutable(subTp); - if (superTp == subTp) - return; + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = logAwareFlatten(superTp, log); + auto [subTypes, subTail] = logAwareFlatten(subTp, log); - if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) - return; + bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && + (subTail && log.getMutable(*subTail)); - if (log.getMutable(superTp)) + auto superIter = WeirdIter(superTp, log); + auto subIter = WeirdIter(subTp, log); + + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do { - occursCheck(superTp, subTp); + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); - if (!log.getMutable(superTp)) + ++loopCount; + + if (superIter.good() && subIter.growing) { - log.replace(superTp, Unifiable::Bound(subTp)); + subIter.pushType(mkFreshType(subIter.level)); } - } - else if (log.getMutable(subTp)) - { - occursCheck(subTp, superTp); - if (!log.getMutable(subTp)) + if (subIter.good() && superIter.growing) { - log.replace(subTp, Unifiable::Bound(superTp)); + superIter.pushType(mkFreshType(superIter.level)); } - } - else if (log.getMutable(superTp)) - tryUnifyWithAny(subTp, superTp); - else if (log.getMutable(subTp)) - tryUnifyWithAny(superTp, subTp); - else if (log.getMutable(superTp)) - tryUnifyVariadics(subTp, superTp, false); - else if (log.getMutable(subTp)) - tryUnifyVariadics(superTp, subTp, true); - else if (log.getMutable(superTp) && log.getMutable(subTp)) - { - auto superTpv = log.getMutable(superTp); - auto subTpv = log.getMutable(subTp); - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = logAwareFlatten(superTp, log); - auto [subTypes, subTail] = logAwareFlatten(subTp, log); - - bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && - (subTail && log.getMutable(*subTail)); - - auto superIter = WeirdIter(superTp, log); - auto subIter = WeirdIter(subTp, log); - - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; - - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); - - int loopCount = 0; - - do + if (superIter.good() && subIter.good()) { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); + tryUnify_(*subIter, *superIter); - ++loopCount; + if (!errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; - if (superIter.good() && subIter.growing) + superIter.advance(); + subIter.advance(); + continue; + } + + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + if (subTpv->tail && superTpv->tail) { - subIter.pushType(mkFreshType(subIter.level)); + tryUnify_(*subTpv->tail, *superTpv->tail); + break; } - if (subIter.good() && superIter.growing) + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; + } + + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) { - superIter.pushType(mkFreshType(superIter.level)); - } - - if (superIter.good() && subIter.good()) - { - tryUnify_(*subIter, *superIter); - - if (!errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { subIter.advance(); continue; } - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) { - if (subTpv->tail && superTpv->tail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - break; - } - - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail) - tryUnify_(emptyTp, *superTpv->tail); - else if (rFreeTail) - tryUnify_(emptyTp, *subTpv->tail); - - break; + superIter.advance(); + continue; } - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); - - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - else + if (log.getMutable(superIter.packId)) { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } - - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) - { - superIter.advance(); - continue; - } - - if (log.getMutable(superIter.packId)) - { - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - return; - } - - if (log.getMutable(subIter.packId)) - { - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - return; - } - - if (!isFunctionCall && subIter.good()) - { - // Sometimes it is ok to pass too many arguments - return; - } - - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); - - while (superIter.good()) - { - tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); - superIter.advance(); - } - - while (subIter.good()) - { - tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); - subIter.advance(); - } - + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); return; } - } while (!noInfiniteGrowth); - } - else - { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); - } + if (log.getMutable(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } + + return; + } + + } while (!noInfiniteGrowth); } else { - superTp = follow(superTp); - subTp = follow(subTp); - - while (auto tp = get(subTp)) - { - if (tp->head.empty() && tp->tail) - subTp = follow(*tp->tail); - else - break; - } - - while (auto tp = get(superTp)) - { - if (tp->head.empty() && tp->tail) - superTp = follow(*tp->tail); - else - break; - } - - if (superTp == subTp) - return; - - if (FFlag::LuauTxnLogSeesTypePacks2 && DEPRECATED_log.haveSeen(superTp, subTp)) - return; - - if (get(superTp)) - { - occursCheck(superTp, subTp); - - if (!get(superTp)) - { - DEPRECATED_log(superTp); - *asMutable(superTp) = Unifiable::Bound(subTp); - } - } - else if (get(subTp)) - { - occursCheck(subTp, superTp); - - if (!get(subTp)) - { - DEPRECATED_log(subTp); - *asMutable(subTp) = Unifiable::Bound(superTp); - } - } - - else if (get(superTp)) - tryUnifyWithAny(subTp, superTp); - - else if (get(subTp)) - tryUnifyWithAny(superTp, subTp); - - else if (get(superTp)) - tryUnifyVariadics(subTp, superTp, false); - else if (get(subTp)) - tryUnifyVariadics(superTp, subTp, true); - - else if (get(superTp) && get(subTp)) - { - auto superTpv = get(superTp); - auto subTpv = get(subTp); - - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = flatten(superTp); - auto [subTypes, subTail] = flatten(subTp); - - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - - auto superIter = DEPRECATED_WeirdIter{superTp}; - auto subIter = DEPRECATED_WeirdIter{subTp}; - - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; - - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); - - int loopCount = 0; - - do - { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); - - ++loopCount; - - if (superIter.good() && subIter.growing) - asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); - - if (subIter.good() && superIter.growing) - asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); - - if (superIter.good() && subIter.good()) - { - tryUnify_(*subIter, *superIter); - - if (!errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - - superIter.advance(); - subIter.advance(); - continue; - } - - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) - { - if (subTpv->tail && superTpv->tail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - break; - } - - const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (lFreeTail) - tryUnify_(emptyTp, *superTpv->tail); - else if (rFreeTail) - tryUnify_(emptyTp, *subTpv->tail); - - break; - } - - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); - - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - - else - { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } - - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) - { - superIter.advance(); - continue; - } - - if (get(superIter.packId)) - { - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - return; - } - - if (get(subIter.packId)) - { - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - return; - } - - if (!isFunctionCall && subIter.good()) - { - // Sometimes it is ok to pass too many arguments - return; - } - - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); - - while (superIter.good()) - { - tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); - superIter.advance(); - } - - while (subIter.good()) - { - tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); - subIter.advance(); - } - - return; - } - - } while (!noInfiniteGrowth); - } - else - { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); - } + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } @@ -1650,14 +1197,8 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) { - FunctionTypeVar* superFunction = getMutable(superTy); - FunctionTypeVar* subFunction = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superFunction = log.getMutable(superTy); - subFunction = log.getMutable(subTy); - } + FunctionTypeVar* superFunction = log.getMutable(superTy); + FunctionTypeVar* subFunction = log.getMutable(subTy); if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); @@ -1680,20 +1221,14 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal for (size_t i = 0; i < numGenerics; i++) { - if (FFlag::LuauUseCommittingTxnLog) - log.pushSeen(superFunction->generics[i], subFunction->generics[i]); - else - DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } if (FFlag::LuauTxnLogSeesTypePacks2) { for (size_t i = 0; i < numGenericPacks; i++) { - if (FFlag::LuauUseCommittingTxnLog) - log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - else - DEPRECATED_log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } } @@ -1734,14 +1269,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); } - if (FFlag::LuauUseCommittingTxnLog) - { - log.concat(std::move(innerState.log)); - } - else - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } + log.concat(std::move(innerState.log)); } else { @@ -1754,33 +1282,19 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (!FFlag::LuauImmutableTypes) { - if (FFlag::LuauUseCommittingTxnLog) + if (superFunction->definition && !subFunction->definition && !subTy->persistent) { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - PendingType* newSubTy = log.queue(subTy); - FunctionTypeVar* newSubFtv = getMutable(newSubTy); - LUAU_ASSERT(newSubFtv); - newSubFtv->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - PendingType* newSuperTy = log.queue(superTy); - FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); - LUAU_ASSERT(newSuperFtv); - newSuperFtv->definition = subFunction->definition; - } + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; } - else + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - subFunction->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - superFunction->definition = subFunction->definition; - } + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; } } @@ -1790,19 +1304,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { for (int i = int(numGenericPacks) - 1; 0 <= i; i--) { - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - else - DEPRECATED_log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } } for (int i = int(numGenerics) - 1; 0 <= i; i--) { - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superFunction->generics[i], subFunction->generics[i]); - else - DEPRECATED_log.popSeen(superFunction->generics[i], subFunction->generics[i]); + log.popSeen(superFunction->generics[i], subFunction->generics[i]); } } @@ -1833,14 +1341,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!FFlag::LuauTableSubtypingVariance2) return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* superTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1855,8 +1357,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - bool isAny = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(superProp.type)) : get(follow(superProp.type)); + bool isAny = log.getMutable(log.follow(superProp.type)); if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) missingProperties.push_back(propName); @@ -1877,8 +1378,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - bool isAny = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(subProp.type)) : get(follow(subProp.type)); + bool isAny = log.is(log.follow(subProp.type)); if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } @@ -1906,18 +1406,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { @@ -1931,18 +1421,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` @@ -1953,22 +1433,30 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else if (subTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* ttv = getMutable(pendingSub); - LUAU_ASSERT(ttv); - ttv->props[name] = prop; - subTable = ttv; - } - else - { - DEPRECATED_log(subTy); - subTable->props[name] = prop; - } + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* ttv = getMutable(pendingSub); + LUAU_ASSERT(ttv); + ttv->props[name] = prop; + subTable = ttv; } else missingProperties.push_back(name); + + if (FFlag::LuauTxnLogCheckForInvalidation) + { + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } + } } for (const auto& [name, prop] : subTable->props) @@ -1990,18 +1478,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if (superTable->state == TableState::Unsealed) { @@ -2011,18 +1489,10 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) Property clone = prop; clone.type = deeplyOptional(clone.type); - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSuper = log.queue(superTy); - TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); - pendingSuperTtv->props[name] = clone; - superTable = pendingSuperTtv; - } - else - { - DEPRECATED_log(superTy); - superTable->props[name] = clone; - } + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = clone; + superTable = pendingSuperTtv; } else if (variance == Covariant) { @@ -2032,21 +1502,29 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else if (superTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSuper = log.queue(superTy); - TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); - pendingSuperTtv->props[name] = prop; - superTable = pendingSuperTtv; - } - else - { - DEPRECATED_log(superTy); - superTable->props[name] = prop; - } + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = prop; + superTable = pendingSuperTtv; } else extraProperties.push_back(name); + + if (FFlag::LuauTxnLogCheckForInvalidation) + { + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } + } } // Unify indexers @@ -2060,18 +1538,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if (superTable->indexer) { @@ -2081,15 +1549,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - { - DEPRECATED_log(subTy); - subTable->indexer = superTable->indexer; - } + log.changeIndexer(subTy, superTable->indexer); } } else if (subTable->indexer && variance == Invariant) @@ -2097,15 +1557,29 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Symmetric if we are invariant if (superTable->state == TableState::Unsealed || superTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(superTy, subTable->indexer); - } + log.changeIndexer(superTy, subTable->indexer); + } + } + + if (FFlag::LuauTxnLogDontRetryForIndexers) + { + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + else if (FFlag::LuauTxnLogCheckForInvalidation) + { + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); else - { - DEPRECATED_log(superTy); - superTable->indexer = subTable->indexer; - } + return; } } @@ -2134,27 +1608,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (superTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(superTy, subTy); - } - else - { - DEPRECATED_log(superTable); - superTable->boundTo = subTy; - } + log.bindTable(superTy, subTy); } else if (subTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(subTy, superTy); - } - else - { - DEPRECATED_log(subTable); - subTable->boundTo = superTy; - } + log.bindTable(subTy, superTy); } } @@ -2197,14 +1655,8 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt Resetter resetter{&variance}; variance = Invariant; - TableTypeVar* superTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -2231,15 +1683,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt // avoid creating a cycle when the types are already pointing at each other if (follow(superTy) != follow(subTy)) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(superTy, subTy); - } - else - { - DEPRECATED_log(superTable); - superTable->boundTo = subTy; - } + log.bindTable(superTy, subTy); } return; } @@ -2268,14 +1712,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. if (subTable->state == TableState::Unsealed) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - { - subTable->indexer = superTable->indexer; - } + log.changeIndexer(subTy, superTable->indexer); } else reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); @@ -2295,14 +1732,8 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { - TableTypeVar* freeTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - freeTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* freeTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!freeTable || !subTable) ice("passed non-table types to tryUnifyFreeTable"); @@ -2323,22 +1754,11 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (FFlag::LuauUseCommittingTxnLog) - { - if (!log.getMutable(superTy) || !log.getMutable(subTy)) - return tryUnify_(subTy, superTy); + if (!log.getMutable(superTy) || !log.getMutable(subTy)) + return tryUnify_(subTy, superTy); - if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) - return tryUnify_(subTy, superTy); - } - else - { - if (!get(superTy) || !get(subTy)) - return tryUnify_(subTy, superTy); - - if (freeTable->boundTo) - return tryUnify_(subTy, superTy); - } + if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) + return tryUnify_(subTy, superTy); } else { @@ -2346,17 +1766,10 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) // properties than we previously thought. Else, it is an error. if (subTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* pendingSubTtv = getMutable(pendingSub); - LUAU_ASSERT(pendingSubTtv); - pendingSubTtv->props.insert({freeName, freeProp}); - } - else - { - subTable->props.insert({freeName, freeProp}); - } + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* pendingSubTtv = getMutable(pendingSub); + LUAU_ASSERT(pendingSubTtv); + pendingSubTtv->props.insert({freeName, freeProp}); } else reportError(TypeError{location, UnknownProperty{subTy, freeName}}); @@ -2370,47 +1783,23 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); } else if (subTable->state == TableState::Free && freeTable->indexer) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(superTy, subTable->indexer); - } - else - { - freeTable->indexer = subTable->indexer; - } + log.changeIndexer(superTy, subTable->indexer); } if (!freeTable->boundTo && subTable->state != TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(superTy, subTy); - } - else - { - DEPRECATED_log(freeTable); - freeTable->boundTo = subTy; - } + log.bindTable(superTy, subTy); } } void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { - TableTypeVar* superTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); @@ -2476,77 +1865,39 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (superTable->indexer || subTable->indexer) { - if (FFlag::LuauUseCommittingTxnLog) + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) + if (superTable->indexer && !subTable->indexer) { - if (superTable->indexer && !subTable->indexer) - { - log.changeIndexer(subTy, superTable->indexer); - } + log.changeIndexer(subTy, superTable->indexer); } - else if (superTable->state == TableState::Unsealed) + } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) { - if (subTable->indexer && !superTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } + log.changeIndexer(superTy, subTable->indexer); } - else if (superTable->indexer) + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + for (const auto& [name, type] : subTable->props) { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else - { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) - { - if (superTable->indexer && !subTable->indexer) - subTable->indexer = superTable->indexer; - } - else if (superTable->state == TableState::Unsealed) - { - if (subTable->indexer && !superTable->indexer) - superTable->indexer = subTable->indexer; - } - else if (superTable->indexer) - { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - // We already try to unify properties in both tables. - // Skip those and just look for the ones remaining and see if they fit into the indexer. - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } - } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } - if (FFlag::LuauUseCommittingTxnLog) - { - if (!errorReported) - log.concat(std::move(innerState.log)); - } + if (!errorReported) + log.concat(std::move(innerState.log)); else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - - if (errorReported) return; if (!missingPropertiesInSuper.empty()) @@ -2594,8 +1945,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy}}; - if (const MetatableTypeVar* subMetatable = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + if (const MetatableTypeVar* subMetatable = log.getMutable(subTy)) { Unifier innerState = makeChildUnifier(); innerState.tryUnify_(subMetatable->table, superMetatable->table); @@ -2606,27 +1956,16 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) else if (!innerState.errors.empty()) reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); } - else if (TableTypeVar* subTable = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : getMutable(subTy)) + else if (TableTypeVar* subTable = log.getMutable(subTy)) { switch (subTable->state) { case TableState::Free: { tryUnify_(subTy, superMetatable->table); - - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(subTy, superTy); - } - else - { - subTable->boundTo = superTy; - } + log.bindTable(subTy, superTy); break; } @@ -2637,8 +1976,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) reportError(mismatchError); } } - else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) - : (get(subTy) || get(subTy))) + else if (log.getMutable(subTy) || log.getMutable(subTy)) { } else @@ -2711,28 +2049,13 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (FFlag::LuauUseCommittingTxnLog) + if (innerState.errors.empty()) { - if (innerState.errors.empty()) - { - log.concat(std::move(innerState.log)); - } - else - { - ok = false; - } + log.concat(std::move(innerState.log)); } else { - if (innerState.errors.empty()) - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } - else - { - ok = false; - innerState.DEPRECATED_log.rollback(); - } + ok = false; } } } @@ -2747,15 +2070,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!ok) return; - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(subTy, superTy); - } - else - { - DEPRECATED_log(subTable); - subTable->boundTo = superTy; - } + log.bindTable(subTy, superTy); } else return fail(); @@ -2771,54 +2086,30 @@ static void queueTypePack(std::vector& queue, DenseHashSet& { while (true) { - a = FFlag::LuauFollowWithCommittingTxnLogInAnyUnification ? state.log.follow(a) : follow(a); + a = state.log.follow(a); if (seenTypePacks.find(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauUseCommittingTxnLog) + if (state.log.getMutable(a)) { - if (state.log.getMutable(a)) - { - state.log.replace(a, Unifiable::Bound{anyTypePack}); - } - else if (auto tp = state.log.getMutable(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log.replace(a, Unifiable::Bound{anyTypePack}); } - else + else if (auto tp = state.log.getMutable(a)) { - if (get(a)) - { - state.DEPRECATED_log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { - const VariadicTypePack* superVariadic = get(superTp); - - if (FFlag::LuauUseCommittingTxnLog) - { - superVariadic = log.getMutable(superTp); - } + const VariadicTypePack* superVariadic = log.getMutable(superTp); if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); @@ -2843,15 +2134,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever TypePackId tail = follow(*maybeTail); if (get(tail)) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.replace(tail, BoundTypePack(superTp)); - } - else - { - DEPRECATED_log(tail); - *asMutable(tail) = BoundTypePack{superTp}; - } + log.replace(tail, BoundTypePack(superTp)); } else if (const VariadicTypePack* vtp = get(tail)) { @@ -2882,103 +2165,54 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { while (!queue.empty()) { - if (FFlag::LuauUseCommittingTxnLog) + TypeId ty = state.log.follow(queue.back()); + queue.pop_back(); + + // Types from other modules don't have free types + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + continue; + + if (seen.find(ty)) + continue; + + seen.insert(ty); + + if (state.log.getMutable(ty)) { - TypeId ty = state.log.follow(queue.back()); - queue.pop_back(); - - // Types from other modules don't have free types - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) - continue; - - if (seen.find(ty)) - continue; - - seen.insert(ty); - - if (state.log.getMutable(ty)) - { - state.log.replace(ty, BoundTypeVar{anyType}); - } - else if (auto fun = state.log.getMutable(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = state.log.getMutable(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); - - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = state.log.getMutable(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (state.log.getMutable(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = state.log.getMutable(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = state.log.getMutable(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. + state.log.replace(ty, BoundTypeVar{anyType}); } + else if (auto fun = state.log.getMutable(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = state.log.getMutable(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = state.log.getMutable(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (state.log.getMutable(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = state.log.getMutable(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = state.log.getMutable(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); else { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.find(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.DEPRECATED_log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); - - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. - } + } // Primitives, any, errors, and generics are left untouched. } } @@ -3038,79 +2272,39 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays occursCheck(seen, needle, tv); }; - if (FFlag::LuauUseCommittingTxnLog) + needle = log.follow(needle); + haystack = log.follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (log.getMutable(needle)) + return; + + if (!log.getMutable(needle)) + ice("Expected needle to be free"); + + if (needle == haystack) { - needle = log.follow(needle); - haystack = log.follow(haystack); + reportError(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryType()); - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (log.getMutable(needle)) - return; - - if (!log.getMutable(needle)) - ice("Expected needle to be free"); - - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - log.replace(needle, *getSingletonTypes().errorRecoveryType()); - - return; - } - - if (log.getMutable(haystack)) - return; - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } + return; } - else + + if (log.getMutable(haystack)) + return; + else if (auto a = log.getMutable(haystack)) { - needle = follow(needle); - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (get(needle)) - return; - - if (!get(needle)) - ice("Expected needle to be free"); - - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - DEPRECATED_log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); - return; - } - - if (get(haystack)) - return; - else if (auto a = get(haystack)) - { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = get(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) + check(ty); } } @@ -3123,87 +2317,45 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { - if (FFlag::LuauUseCommittingTxnLog) + needle = log.follow(needle); + haystack = log.follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (log.getMutable(needle)) + return; + + if (!log.getMutable(needle)) + ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + + while (!log.getMutable(haystack)) { - needle = log.follow(needle); - haystack = log.follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (log.getMutable(needle)) - return; - - if (!log.getMutable(needle)) - ice("Expected needle pack to be free"); - - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - - while (!log.getMutable(haystack)) + if (needle == haystack) { - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); + reportError(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); - return; - } - - if (auto a = get(haystack); a && a->tail) - { - haystack = log.follow(*a->tail); - continue; - } - - break; + return; } - } - else - { - needle = follow(needle); - haystack = follow(haystack); - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (get(needle)) - return; - - if (!get(needle)) - ice("Expected needle pack to be free"); - - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - - while (!get(haystack)) + if (auto a = get(haystack); a && a->tail) { - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - DEPRECATED_log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); - } - - if (auto a = get(haystack); a && a->tail) - { - haystack = follow(*a->tail); - continue; - } - - break; + haystack = log.follow(*a->tail); + continue; } + + break; } } Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauUseCommittingTxnLog) - return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; - else - return Unifier{types, mode, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; + return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/VM/include/lua.h b/VM/include/lua.h index 0a561f27..274c4ed9 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -229,19 +229,46 @@ LUA_API void lua_setthreaddata(lua_State* L, void* data); enum lua_GCOp { + /* stop and resume incremental garbage collection */ LUA_GCSTOP, LUA_GCRESTART, + + /* run a full GC cycle; not recommended for latency sensitive applications */ LUA_GCCOLLECT, + + /* return the heap size in KB and the remainder in bytes */ LUA_GCCOUNT, LUA_GCCOUNTB, + + /* return 1 if GC is active (not stopped); note that GC may not be actively collecting even if it's running */ LUA_GCISRUNNING, - // garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation - // explicit GC steps allow to perform some amount of work at custom points to offset the need for GC assists - // note that GC might also be paused for some duration (until bytes allocated meet the threshold) - // if an explicit step is performed during this pause, it will trigger the start of the next collection cycle + /* + ** perform an explicit GC step, with the step size specified in KB + ** + ** garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation + ** explicit GC steps allow to perform some amount of work at custom points to offset the need for GC assists + ** note that GC might also be paused for some duration (until bytes allocated meet the threshold) + ** if an explicit step is performed during this pause, it will trigger the start of the next collection cycle + */ LUA_GCSTEP, + /* + ** tune GC parameters G (goal), S (step multiplier) and step size (usually best left ignored) + ** + ** garbage collection is incremental and tries to maintain the heap size to balance memory and performance overhead + ** this overhead is determined by G (goal) which is the ratio between total heap size and the amount of live data in it + ** G is specified in percentages; by default G=200% which means that the heap is allowed to grow to ~2x the size of live data. + ** + ** collector tries to collect S% of allocated bytes by interrupting the application after step size bytes were allocated. + ** when S is too small, collector may not be able to catch up and the effective goal that can be reached will be larger. + ** S is specified in percentages; by default S=200% which means that collector will run at ~2x the pace of allocations. + ** + ** it is recommended to set S in the interval [100 / (G - 100), 100 + 100 / (G - 100))] with a minimum value of 150%; for example: + ** - for G=200%, S should be in the interval [150%, 200%] + ** - for G=150%, S should be in the interval [200%, 300%] + ** - for G=125%, S should be in the interval [400%, 500%] + */ LUA_GCSETGOAL, LUA_GCSETSTEPMUL, LUA_GCSETSTEPSIZE, diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index c5bf1c18..b93cbf7c 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -59,33 +59,6 @@ #define LUA_IDSIZE 256 #endif -/* -@@ LUAI_GCGOAL defines the desired top heap size in relation to the live heap -@* size at the end of the GC cycle -** CHANGE it if you want the GC to run faster or slower (higher values -** mean larger GC pauses which mean slower collection.) You can also change -** this value dynamically. -*/ -#ifndef LUAI_GCGOAL -#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ -#endif - -/* -@@ LUAI_GCSTEPMUL / LUAI_GCSTEPSIZE define the default speed of garbage collection -@* relative to memory allocation. -** Every LUAI_GCSTEPSIZE KB allocated, incremental collector collects LUAI_GCSTEPSIZE -** times LUAI_GCSTEPMUL% bytes. -** CHANGE it if you want to change the granularity of the garbage -** collection. -*/ -#ifndef LUAI_GCSTEPMUL -#define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ -#endif - -#ifndef LUAI_GCSTEPSIZE -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ -#endif - /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ #ifndef LUA_MINSTACK #define LUA_MINSTACK 20 diff --git a/VM/include/lualib.h b/VM/include/lualib.h index baf27b47..bebd0a0f 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -54,6 +54,8 @@ LUALIB_API lua_State* luaL_newstate(void); LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint); +LUALIB_API const char* luaL_typename(lua_State* L, int idx); + /* ** =============================================================== ** some useful macros @@ -66,8 +68,6 @@ LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, #define luaL_checkstring(L, n) (luaL_checklstring(L, (n), NULL)) #define luaL_optstring(L, n, d) (luaL_optlstring(L, (n), (d), NULL)) -#define luaL_typename(L, i) lua_typename(L, lua_type(L, (i))) - #define luaL_getmetatable(L, n) (lua_getfield(L, LUA_REGISTRYINDEX, (n))) #define luaL_opt(L, f, n, d) (lua_isnoneornil(L, (n)) ? (d) : f(L, (n))) diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 9a6f7793..9fe2ebb6 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,6 +11,8 @@ #include +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauMorePreciseLuaLTypeName, false) + /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -333,6 +335,19 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) return NULL; } +const char* luaL_typename(lua_State* L, int idx) +{ + if (DFFlag::LuauMorePreciseLuaLTypeName) + { + const TValue* obj = luaA_toobject(L, idx); + return luaT_objtypename(L, obj); + } + else + { + return lua_typename(L, lua_type(L, idx)); + } +} + /* ** {====================================================== ** Generic Buffer manipulation diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 988fd315..2307598e 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_DYNAMIC_FASTFLAG(LuauMorePreciseLuaLTypeName) + static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -186,7 +188,14 @@ static int luaB_gcinfo(lua_State* L) static int luaB_type(lua_State* L) { luaL_checkany(L, 1); - lua_pushstring(L, luaL_typename(L, 1)); + if (DFFlag::LuauMorePreciseLuaLTypeName) + { + lua_pushstring(L, lua_typename(L, lua_type(L, 1))); + } + else + { + lua_pushstring(L, luaL_typename(L, 1)); + } return 1; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index cbeeebd4..ad8ee78a 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,6 +6,13 @@ #include "lobject.h" #include "lstate.h" +/* +** Default settings for GC tunables (settable via lua_gc) +*/ +#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ + /* ** Possible states of the Garbage Collector */ diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index ce890ba8..55b0618a 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -1912,14 +1911,9 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { - ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); - ACFixture fix; - - fix.check(R"( + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end @@ -1927,7 +1921,7 @@ local function bar2(a: string) return a .. 'x' end return target(b@1 )"); - auto ac = fix.autocomplete('1'); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1937,8 +1931,6 @@ return target(b@1 TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local function bar(a: number) return -a end local abc = b@1 @@ -1952,8 +1944,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local function foo() return 1 end local function bar(a: number) return -a end @@ -1978,14 +1968,9 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { - ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); - ACFixture fix; - - fix.check(R"( + check(R"( local function a(x: boolean) end local function b(x: number?) end local function c(x: (number) -> string) end @@ -2002,26 +1987,26 @@ local dc = d(f@4) local ec = e(f@5) )"); - auto ac = fix.autocomplete('1'); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('2'); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('3'); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('4'); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('5'); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -2512,23 +2497,21 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE("autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - fix.loadDefinition(R"( + loadDefinition(R"( declare y: { x: number, } )"); - fix.fileResolver.source["Module/A"] = R"( + fileResolver.source["Module/A"] = R"( local a = y. )"; - fix.frontend.check("Module/A"); + frontend.check("Module/A"); - auto ac = autocomplete(fix.frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); @@ -2646,8 +2629,6 @@ local a: A<(number, s@1> TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local function foo1() return 1 end local function foo2() return "1" end @@ -2720,7 +2701,6 @@ type A = () -> T TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") { - ScopedFastFlag flag("LuauMissingFollowACMetatables", true); check(R"( --!strict local Class = {} @@ -2764,8 +2744,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local bar: ((number) -> number) & (number, number) -> number) local abc = b@1 diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index d584eb2d..711c0aa1 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,6 +7,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauFixIncorrectLineNumberDuplicateType) + TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") @@ -241,6 +243,27 @@ TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") CHECK_EQ(dtd->name, "Foo"); } +TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + type A = string + type B = number + type C = string + type B = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "B"); + + if (FFlag::LuauFixIncorrectLineNumberDuplicateType) + CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); + else + CHECK_EQ(dtd->previousLocation.begin.line + 1, 1); +} + TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index bf990770..8da655b3 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,8 +8,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -443,28 +441,19 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("thread_is_a_type") +TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( local co = coroutine.create(function() end) )"); - // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() == 0); - CHECK_EQ(*fix.typeChecker.threadType, *fix.requireType("co")); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.threadType, *requireType("co")); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("coroutine_resume_anything_goes") +TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( local function nifty(x, y) print(x, y) local z = coroutine.yield(1, 2) @@ -477,17 +466,12 @@ TEST_CASE("coroutine_resume_anything_goes") local answer = coroutine.resume(co, 3) )"); - // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() == 0); + LUAU_REQUIRE_NO_ERRORS(result); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("coroutine_wrap_anything_goes") +TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( --!nonstrict local function nifty(x, y) print(x, y) @@ -501,8 +485,7 @@ TEST_CASE("coroutine_wrap_anything_goes") local answer = f(3) )"); - // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() == 0); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") @@ -961,4 +944,18 @@ TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") CHECK_EQ("*unknown*", toString(requireType("d"))); } +TEST_CASE_FIXTURE(Fixture, "set_metatable_needs_arguments") +{ + ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; + CheckResult result = check(R"( +local a = {b=setmetatable} +a.b() +a:b() +a:b({}) + )"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 0}, {2, 5}}, CountMismatch{2, 0}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index c482847b..547fbab1 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauFixArgumentCountMismatchAmountWithGenericTypes) + TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -786,4 +788,47 @@ local TheDispatcher: Dispatcher = { LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") +{ + CheckResult result = check(R"( +function test(a: number) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") +{ + CheckResult result = check(R"( +function test2(a: number, b: string) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test2, 1, "", 3) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index da035ba1..a5eba5df 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2334,4 +2334,54 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") +{ + ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; + + CheckResult result = check(R"( +-- This example produced a UAF at one point, caused by pointers to table types becoming +-- invalidated by child unifiers. (Calling log.concat can cause pointers to become invalid.) +type _Entry = { + a: number, + + middle: (self: _Entry) -> (), + + z: number +} + +export type AnyEntry = _Entry + +local Entry = {} +Entry.__index = Entry + +function Entry:dispose() + self:middle() + forgetChildren(self) -- unify free with sealed AnyEntry +end + +function forgetChildren(parent: AnyEntry) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf2") +{ + ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; + + CheckResult result = check(R"( +-- Another example that UAFd, this time found by fuzzing. +local _ +do +_._ *= (_[{n0=_[{[{[_]=_,}]=_,}],}])[_] +_ = (_.n0) +end +_._ *= (_[false])[_] +_ = (_.cos) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f63579b5..d7bbad20 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5144,7 +5144,6 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { - ScopedFastFlag committingTxnLog{"LuauUseCommittingTxnLog", true}; ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( @@ -5355,4 +5354,41 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; + + fileResolver.source["game/isAMagicMock"] = R"( +--!nonstrict +return function(value) + return false +end + )"; + + CheckResult result = check(R"( +--!nonstrict +local MagicMock = {} +MagicMock.is = require(game.isAMagicMock) + +function MagicMock.is(value) + return false +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; + + CheckResult result = check(R"( +function string.len(): number + return 1 +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index f6ee3ccc..d8de2594 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - struct TryUnifyFixture : Fixture { TypeArena arena; @@ -43,8 +41,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); CHECK_EQ(functionOne, functionTwo); } @@ -86,8 +83,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") CHECK(state.errors.empty()); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -110,9 +106,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_EQ(1, state.errors.size()); - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -217,34 +210,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } -TEST_CASE("undo_new_prop_on_unsealed_table") -{ - ScopedFastFlag flags[] = { - {"LuauTableSubtypingVariance2", true}, - // This test makes no sense with a committing TxnLog. - {"LuauUseCommittingTxnLog", false}, - }; - // I am not sure how to make this happen in Luau code. - - TryUnifyFixture fix; - - TypeId unsealedTable = fix.arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); - TypeId sealedTable = - fix.arena.addType(TableTypeVar{{{"prop", Property{getSingletonTypes().numberType}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); - - const TableTypeVar* ttv = get(unsealedTable); - REQUIRE(ttv); - - fix.state.tryUnify(sealedTable, unsealedTable); - - // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. - CHECK(!ttv->props.empty()); - - fix.state.DEPRECATED_log.rollback(); - - CHECK(ttv->props.empty()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") { TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); @@ -267,11 +232,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") { - ScopedFastFlag sffs[] = { - {"LuauUseCommittingTxnLog", true}, - {"LuauFollowWithCommittingTxnLogInAnyUnification", true}, - }; - TypePackVar free{FreeTypePack{TypeLevel{}}}; TypePackVar target{TypePack{}}; diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 6b96f449..fcc21c18 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -264,13 +262,9 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("type_pack_hidden_free_tail_infinite_growth") +TEST_CASE_FIXTURE(Fixture, "type_pack_hidden_free_tail_infinite_growth") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( --!nonstrict if _ then _[function(l0)end],l0 = _ @@ -282,8 +276,7 @@ elseif _ then end )"); - // Switch back to LUAU_REQUIRE_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() > 0); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 0e0b6ebb..ad4cecd8 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -282,19 +281,16 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE("optional_union_follow") +TEST_CASE_FIXTURE(Fixture, "optional_union_follow") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( local y: number? = 2 local x = y local function f(a: number, b: typeof(x), c: typeof(x)) return -a end return f() )"); - REQUIRE_EQ(result.errors.size(), 1); - // LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); auto acm = get(result.errors[0]); REQUIRE(acm); diff --git a/tools/LuauVisualize.py b/tools/LuauVisualize.py new file mode 100644 index 00000000..40f8d6be --- /dev/null +++ b/tools/LuauVisualize.py @@ -0,0 +1,107 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# HACK: LLDB's python API doesn't afford anything helpful for getting at variadic template parameters. +# We're forced to resort to parsing names as strings. +def templateParams(s): + depth = 0 + start = s.find('<') + 1 + result = [] + for i, c in enumerate(s[start:], start): + if c == '<': + depth += 1 + elif c == '>': + if depth == 0: + result.append(s[start: i].strip()) + break + depth -= 1 + elif c == ',' and depth == 0: + result.append(s[start: i].strip()) + start = i + 1 + return result + +def getType(target, typeName): + stars = 0 + + typeName = typeName.strip() + while typeName.endswith('*'): + stars += 1 + typeName = typeName[:-1] + + if typeName.startswith('const '): + typeName = typeName[6:] + + ty = target.FindFirstType(typeName.strip()) + for _ in range(stars): + ty = ty.GetPointerType() + + return ty + +def luau_variant_summary(valobj, internal_dict, options): + type_id = valobj.GetChildMemberWithName("typeid").GetValueAsUnsigned() + storage = valobj.GetChildMemberWithName("storage") + params = templateParams(valobj.GetType().GetCanonicalType().GetName()) + stored_type = params[type_id] + value = storage.Cast(stored_type.GetPointerType()).Dereference() + return stored_type.GetDisplayTypeName() + " [" + value.GetValue() + "]" + +class LuauVariantSyntheticChildrenProvider: + node_names = ["type", "value"] + + def __init__(self, valobj, internal_dict): + self.valobj = valobj + self.type_index = None + self.current_type = None + self.type_params = [] + self.stored_value = None + + def num_children(self): + return len(self.node_names) + + def has_children(self): + return True + + def get_child_index(self, name): + try: + return self.node_names.index(name) + except ValueError: + return -1 + + def get_child_at_index(self, index): + try: + node = self.node_names[index] + except IndexError: + return None + + if node == "type": + if self.current_type: + return self.valobj.CreateValueFromExpression(node, f"(const char*)\"{self.current_type.GetDisplayTypeName()}\"") + else: + return self.valobj.CreateValueFromExpression(node, "(const char*)\"\"") + elif node == "value": + if self.stored_value is not None: + if self.current_type is not None: + return self.valobj.CreateValueFromData(node, self.stored_value.GetData(), self.current_type) + else: + return self.valobj.CreateValueExpression(node, "(const char*)\"\"") + else: + return self.valobj.CreateValueFromExpression(node, "(const char*)\"\"") + else: + return None + + def update(self): + self.type_index = self.valobj.GetChildMemberWithName("typeid").GetValueAsSigned() + self.type_params = templateParams(self.valobj.GetType().GetCanonicalType().GetName()) + + if len(self.type_params) > self.type_index: + self.current_type = getType(self.valobj.GetTarget(), self.type_params[self.type_index]) + + if self.current_type: + storage = self.valobj.GetChildMemberWithName("storage") + self.stored_value = storage.Cast(self.current_type.GetPointerType()).Dereference() + else: + self.stored_value = None + else: + self.current_type = None + self.stored_value = None + + return False diff --git a/tools/lldb-formatters.lldb b/tools/lldb-formatters.lldb new file mode 100644 index 00000000..3868ac20 --- /dev/null +++ b/tools/lldb-formatters.lldb @@ -0,0 +1,2 @@ +type synthetic add -x "^Luau::Variant<.+>$" -l LuauVisualize.LuauVariantSyntheticChildrenProvider +type summary add -x "^Luau::Variant<.+>$" -l LuauVisualize.luau_variant_summary From adecd840675c642f1f553d39930aa8014e82faa7 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 17 Mar 2022 17:06:25 -0700 Subject: [PATCH 030/102] Sync to upstream/release/519 --- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/Unifier.h | 2 + Analysis/src/Autocomplete.cpp | 198 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 144 +- Analysis/src/Error.cpp | 11 +- Analysis/src/Module.cpp | 7 + Analysis/src/TypeInfer.cpp | 127 +- Analysis/src/TypeVar.cpp | 10 +- Analysis/src/Unifier.cpp | 78 +- Ast/src/Parser.cpp | 117 +- Compiler/src/Builtins.cpp | 4 +- Compiler/src/Compiler.cpp | 4 - Sources.cmake | 7 + VM/include/lua.h | 15 +- VM/src/lapi.cpp | 20 +- VM/src/lbaselib.cpp | 16 +- VM/src/lgc.h | 4 +- VM/src/lmem.cpp | 2 +- VM/src/lstring.cpp | 2 +- VM/src/ltable.cpp | 71 +- VM/src/ltm.cpp | 3 +- VM/src/ludata.cpp | 8 +- VM/src/ludata.h | 3 + VM/src/lvmexecute.cpp | 5 + tests/Autocomplete.test.cpp | 162 + tests/Compiler.test.cpp | 2 - tests/Conformance.test.cpp | 142 +- tests/Linter.test.cpp | 2 - tests/Module.test.cpp | 10 +- tests/NonstrictMode.test.cpp | 22 +- tests/TypeInfer.aliases.test.cpp | 21 + tests/TypeInfer.anyerror.test.cpp | 335 ++ tests/TypeInfer.builtins.test.cpp | 60 + tests/TypeInfer.definitions.test.cpp | 18 + tests/TypeInfer.functions.test.cpp | 1338 +++++ tests/TypeInfer.generics.test.cpp | 301 ++ tests/TypeInfer.intersectionTypes.test.cpp | 28 + tests/TypeInfer.loops.test.cpp | 473 ++ tests/TypeInfer.modules.test.cpp | 310 ++ tests/TypeInfer.oop.test.cpp | 275 + tests/TypeInfer.operators.test.cpp | 759 +++ tests/TypeInfer.primitives.test.cpp | 100 + tests/TypeInfer.refinements.test.cpp | 18 + tests/TypeInfer.singletons.test.cpp | 116 +- tests/TypeInfer.tables.test.cpp | 500 ++ tests/TypeInfer.test.cpp | 4413 ----------------- tests/TypeInfer.typePacks.cpp | 83 + tests/TypeInfer.unionTypes.test.cpp | 16 + tests/conformance/basic.lua | 19 + tests/conformance/debugger.lua | 4 +- tests/conformance/errors.lua | 2 + tests/conformance/interrupt.lua | 11 + tools/{gdb-printers.py => gdb_printers.py} | 0 tools/lldb-formatters.lldb | 2 - tools/lldb_formatters.lldb | 2 + .../{LuauVisualize.py => lldb_formatters.py} | 0 56 files changed, 5643 insertions(+), 4760 deletions(-) create mode 100644 tests/TypeInfer.anyerror.test.cpp create mode 100644 tests/TypeInfer.functions.test.cpp create mode 100644 tests/TypeInfer.loops.test.cpp create mode 100644 tests/TypeInfer.modules.test.cpp create mode 100644 tests/TypeInfer.oop.test.cpp create mode 100644 tests/TypeInfer.operators.test.cpp create mode 100644 tests/TypeInfer.primitives.test.cpp create mode 100644 tests/conformance/interrupt.lua rename tools/{gdb-printers.py => gdb_printers.py} (100%) delete mode 100644 tools/lldb-formatters.lldb create mode 100644 tools/lldb_formatters.lldb rename tools/{LuauVisualize.py => lldb_formatters.py} (100%) diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index a71e0224..72350255 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -107,6 +107,7 @@ struct FunctionDoesNotTakeSelf struct FunctionRequiresSelf { + // TODO: Delete with LuauAnyInIsOptionalIsOptional int requiredExtraNils = 0; bool operator==(const FunctionRequiresSelf& rhs) const; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 71958f4a..f1ffbcc0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -86,6 +86,8 @@ private: void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId widen(TypeId ty); + TypePackId widen(TypePackId tp); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); void cacheResult(TypeId subTy, TypeId superTy); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index e94c432f..492edf25 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -228,11 +229,22 @@ static std::optional findExpectedTypeAt(const Module& module, AstNode* n return *it; } +static bool checkTypeMatch(TypeArena* typeArena, TypeId subTy, TypeId superTy) +{ + InternalErrorReporter iceReporter; + UnifierSharedState unifierState(&iceReporter); + Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); + + return unifier.canUnify(subTy, superTy).empty(); +} + static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) { ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); @@ -249,20 +261,30 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return true; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { + if (FFlag::LuauSelfCallAutocompleteFix) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return true; - } + if (std::optional firstRetTy = first(ftv->retType)) + return checkTypeMatch(typeArena, *firstRetTy, expectedType); - return false; + return false; + } + else + { + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return true; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return true; + } + + return false; + } }; // We also want to suggest functions that return compatible result @@ -281,7 +303,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + if (FFlag::LuauSelfCallAutocompleteFix) + return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + else + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -291,16 +316,22 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, - AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) +static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId rootTy, TypeId ty, PropIndexType indexType, + const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, + std::optional containingClass = std::nullopt) { + if (FFlag::LuauSelfCallAutocompleteFix) + rootTy = follow(rootTy); + ty = follow(ty); if (seen.count(ty)) return; seen.insert(ty); - auto isWrongIndexer = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { + auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + if (indexType == PropIndexType::Key) return false; @@ -331,6 +362,48 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return colonIndex; } }; + auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix); + + if (indexType == PropIndexType::Key) + return false; + + bool calledWithSelf = indexType == PropIndexType::Colon; + + auto isCompatibleCall = [typeArena, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { + if (get(rootTy)) + { + // Calls on classes require strict match between how function is declared and how it's called + return calledWithSelf == ftv->hasSelf; + } + + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; + } + + return !calledWithSelf; + }; + + if (const FunctionTypeVar* ftv = get(type)) + return !isCompatibleCall(ftv); + + // For intersections, any part that is successful makes the whole call successful + if (const IntersectionTypeVar* itv = get(type)) + { + for (auto subType : itv->parts) + { + if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) + { + if (isCompatibleCall(ftv)) + return false; + } + } + } + + return calledWithSelf; + }; auto fillProps = [&](const ClassTypeVar::Props& props) { for (const auto& [name, prop] : props) @@ -349,7 +422,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - isWrongIndexer(type), + FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -361,34 +434,60 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; - if (auto cls = get(ty)) - { - containingClass = containingClass.value_or(cls); - fillProps(cls->props); - if (cls->parent) - autocompleteProps(module, typeArena, *cls->parent, indexType, nodes, result, seen, cls); - } - else if (auto tbl = get(ty)) - fillProps(tbl->props); - else if (auto mt = get(ty)) - { - autocompleteProps(module, typeArena, mt->table, indexType, nodes, result, seen); - - auto mtable = get(mt->metatable); - if (!mtable) - return; - + auto fillMetatableProps = [&](const TableTypeVar* mtable) { auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { TypeId followed = follow(indexIt->second.type); if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + { + autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); + } else if (auto indexFunction = get(followed)) { std::optional indexFunctionResult = first(indexFunction->retType); if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } + } + }; + + if (auto cls = get(ty)) + { + containingClass = containingClass.value_or(cls); + fillProps(cls->props); + if (cls->parent) + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, cls); + } + else if (auto tbl = get(ty)) + fillProps(tbl->props); + else if (auto mt = get(ty)) + { + autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); + + if (FFlag::LuauSelfCallAutocompleteFix) + { + if (auto mtable = get(mt->metatable)) + fillMetatableProps(mtable); + } + else + { + auto mtable = get(mt->metatable); + if (!mtable) + return; + + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } } } } @@ -400,7 +499,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - autocompleteProps(module, typeArena, ty, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, rootTy, ty, indexType, nodes, inner, innerSeen); for (auto& pair : inner) result.insert(pair); @@ -423,14 +522,17 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (iter == endIter) return; - autocompleteProps(module, typeArena, *iter, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, rootTy, *iter, indexType, nodes, result, seen); ++iter; while (iter != endIter) { AutocompleteEntryMap inner; - std::unordered_set innerSeen = seen; + std::unordered_set innerSeen; + + if (!FFlag::LuauSelfCallAutocompleteFix) + innerSeen = seen; if (isNil(*iter)) { @@ -438,7 +540,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId continue; } - autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, rootTy, *iter, indexType, nodes, inner, innerSeen); std::unordered_set toRemove; @@ -455,6 +557,18 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix) + { + if (pt->metatable) + { + if (auto mtable = get(*pt->metatable)) + fillMetatableProps(mtable); + } + } + else if (FFlag::LuauSelfCallAutocompleteFix && get(get(ty))) + { + autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); + } } static void autocompleteKeywords( @@ -482,7 +596,7 @@ static void autocompleteProps( const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result) { std::unordered_set seen; - autocompleteProps(module, typeArena, ty, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, ty, ty, indexType, nodes, result, seen); } AutocompleteEntryMap autocompleteProps( @@ -1352,7 +1466,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (isString(ty)) + if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty)) return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), finder.ancestry}; else diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 471b61ad..be3fcd7d 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -95,104 +95,104 @@ declare os: { declare function require(target: any): any -declare function getfenv(target: any?): { [string]: any } +declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string declare function gcinfo(): number - declare function print(...: T...) +declare function print(...: T...) - declare function type(value: T): string - declare function typeof(value: T): string - - -- `assert` has a magic function attached that will give more detailed type information - declare function assert(value: T, errorMessage: string?): T +declare function type(value: T): string +declare function typeof(value: T): string - declare function error(message: T, level: number?) +-- `assert` has a magic function attached that will give more detailed type information +declare function assert(value: T, errorMessage: string?): T - declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number? +declare function error(message: T, level: number?) - declare function rawequal(a: T1, b: T2): boolean - declare function rawget(tab: {[K]: V}, k: K): V - declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} +declare function tostring(value: T): string +declare function tonumber(value: T, radix: number?): number? - declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? +declare function rawequal(a: T1, b: T2): boolean +declare function rawget(tab: {[K]: V}, k: K): V +declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} - declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) +declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? - declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) +declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) - -- FIXME: The actual type of `xpcall` is: - -- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) - -- Since we can't represent the return value, we use (boolean, R1...). - declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) +declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) - -- `select` has a magic function attached to provide more detailed type information - declare function select(i: string | number, ...: A...): ...any +-- FIXME: The actual type of `xpcall` is: +-- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) +-- Since we can't represent the return value, we use (boolean, R1...). +declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) - -- FIXME: This type is not entirely correct - `loadstring` returns a function or - -- (nil, string). - declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) +-- `select` has a magic function attached to provide more detailed type information +declare function select(i: string | number, ...: A...): ...any - declare function newproxy(mt: boolean?): any +-- FIXME: This type is not entirely correct - `loadstring` returns a function or +-- (nil, string). +declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), - running: () -> thread, - status: (thread) -> string, - -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, - yield: (A...) -> R..., - isyieldable: () -> boolean, - close: (thread) -> (boolean, any?) - } +declare function newproxy(mt: boolean?): any - declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, +declare coroutine: { + create: ((A...) -> R...) -> thread, + resume: (thread, A...) -> (boolean, R...), + running: () -> thread, + status: (thread) -> string, + -- FIXME: This technically returns a function, but we can't represent this yet. + wrap: ((A...) -> R...) -> any, + yield: (A...) -> R..., + isyieldable: () -> boolean, + close: (thread) -> (boolean, any) +} - unpack: ({V}, number?, number?) -> ...V, - pack: (...V) -> { n: number, [number]: V }, +declare table: { + concat: ({V}, string?, number?, number?) -> string, + insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), + maxn: ({V}) -> number, + remove: ({V}, number?) -> V?, + sort: ({V}, ((V, V) -> boolean)?) -> (), + create: (number, V?) -> {V}, + find: ({V}, V, number?) -> number?, - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), - foreachi: ({V}, (number, V) -> ()) -> (), + unpack: ({V}, number?, number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, - move: ({V}, number, number, number, {V}?) -> {V}, - clear: ({[K]: V}) -> (), + getn: ({V}) -> number, + foreach: ({[K]: V}, (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), - isfrozen: ({[K]: V}) -> boolean, - } + move: ({V}, number, number, number, {V}?) -> {V}, + clear: ({[K]: V}) -> (), - declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), - } + isfrozen: ({[K]: V}) -> boolean, +} - declare utf8: { - char: (number, ...number) -> string, - charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - -- FIXME - codepoint: (string, number?, number?) -> (number, ...number), - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, - nfdnormalize: (string) -> string, - nfcnormalize: (string) -> string, - graphemes: (string, number?, number?) -> (() -> (number, number)), - } +declare debug: { + info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), + traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), +} - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. - declare function unpack(tab: {V}, i: number?, j: number?): ...V +declare utf8: { + char: (number, ...number) -> string, + charpattern: string, + codes: (string) -> ((string, number) -> (number, number), string, number), + -- FIXME + codepoint: (string, number?, number?) -> (number, ...number), + len: (string, number?, number?) -> (number?, number?), + offset: (string, number?, number?) -> number, + nfdnormalize: (string) -> string, + nfcnormalize: (string) -> string, + graphemes: (string, number?, number?) -> (() -> (number, number)), +} + +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V )BUILTIN_SRC"; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 88069f1f..26d3b76d 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); + static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -223,7 +225,14 @@ struct ErrorConverter std::string operator()(const Luau::SyntaxError& e) const { - return "Syntax error: " + e.message; + if (FFlag::BetterDiagnosticCodesInStudio) + { + return e.message; + } + else + { + return "Syntax error: " + e.message; + } } std::string operator()(const Luau::CodeTooComplex&) const diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 76dc72d2..a330a98d 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) LUAU_FASTFLAG(LuauImmutableTypes) namespace Luau @@ -536,6 +537,12 @@ bool Module::clonePublicInterface() if (get(follow(ty))) *asMutable(ty) = AnyTypeVar{}; + if (FFlag::LuauCloneDeclaredGlobals) + { + for (auto& [name, ty] : declaredGlobals) + ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState); + } + freeze(internalTypes); freeze(interfaceTypes); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 3fe4c90e..41e8ce55 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,22 +29,24 @@ LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) +LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) +LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) namespace Luau { @@ -1099,7 +1101,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify) + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify2) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); @@ -1111,7 +1113,10 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); } else if (ttv->state == TableState::Sealed) - reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + { + if (!ttv->indexer || !isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + } ty = follow(ty); @@ -1134,7 +1139,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify) + else if (FFlag::LuauStatFunctionSimplify2) { LUAU_ASSERT(function.name->is()); @@ -1144,7 +1149,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (function.func->self) { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); AstExprIndexName* indexName = function.name->as(); if (!indexName) @@ -1183,7 +1188,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); TypeId leftType = checkLValueBinding(scope, *function.name); @@ -1410,6 +1415,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + + if (FFlag::LuauSelfCallAutocompleteFix) + ftv->hasSelf = true; } } @@ -1883,19 +1891,27 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) { - bool hasNil = false; - - for (TypeId option : utv) + if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (isNil(option)) - { - hasNil = true; - break; - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; } + else + { + bool hasNil = false; - if (!hasNil) - return ty; + for (TypeId option : utv) + { + if (isNil(option)) + { + hasNil = true; + break; + } + } + + if (!hasNil) + return ty; + } std::vector result; @@ -1916,14 +1932,34 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { - if (isOptional(ty)) + if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) + ty = follow(ty); + + if (auto utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + } + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) { reportError(location, OptionalValueAccess{ty}); return follow(*strippedUnion); } } + else + { + if (isOptional(ty)) + { + if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) + { + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); + } + } + } return ty; } @@ -2935,9 +2971,25 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T return errorRecoveryType(scope); } - // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check - if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorRecoveryType(scope); + if (FFlag::LuauStatFunctionSimplify2) + { + if (lhsType->persistent) + return errorRecoveryType(scope); + + // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check + if (ttv->state == TableState::Sealed) + { + if (ttv->indexer && isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) + return ttv->indexer->indexResultType; + else + return errorRecoveryType(scope); + } + } + else + { + if (lhsType->persistent || ttv->state == TableState::Sealed) + return errorRecoveryType(scope); + } Name name = indexName->index.value; @@ -3393,7 +3445,7 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (isNonstrictMode() && state.log.getMutable(t)) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable(t)) { } // ok else @@ -3467,7 +3519,11 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(varPack, tail); + if (FFlag::LuauWidenIfSupertypeIsFree2) + state.tryUnify(varPack, tail); + else + state.tryUnify(tail, varPack); + return; } else if (state.log.getMutable(tail)) @@ -3542,6 +3598,23 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = follow(actualFunctionType); + TypePackId retPack; + if (!FFlag::LuauWidenIfSupertypeIsFree2) + { + retPack = freshTypePack(scope->level); + } + else + { + if (auto free = get(actualFunctionType)) + { + retPack = freshTypePack(free->level); + TypePackId freshArgPack = freshTypePack(free->level); + *asMutable(actualFunctionType) = FunctionTypeVar(free->level, freshArgPack, retPack); + } + else + retPack = freshTypePack(scope->level); + } + // checkExpr will log the pre-instantiated type of the function. // That's not nearly as interesting as the instantiated type, which will include details about how // generic functions are being instantiated for this particular callsite. @@ -3550,8 +3623,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector overloads = flattenIntersection(actualFunctionType); - TypePackId retPack = freshTypePack(scope->level); - std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); @@ -3682,7 +3753,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree) + if (FFlag::LuauWidenIfSupertypeIsFree2) { UnifierOptions options; options.isFunctionCall = true; @@ -3772,7 +3843,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope { state.log.commit(); - if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) + if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND // the function is declared with colon notation AND we use dot notation, warn. diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 5af2c8a6..89549535 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -26,6 +26,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) +LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) namespace Luau { @@ -201,11 +202,16 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - auto utv = get(follow(ty)); + ty = follow(ty); + + if (FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) + return true; + + auto utv = get(ty); if (!utv) return false; - return std::any_of(begin(utv), end(utv), isNil); + return std::any_of(begin(utv), end(utv), FFlag::LuauAnyInIsOptionalIsOptional ? isOptional : isNil); } bool isTableIntersection(TypeId ty) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 7b781f26..60a9c9a5 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -20,11 +20,13 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) +LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) +LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) namespace Luau { @@ -272,12 +274,6 @@ TypePackId Widen::clean(TypePackId) bool Widen::ignoreChildren(TypeId ty) { - // Sometimes we unify ("hi") -> free1 with (free2) -> free3, so don't ignore functions. - // TODO: should we be doing this? we would need to rework how checkCallOverload does the unification. - if (log->is(ty)) - return false; - - // We only care about unions. return !log->is(ty); } @@ -990,7 +986,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (!log.getMutable(superTp)) { - log.replace(superTp, Unifiable::Bound(subTp)); + log.replace(superTp, Unifiable::Bound(widen(subTp))); } } else if (log.getMutable(subTp)) @@ -1107,7 +1103,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) { superIter.advance(); continue; @@ -1280,6 +1276,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(subFunction->retType, superFunction->retType); } + if (FFlag::LuauTxnLogRefreshFunctionPointers) + { + // Updating the log may have invalidated the function pointers + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); + } + if (!FFlag::LuauImmutableTypes) { if (superFunction->definition && !subFunction->definition && !subTy->persistent) @@ -1357,10 +1360,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - bool isAny = log.getMutable(log.follow(superProp.type)); + if (FFlag::LuauAnyInIsOptionalIsOptional) + { + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) + missingProperties.push_back(propName); + } + else + { + bool isAny = log.getMutable(log.follow(superProp.type)); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) - missingProperties.push_back(propName); + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) + missingProperties.push_back(propName); + } } if (!missingProperties.empty()) @@ -1378,9 +1389,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) - extraProperties.push_back(propName); + if (FFlag::LuauAnyInIsOptionalIsOptional) + { + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) + extraProperties.push_back(propName); + } + else + { + bool isAny = log.is(log.follow(subProp.type)); + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) + extraProperties.push_back(propName); + } } if (!extraProperties.empty()) @@ -1424,6 +1443,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } + else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` + // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. + // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) + { + } else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. @@ -1497,6 +1522,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } + else if (FFlag::LuauAnyInIsOptionalIsOptional && !FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) + { + } else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) { } @@ -1618,7 +1646,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TypeId Unifier::widen(TypeId ty) { - if (!FFlag::LuauWidenIfSupertypeIsFree) + if (!FFlag::LuauWidenIfSupertypeIsFree2) return ty; Widen widen{types}; @@ -1627,10 +1655,21 @@ TypeId Unifier::widen(TypeId ty) return result.value_or(ty); } +TypePackId Unifier::widen(TypePackId tp) +{ + if (!FFlag::LuauWidenIfSupertypeIsFree2) + return tp; + + Widen widen{types}; + std::optional result = widen.substitute(tp); + // TODO: what does it mean for substitution to fail to widen? + return result.value_or(tp); +} + TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); - if (get(ty)) + if (!FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) return ty; else if (isOptional(ty)) return ty; @@ -1744,7 +1783,10 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - tryUnify_(freeProp.type, *subProp); + if (FFlag::LuauWidenIfSupertypeIsFree2) + tryUnify_(*subProp, freeProp.type); + else + tryUnify_(freeProp.type, *subProp); /* * TypeVars are commonly cyclic, so it is entirely possible diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1cb8f134..941a3ea4 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,18 +11,11 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) namespace Luau { -static bool isComment(const Lexeme& lexeme) -{ - LUAU_ASSERT(!FFlag::LuauParseAllHotComments); - return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; -} - ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -146,54 +139,13 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); - Parser p(buffer, bufferSize, names, allocator, FFlag::LuauParseAllHotComments ? options : ParseOptions()); + Parser p(buffer, bufferSize, names, allocator, options); try { - if (FFlag::LuauParseAllHotComments) - { - AstStatBlock* root = p.parseChunk(); + AstStatBlock* root = p.parseChunk(); - return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; - } - else - { - std::vector hotcomments; - - while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) - { - const char* text = p.lexer.current().data; - unsigned int length = p.lexer.current().length; - - if (length && text[0] == '!') - { - unsigned int end = length; - while (end > 0 && isSpace(text[end - 1])) - --end; - - hotcomments.push_back({true, p.lexer.current().location, std::string(text + 1, text + end)}); - } - - const Lexeme::Type type = p.lexer.current().type; - const Location loc = p.lexer.current().location; - - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); - - if (type == Lexeme::BrokenComment) - break; - - p.lexer.next(); - } - - p.lexer.setSkipComments(true); - - p.options = options; - - AstStatBlock* root = p.parseChunk(); - - return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; - } + return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; } catch (ParseError& err) { @@ -225,10 +177,11 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); matchRecoveryStopOnToken[Lexeme::Type::Eof] = 1; - if (FFlag::LuauParseAllHotComments) - lexer.setSkipComments(true); + // required for lookahead() to work across a comment boundary and for nextLexeme() to work when captureComments is false + lexer.setSkipComments(true); - // read first lexeme + // read first lexeme (any hot comments get .header = true) + LUAU_ASSERT(hotcommentHeader); nextLexeme(); // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode @@ -2831,49 +2784,31 @@ void Parser::nextLexeme() { if (options.captureComments) { - if (FFlag::LuauParseAllHotComments) + Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + + while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) { - Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + const Lexeme& lexeme = lexer.current(); + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + return; + + // Comments starting with ! are called "hot comments" and contain directives for type checking / linting + if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') { - const Lexeme& lexeme = lexer.current(); - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + const char* text = lexeme.data; - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) - return; + unsigned int end = lexeme.length; + while (end > 0 && isSpace(text[end - 1])) + --end; - // Comments starting with ! are called "hot comments" and contain directives for type checking / linting - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') - { - const char* text = lexeme.data; - - unsigned int end = lexeme.length; - while (end > 0 && isSpace(text[end - 1])) - --end; - - hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); - } - - type = lexer.next(/* skipComments= */ false).type; - } - } - else - { - while (true) - { - const Lexeme& lexeme = lexer.next(/*skipComments*/ false); - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - if (isComment(lexeme)) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - else - return; + hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } + + type = lexer.next(/* skipComments= */ false).type; } } else diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 26360c49..ff753112 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,8 +4,6 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin2, false) - namespace Luau { namespace Compile @@ -64,7 +62,7 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; - if (FFlag::LuauCompileSelectBuiltin2 && builtin.isGlobal("select")) + if (builtin.isGlobal("select")) return LBF_SELECT_VARARG; if (builtin.object == "math") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 656a9926..6330bf1f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,8 +15,6 @@ #include #include -LUAU_FASTFLAG(LuauCompileSelectBuiltin2) - namespace Luau { @@ -265,7 +263,6 @@ struct Compiler void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); LUAU_ASSERT(targetCount == 1); LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); @@ -407,7 +404,6 @@ struct Compiler if (bfid == LBF_SELECT_VARARG) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) diff --git a/Sources.cmake b/Sources.cmake index 615641eb..59b38497 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -232,11 +232,18 @@ if(TARGET Luau.UnitTest) tests/Transpiler.test.cpp tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp + tests/TypeInfer.anyerror.test.cpp tests/TypeInfer.builtins.test.cpp tests/TypeInfer.classes.test.cpp tests/TypeInfer.definitions.test.cpp + tests/TypeInfer.functions.test.cpp tests/TypeInfer.generics.test.cpp tests/TypeInfer.intersectionTypes.test.cpp + tests/TypeInfer.loops.test.cpp + tests/TypeInfer.modules.test.cpp + tests/TypeInfer.oop.test.cpp + tests/TypeInfer.operators.test.cpp + tests/TypeInfer.primitives.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp tests/TypeInfer.singletons.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 274c4ed9..d08b73ea 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -172,9 +172,12 @@ LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); -LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); +LUA_API void lua_pushlightuserdata(lua_State* L, void* p); +LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); + /* ** get functions (Lua -> stack) */ @@ -189,8 +192,6 @@ LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); -LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); -LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API int lua_getmetatable(lua_State* L, int objindex); LUA_API void lua_getfenv(lua_State* L, int idx); @@ -276,6 +277,14 @@ enum lua_GCOp LUA_API int lua_gc(lua_State* L, int what, int data); +/* +** memory statistics +** all allocated bytes are attributed to the memory category of the running thread (0..LUA_MEMORY_CATEGORIES-1) +*/ + +LUA_API void lua_setmemcat(lua_State* L, int category); +LUA_API size_t lua_totalbytes(lua_State* L, int category); + /* ** miscellaneous functions */ diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f7f15442..3c087314 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -35,8 +35,8 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" static Table* getcurrenv(lua_State* L) { - if (L->ci == L->base_ci) /* no enclosing function? */ - return L->gt; /* use global table as environment */ + if (L->ci == L->base_ci) /* no enclosing function? */ + return L->gt; /* use global table as environment */ else return curr_func(L)->env; } @@ -1188,7 +1188,7 @@ void lua_concat(lua_State* L, int n) void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) { - api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT || tag == UTAG_PROXY); luaC_checkGC(L); luaC_checkthreadsleep(L); Udata* u = luaU_newudata(L, sz, tag); @@ -1317,7 +1317,7 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) L->global->udatagc[tag] = dtor; } -LUA_API void lua_clonefunction(lua_State* L, int idx) +void lua_clonefunction(lua_State* L, int idx) { StkId p = index2addr(L, idx); api_check(L, isLfunction(p)); @@ -1333,3 +1333,15 @@ lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; } + +void lua_setmemcat(lua_State* L, int category) +{ + api_check(L, unsigned(category) < LUA_MEMORY_CATEGORIES); + L->activememcat = uint8_t(category); +} + +size_t lua_totalbytes(lua_State* L, int category) +{ + api_check(L, category < LUA_MEMORY_CATEGORIES); + return category < 0 ? L->global->totalbytes : L->global->memcatbytes[category]; +} diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 2307598e..96ad493b 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -5,6 +5,7 @@ #include "lstate.h" #include "lapi.h" #include "ldo.h" +#include "ludata.h" #include #include @@ -190,6 +191,7 @@ static int luaB_type(lua_State* L) luaL_checkany(L, 1); if (DFFlag::LuauMorePreciseLuaLTypeName) { + /* resulting name doesn't differentiate between userdata types */ lua_pushstring(L, lua_typename(L, lua_type(L, 1))); } else @@ -202,8 +204,16 @@ static int luaB_type(lua_State* L) static int luaB_typeof(lua_State* L) { luaL_checkany(L, 1); - const TValue* obj = luaA_toobject(L, 1); - lua_pushstring(L, luaT_objtypename(L, obj)); + if (DFFlag::LuauMorePreciseLuaLTypeName) + { + /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ + lua_pushstring(L, luaL_typename(L, 1)); + } + else + { + const TValue* obj = luaA_toobject(L, 1); + lua_pushstring(L, luaT_objtypename(L, obj)); + } return 1; } @@ -412,7 +422,7 @@ static int luaB_newproxy(lua_State* L) bool needsmt = lua_toboolean(L, 1); - lua_newuserdata(L, 0); + lua_newuserdatatagged(L, 0, UTAG_PROXY); if (needsmt) { diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ad8ee78a..ebf999b5 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -9,9 +9,9 @@ /* ** Default settings for GC tunables (settable via lua_gc) */ -#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ #define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ /* ** Possible states of the Garbage Collector diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 899cb0c0..3cbdafff 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -98,7 +98,7 @@ */ #if defined(__APPLE__) #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : gcc32) -#elif defined(__i386__) +#elif defined(__i386__) && !defined(_MSC_VER) #define ABISWITCH(x64, ms32, gcc32) (gcc32) #else // Android somehow uses a similar ABI to MSVC, *not* to iOS... diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 87250146..c0cd3e26 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -53,7 +53,7 @@ void luaS_resize(lua_State* L, int newsize) { TString* p = tb->hash[i]; while (p) - { /* for each node in the list */ + { /* for each node in the list */ TString* next = p->next; /* save next */ unsigned int h = p->hash; int h1 = lmod(h, newsize); /* new position */ diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index ef0b4b93..2deec2b9 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -24,6 +24,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) + // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -351,6 +353,22 @@ static void setnodevector(lua_State* L, Table* t, int size) t->lastfree = size; /* all positions are free */ } +static TValue* newkey(lua_State* L, Table* t, const TValue* key); + +static TValue* arrayornewkey(lua_State* L, Table* t, const TValue* key) +{ + if (ttisnumber(key)) + { + int k; + double n = nvalue(key); + luai_num2int(k, n); + if (luai_numeq(cast_num(k), n) && cast_to(unsigned int, k - 1) < cast_to(unsigned int, t->sizearray)) + return &t->array[k - 1]; + } + + return newkey(L, t, key); +} + static void resize(lua_State* L, Table* t, int nasize, int nhsize) { if (nasize > MAXSIZE || nhsize > MAXSIZE) @@ -369,22 +387,50 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) for (int i = nasize; i < oldasize; i++) { if (!ttisnil(&t->array[i])) - setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); + { + if (FFlag::LuauTableRehashRework) + { + TValue ok; + setnvalue(&ok, cast_num(i + 1)); + setobjt2t(L, newkey(L, t, &ok), &t->array[i]); + } + else + { + setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); + } + } } /* shrink array */ luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } /* re-insert elements from hash part */ - for (int i = twoto(oldhsize) - 1; i >= 0; i--) + if (FFlag::LuauTableRehashRework) { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) + for (int i = twoto(oldhsize) - 1; i >= 0; i--) { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, luaH_set(L, t, &ok), gval(old)); + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) + { + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); + } } } + else + { + for (int i = twoto(oldhsize) - 1; i >= 0; i--) + { + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) + { + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, luaH_set(L, t, &ok), gval(old)); + } + } + } + if (nold != dummynode) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ } @@ -482,7 +528,16 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) if (n == NULL) { /* cannot find a free place? */ rehash(L, t, key); /* grow table */ - return luaH_set(L, t, key); /* re-insert key into grown table */ + + if (!FFlag::LuauTableRehashRework) + { + return luaH_set(L, t, key); /* re-insert key into grown table */ + } + else + { + // after rehash, numeric keys might be located in the new array part, but won't be found in the node part + return arrayornewkey(L, t, key); + } } LUAU_ASSERT(n != dummynode); TValue mk; diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index a77a7c72..106efb2b 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -4,6 +4,7 @@ #include "lstate.h" #include "lstring.h" +#include "ludata.h" #include "ltable.h" #include "lgc.h" @@ -116,7 +117,7 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) { - if (ttisuserdata(o) && uvalue(o)->tag && uvalue(o)->metatable) + if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) { const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 0dfac508..819d1863 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -22,13 +22,11 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { - LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); - void (*dtor)(void*) = nullptr; - if (u->tag == UTAG_IDTOR) - memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); - else if (u->tag) + if (u->tag < LUA_UTAG_LIMIT) dtor = L->global->udatagc[u->tag]; + else if (u->tag == UTAG_IDTOR) + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); if (dtor) dtor(u->data); diff --git a/VM/src/ludata.h b/VM/src/ludata.h index ec374c28..f24e4a32 100644 --- a/VM/src/ludata.h +++ b/VM/src/ludata.h @@ -7,6 +7,9 @@ /* special tag value is used for user data with inline dtors */ #define UTAG_IDTOR LUA_UTAG_LIMIT +/* special tag value is used for newproxy-created user data (all other user data objects are host-exposed) */ +#define UTAG_PROXY (LUA_UTAG_LIMIT + 1) + #define sizeudata(len) (offsetof(Udata, data) + len) LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 6c31d36f..96a87b7e 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -77,6 +77,11 @@ if (LUAU_UNLIKELY(!!interrupt)) \ { /* the interrupt hook is called right before we advance pc */ \ VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ + if (L->status != 0) \ + { \ + L->ci->savedpc--; \ + goto exit; \ + } \ } \ } #endif diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 55b0618a..17fd6b13 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2755,4 +2755,166 @@ local abc = b@1 CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); } +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + loadDefinition(R"( +declare class Foo + function one(self): number + two: () -> number +end + )"); + + { + check(R"( +local t: Foo +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("one")); + REQUIRE(ac.entryMap.count("two")); + CHECK(!ac.entryMap["one"].wrongIndexType); + CHECK(ac.entryMap["two"].wrongIndexType); + } + + { + check(R"( +local t: Foo +t.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("one")); + REQUIRE(ac.entryMap.count("two")); + CHECK(ac.entryMap["one"].wrongIndexType); + CHECK(!ac.entryMap["two"].wrongIndexType); + } +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local t = {} +function t.m() end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + CHECK(ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end +local t = {} +t.f = f +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("f")); + CHECK(ac.entryMap["f"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_provisional") +{ + check(R"( +local t = {} +function t.m(x: typeof(t)) end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + // We can make changes to mark this as a wrong way to call even though it's compatible + CHECK(!ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local s = "hello" +s:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local s = "hello" +s.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == true); +} + +TEST_CASE_FIXTURE(ACFixture, "string_library_non_self_calls_are_fine") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +string.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACFixture, "string_library_self_calls_are_invalid") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +string:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == true); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f982c86f..3dc57da0 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -2819,8 +2819,6 @@ RETURN R1 -1 TEST_CASE("FastcallSelect") { - ScopedFastFlag sff("LuauCompileSelectBuiltin2", true); - // select(_, ...) compiles to a builtin call CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( LOADK R1 K0 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 63fbb363..9e4cb4a5 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -569,6 +569,11 @@ TEST_CASE("Debugger") CHECK(lua_tointeger(L, -1) == 50); lua_pop(L, 1); + int v = lua_getargument(L, 0, 2); + REQUIRE(v); + CHECK(lua_tointeger(L, -1) == 42); + lua_pop(L, 1); + // test lua_getlocal const char* l = lua_getlocal(L, 0, 1); REQUIRE(l); @@ -652,31 +657,6 @@ TEST_CASE("SameHash") CHECK(luaS_hash(buf + 1, 120) == luaS_hash(buf + 2, 120)); } -TEST_CASE("InlineDtor") -{ - static int dtorhits = 0; - - dtorhits = 0; - - { - StateRef globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - void* u1 = lua_newuserdatadtor(L, 4, [](void* data) { - dtorhits += *(int*)data; - }); - - void* u2 = lua_newuserdatadtor(L, 1, [](void* data) { - dtorhits += *(char*)data; - }); - - *(int*)u1 = 39; - *(char*)u2 = 3; - } - - CHECK(dtorhits == 42); -} - TEST_CASE("Reference") { static int dtorhits = 0; @@ -969,7 +949,7 @@ TEST_CASE("StringConversion") TEST_CASE("GCDump") { // internal function, declared in lgc.h - not exposed via lua.h - extern void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); + extern void luaC_dump(lua_State * L, void* file, const char* (*categoryName)(lua_State * L, uint8_t memcat)); StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1015,4 +995,114 @@ TEST_CASE("GCDump") fclose(f); } +TEST_CASE("Interrupt") +{ + static const int expectedhits[] = { + 2, + 9, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 11, + }; + static int index; + + index = 0; + + runConformance( + "interrupt.lua", + [](lua_State* L) { + auto* cb = lua_callbacks(L); + + // note: for simplicity here we setup the interrupt callback once + // however, this carries a noticeable performance cost. in a real application, + // it's advised to set interrupt callback on a timer from a different thread, + // and set it back to nullptr once the interrupt triggered. + cb->interrupt = [](lua_State* L, int gc) { + if (gc >= 0) + return; + + CHECK(index < int(std::size(expectedhits))); + + lua_Debug ar = {}; + lua_getinfo(L, 0, "l", &ar); + + CHECK(ar.currentline == expectedhits[index]); + + index++; + + // check that we can yield inside an interrupt + if (index == 5) + lua_yield(L, 0); + }; + }, + [](lua_State* L) { + CHECK(index == 5); // a single yield point + }); + + CHECK(index == int(std::size(expectedhits))); +} + +TEST_CASE("UserdataApi") +{ + static int dtorhits = 0; + + dtorhits = 0; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup dtor for tag 42 (created later) + lua_setuserdatadtor(L, 42, [](void* data) { + dtorhits += *(int*)data; + }); + + // light user data + int lud; + lua_pushlightuserdata(L, &lud); + + CHECK(lua_touserdata(L, -1) == &lud); + CHECK(lua_topointer(L, -1) == &lud); + + // regular user data + int* ud1 = (int*)lua_newuserdata(L, 4); + *ud1 = 42; + + CHECK(lua_touserdata(L, -1) == ud1); + CHECK(lua_topointer(L, -1) == ud1); + + // tagged user data + int* ud2 = (int*)lua_newuserdatatagged(L, 4, 42); + *ud2 = -4; + + CHECK(lua_touserdatatagged(L, -1, 42) == ud2); + CHECK(lua_touserdatatagged(L, -1, 41) == nullptr); + CHECK(lua_userdatatag(L, -1) == 42); + + // user data with inline dtor + void* ud3 = lua_newuserdatadtor(L, 4, [](void* data) { + dtorhits += *(int*)data; + }); + + void* ud4 = lua_newuserdatadtor(L, 1, [](void* data) { + dtorhits += *(char*)data; + }); + + *(int*)ud3 = 43; + *(char*)ud4 = 3; + + globalState.reset(); + + CHECK(dtorhits == 42); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 4d6c207c..91b23197 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1667,8 +1667,6 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored TEST_CASE_FIXTURE(Fixture, "WrongComment") { - ScopedFastFlag sff("LuauParseAllHotComments", true); - LintResult result = lint(R"( --!strict --!struct diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index e3993cc5..82b7a350 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -248,10 +248,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TEST_CASE_FIXTURE(Fixture, "clone_self_property") { + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} - function a:foo(x) + function a:foo(x: number) return -x; end return a; @@ -267,10 +269,10 @@ TEST_CASE_FIXTURE(Fixture, "clone_self_property") )"; result = frontend.check("Module/B"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), "This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a " - "dot or pass 1 extra nil to suppress this warning"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 5bad9901..d3faea2a 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -126,6 +126,8 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + CheckResult result = check(R"( --!nonstrict local T = {} @@ -136,31 +138,25 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") T:staticmethod() )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return get(e); - })); - CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return get(e); - })); + CHECK_EQ("This function does not take self. Did you mean to use a dot instead of a colon?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + CheckResult result = check(R"( --!nonstrict local T = {} - function T:method() end - T.method() + function T:method(x: number) end + T.method(5) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - auto e = get(result.errors[0]); - REQUIRE(e != nullptr); - - REQUIRE_EQ(1, e->requiredExtraNils); + CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_props_are_any") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 711c0aa1..b2e76052 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -676,4 +676,25 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: {Tree} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- this would be an infinite type if we allowed it + type Tree = { data: T, children: {Tree<{T}>} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp new file mode 100644 index 00000000..5224b5d8 --- /dev/null +++ b/tests/TypeInfer.anyerror.test.cpp @@ -0,0 +1,335 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferAnyError"); + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") +{ + CheckResult result = check(R"( + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") +{ + CheckResult result = check(R"( + function bar(c) return c end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local l = #this_is_not_defined + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local originalReward = unknown.Parent.Reward:GetChildren()[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local foo = (true).x + foo.x = foo.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "any_type_propagates") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo:method("argument") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "can_subscript_any") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo[5] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +// Not strictly correct: metatables permit overriding this +TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") +{ + CheckResult result = check(R"( + local foo: any = {} + local bar = #foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") +{ + CheckResult result = check(R"( + local f: any + local T = {} + + T.prop = f() + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* ttv = getMutable(requireType("T")); + REQUIRE(ttv); + REQUIRE(ttv->props.count("prop")); + + REQUIRE_EQ("any", toString(ttv->props["prop"].type)); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") +{ + CheckResult result = check(R"( + local A : any + function A.B() end + A:C() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId aType = requireType("A"); + CHECK_EQ(aType, typeChecker.anyType); +} + +TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = unknown.Parent.Reward.GetChildren() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* err = get(result.errors[0]); + REQUIRE(err != nullptr); + + CHECK_EQ("unknown", err->name); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = Utility.Create "Foo" {} + )"); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +{ + CheckResult result = check(R"( + local a: any + local b + for _, i in pairs(a) do + b = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") +{ + CheckResult result = check(R"( + local a: any + local b = a() + )"); + + REQUIRE_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") +{ + CheckResult result = check(R"( +local x: any = {} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") +{ + CheckResult result = check(R"( +local x = (true).foo +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") +{ + CheckResult result = check(R"( +--!strict +local T: any +T = {} +T.__index = T +function T.new(...) + local self = {} + setmetatable(self, T) + self:construct(...) + return self +end +function T:construct(index) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") +{ + CheckResult result = check(R"( + local function f(thing: any | string) + local foo = thing.SomeRandomKey + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 8da655b3..ec20a2c7 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -871,6 +871,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -879,6 +880,26 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") end )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree2", true}, + }; + + CheckResult result = check(R"( + local function f(x: (number | boolean)?): number | true + return assert(x) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); } @@ -958,4 +979,43 @@ a:b({}) CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); } +TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") +{ + CheckResult result = check(R"( +local function f(a: typeof(f)) end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") +{ + TypeId mathTy = requireType(typeChecker.globalScope, "math"); + REQUIRE(mathTy); + TableTypeVar* ttv = getMutable(mathTy); + REQUIRE(ttv); + const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); + REQUIRE(ftv); + auto original = ftv->level; + + CheckResult result = check("local a = math.frexp"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(ftv->level.level == original.level); + CHECK(ftv->level.subLevel == original.subLevel); +} + +TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +{ + CheckResult result = check(R"( +local function f(x: string) + local p = x:split('a') + p = table.pack(table.unpack(p, 1, #p - 1)) + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index c6d55793..898d8902 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -293,4 +293,22 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type CHECK_EQ(ty->type->documentationSymbol, std::nullopt); } +TEST_CASE_FIXTURE(Fixture, "single_class_type_identity_in_global_types") +{ + ScopedFastFlag luauCloneDeclaredGlobals{"LuauCloneDeclaredGlobals", true}; + + loadDefinition(R"( +declare class Cls +end + +declare GetCls: () -> (Cls) + )"); + + CheckResult result = check(R"( +local s : Cls = GetCls() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp new file mode 100644 index 00000000..4288098a --- /dev/null +++ b/tests/TypeInfer.functions.test.cpp @@ -0,0 +1,1338 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferFunctions"); + +TEST_CASE_FIXTURE(Fixture, "tc_function") +{ + CheckResult result = check("function five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* fiveType = get(requireType("five")); + REQUIRE(fiveType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_bodies") +{ + CheckResult result = check("function myFunction() local a = 0 a = true end"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.booleanType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type") +{ + CheckResult result = check("function take_five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* takeFiveType = get(requireType("take_five")); + REQUIRE(takeFiveType != nullptr); + + std::vector retVec = flatten(takeFiveType->retType).first; + REQUIRE(!retVec.empty()); + + REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") +{ + CheckResult result = check("function take_five() return 5 end local five = take_five()"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") +{ + CheckResult result = check(R"( + function take_five() + return 5 + end + + take_five().prop = 888 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") +{ + CheckResult result = check(R"( + function f(...) end + + f(1) + f("foo", 2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +{ + CheckResult result = check(R"( + local T = {} + function T.f(...) + local result = {} + + for i = 1, select("#", ...) do + local dictionary = select(i, ...) + for key, value in pairs(dictionary) do + result[key] = value + end + end + + return result + end + + return T + )"); + + auto r = first(getMainModule()->getModuleScope()->returnType); + REQUIRE(r); + + TableTypeVar* ttv = getMutable(*r); + REQUIRE(ttv); + + TypeId k = ttv->props["f"].type; + REQUIRE(k); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply("") + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply(1, "") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> string)} + local T: T + + local a = T.method(T, 4) + local b = T.method(5) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_arguments") +{ + CheckResult result = check(R"( + --!nonstrict + + function g(a: number) end + + g() + + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(1, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_function") +{ + CheckResult result = check(R"( + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") +{ + CheckResult result = check(R"( + local f = function() return f() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_local_function") +{ + CheckResult result = check(R"( + local function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// FIXME: This and the above case get handled very differently. It's pretty dumb. +// We really should unify the two code paths, probably by deleting AstStatFunction. +TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") +{ + CheckResult result = check(R"( + local count + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") +{ + CheckResult result = check(R"( + function f() + return f + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") +{ + CheckResult result = check(R"( + function f(g) + return f(f) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") +{ + CheckResult result = check(R"( + local Get_des + function Get_des(func) + Get_des(func) + end + + local function f(d) + d:IsA("BasePart") + d.Parent:FindFirstChild("Humanoid") + d:IsA("Decal") + end + Get_des(f) + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") +{ + CheckResult result = check(R"( + local d + d:foo() + d:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "local_function") +{ + CheckResult result = check(R"( + function f() + return 8 + end + + function g() + local function f() + return 'hello' + end + return f + end + + local h = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = follow(requireType("h")); + + const FunctionTypeVar* ftv = get(h); + REQUIRE(ftv != nullptr); + + std::optional rt = first(ftv->retType); + REQUIRE(bool(rt)); + + TypeId retType = follow(*rt); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); +} + +TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") +{ + CheckResult result = check(R"( + local p = function(x) return x end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const Luau::FunctionTypeVar* fn = get(requireType("p")); + REQUIRE(fn); + auto ret = first(fn->retType); + REQUIRE(ret); + REQUIRE(get(follow(*ret))); +} + +TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") +{ + CheckResult result = check(R"( + local T = {} + function T.new(a: number?, b: number?, c: number?) return 5 end + local m = T.new() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") +{ + CheckResult result = check(R"( + function get_two() return 5, 6 end + + local a = get_two() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") +{ + CheckResult result = check(R"( + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo() end + + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo(): number + return 1 + end + foo() + + function foo(n: number): number + return 2 + end + foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("(number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") +{ + CheckResult result = check(R"( + local i = 0 + function most_of_the_natural_numbers(): number? + if i < 10 then + i = i + 1 + return i + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); + + std::optional retType = first(functionType->retType); + REQUIRE(retType); + CHECK(get(*retType)); +} + +TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") +{ + CheckResult result = check(R"( + function apply(f, x) + return f(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("apply")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const FunctionTypeVar* fType = get(follow(argVec[0])); + REQUIRE(fType != nullptr); + + std::vector fArgs = flatten(fType->argTypes).first; + + TypeId xType = argVec[1]; + + CHECK_EQ(1, fArgs.size()); + CHECK_EQ(xType, fArgs[0]); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(6, argVec.size()); + + const FunctionTypeVar* fType = get(follow(argVec[0])); + REQUIRE(fType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") +{ + CheckResult result = check(R"( + function swap(p) + local t = p[0] + p[0] = p[1] + p[1] = t + return nil + end + + function swapTwice(p) + swap(p) + swap(p) + return p + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("swapTwice")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(1, argVec.size()); + + const TableTypeVar* argType = get(follow(argVec[0])); + REQUIRE(argType != nullptr); + + CHECK(bool(argType->indexer)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + + function mergesort(arr, comp) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + local width = 1 + while width < #arr do + for i = 1, #arr, 2*width do + bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) + end + local temp = work + work = arr + arr = temp + width = width * 2 + end + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + /* + * mergesort takes two arguments: an array of some type T and a function that takes two Ts. + * We must assert that these two types are in fact the same type. + * In other words, comp(arr[x], arr[y]) is well-typed. + */ + + const FunctionTypeVar* ftv = get(requireType("mergesort")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const TableTypeVar* arg0 = get(follow(argVec[0])); + REQUIRE(arg0 != nullptr); + REQUIRE(bool(arg0->indexer)); + + const FunctionTypeVar* arg1 = get(follow(argVec[1])); + REQUIRE(arg1 != nullptr); + REQUIRE_EQ(2, size(arg1->argTypes)); + + std::vector arg1Args = flatten(arg1->argTypes).first; + + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); +} + +TEST_CASE_FIXTURE(Fixture, "mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + + function newPlayerCharacter() + startGui() -- Unknown symbol 'startGui' + end + + local characterAddedConnection: any + function startGui() + characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + local x = nil + function f() g() end + -- make sure print(x) doesn't get toposorted here, breaking the mutual block + function g() x = f end + print(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") +{ + CheckResult result = check(R"( + --!nonstrict + + function f() + return 114 + end + + return function() + return f():andThen() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +{ + CheckResult result = check(R"( + function onerror() end + function foo() end + xpcall(foo, onerror) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") +{ + CheckResult result = check(R"( + local mycb: (number, number) -> () + + function f() end + + mycb = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + FunctionExitsWithoutReturning* err = get(result.errors[0]); + CHECK(err); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") +{ + CheckResult result = check(R"( + --!strict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); + CHECK(annotatedErr); + + FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); + CHECK(inferredErr); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") +{ + CheckResult result = check(R"( + function foo(a: number, b: string) end + + foo("Test", 123) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); + + CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ + typeChecker.stringType, + typeChecker.numberType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +{ + CheckResult result = check(R"( + --!nonstrict + + function Test(a) + return 1, "" + end + + + local tab = {} + table.insert(tab, Test(1)); + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + opts.maxTableLength = 0; + + CHECK_EQ("{any}", toString(requireType("tab"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55 + end + + local a, b = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "ignored_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55, "" + end + + local a = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") +{ + CheckResult result = check(R"( + --!strict + + function f(): (number, string) + return 55 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); +} + +TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") +{ + CheckResult result = check(R"( + function foo(a, b): number + return 0 + end + + local a: (string)->number = foo + local b: (number, number)->(number, number) = foo + + local c: (string, number)->number = foo -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(string) -> number", toString(tm1->wantedType)); + CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); + + auto tm2 = get(result.errors[1]); + REQUIRE(tm2); + + CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); + CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") +{ + CheckResult result = check(R"( + --!strict + local tbl = {} + function tbl:abc(a: number, b: number) + return a + end + tbl:abc(1, 2) -- Line 6 + -- | Column 14 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + TypeId type = requireTypeAtPosition(Position(6, 14)); + CHECK_EQ("(tbl, number, number) -> number", toString(type)); + auto ftv = get(type); + REQUIRE(ftv); + CHECK(ftv->hasSelf); +} + +TEST_CASE_FIXTURE(Fixture, "record_matching_overload") +{ + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number) -> number) + local abc: Overload + abc(1) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // AstExprCall is the node that has the overload stored on it. + // findTypeAtPosition will look at the AstExprLocal, but this is not what + // we want to look at. + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); + REQUIRE_GE(ancestry.size(), 2); + AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); + REQUIRE(bool(parentExpr)); + REQUIRE(parentExpr->is()); + + ModulePtr module = getMainModule(); + auto it = module->astOverloadResolvedTypes.find(parentExpr); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +{ + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional function is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a variadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + + // Infer from variadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from variadic packs into variadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( +function f(a: (number) -> nil) return a(4) end +f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +{ + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional function is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a variadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + + // Infer from variadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from variadic packs into variadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( +function f(a: (number) -> nil) return a(4) end +f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") +{ + CheckResult result = check(R"( +type Table = { x: number, y: number } +local f: (Table) -> number = function(t) return t.x + t.y end + +type TableWithFunc = { x: number, y: number, f: (number, number) -> number } +local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") +{ + CheckResult result = check(R"( +local function f(): {string|number} + return {1, "b", 3} +end + +local function g(): (number, {string|number}) + return 4, {1, "b", 3} +end + +local function h(): ...{string|number} + return {4}, {1, "b", 3}, {"s"} +end + +local function i(): ...{string|number} + return {1, "b", 3}, h() +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value. 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + + fileResolver.source["game/isAMagicMock"] = R"( +--!nonstrict +return function(value) + return false +end + )"; + + CheckResult result = check(R"( +--!nonstrict +local MagicMock = {} +MagicMock.is = require(game.isAMagicMock) + +function MagicMock.is(value) + return false +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + + CheckResult result = check(R"( +function string.len(): number + return 1 +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") +{ + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + + CheckResult result = check(R"( + local function f(x: any) end + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + + CheckResult result = check(R"( +local t: {[string]: () -> number} = {} + +function t.a() return 1 end -- OK +function t:b() return 2 end -- not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(*unknown*) -> number' could not be converted into '() -> number' +caused by: + Argument count mismatch. Function expects 1 argument, but none are specified)", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 547fbab1..f360a77c 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1,6 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/Scope.h" + +#include #include "Fixture.h" @@ -830,5 +833,303 @@ wrapper(test2, 1, "", 3) CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); } +TEST_CASE_FIXTURE(Fixture, "generic_function") +{ + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*typeChecker.nilType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "generic_table_method") +{ + CheckResult result = check(R"( + local T = {} + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId tType = requireType("T"); + TableTypeVar* tTable = getMutable(tType); + REQUIRE(tTable != nullptr); + + TypeId barType = tTable->props["bar"].type; + REQUIRE(barType != nullptr); + + const FunctionTypeVar* ftv = get(follow(barType)); + REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); + + std::vector args = flatten(ftv->argTypes).first; + TypeId argType = args.at(1); + + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(5) + end + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const TableTypeVar* t = get(requireType("T")); + REQUIRE(t != nullptr); + + std::optional fooProp = get(t->props, "foo"); + REQUIRE(bool(fooProp)); + + const FunctionTypeVar* foo = get(follow(fooProp->type)); + REQUIRE(bool(foo)); + + std::optional ret_ = first(foo->retType); + REQUIRE(bool(ret_)); + TypeId ret = follow(*ret_); + + REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); +} + +/* + * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as + * f {+ method: function(): (t2, T3...) +} + * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} + * + * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. + * + * The correct unification of the argument to 'g' is + * + * {+ method: function(): (t5, T6...) +} + */ +TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") +{ + auto result = check(R"( + function f(o) + o:method() + end + + function g(o) + f(o) + end + )"); + + TypeId g = requireType("g"); + const FunctionTypeVar* gFun = get(g); + REQUIRE(gFun != nullptr); + + auto optionArg = first(gFun->argTypes); + REQUIRE(bool(optionArg)); + + TypeId arg = follow(*optionArg); + const TableTypeVar* argTable = get(arg); + REQUIRE(argTable != nullptr); + + std::optional methodProp = get(argTable->props, "method"); + REQUIRE(bool(methodProp)); + + const FunctionTypeVar* methodFunction = get(methodProp->type); + REQUIRE(methodFunction != nullptr); + + std::optional methodArg = first(methodFunction->argTypes); + REQUIRE(bool(methodArg)); + + REQUIRE_EQ(follow(*methodArg), follow(arg)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local c: ((number)->number, number)->number = foo -- no error + c = foo -- no error + local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); + CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local _: (string, string)->number = foo -- string cannot be converted to (string)->number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); + CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + // Mutability in type function application right now can create strange recursive types + CheckResult result = check(R"( +type Table = { a: number } +type Self = T +local a: Self
+ )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a")), "Table"); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") +{ + CheckResult result = check(R"( + function _(l0:t0): (any, ()->()) + end + + type t0 = t0 | {} + )"); + + CHECK_LE(0, result.errors.size()); + + std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + REQUIRE(t0); + CHECK_EQ("*unknown*", toString(t0->type)); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { + return get(err); + }); + CHECK(it != result.errors.end()); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end +return sum(2, 3, function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end +local a = {1, 2, 3} +local r = map(a, function(a) return a + a > 100 end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{boolean}", toString(requireType("r"))); + + check(R"( +local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end +local a = {1, 2, 3} +local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") +{ + CheckResult result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12(1, function(x) return x + x end) +g12(1, 2, function(x, y) return x + y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12({x=1}, function(x) return {x=-x.x} end) +g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") +{ + CheckResult result = check(R"( +local a = {{x=4}, {x=7}, {x=1}} +table.sort(a, function(x, y) return x.x < y.x end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") +{ + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end + +local function sumrec(f: typeof(sum)) + return sum(2, 3, function(a, b) return a + b end) +end + +local b = sumrec(sum) -- ok +local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") +{ + CheckResult result = check(R"( +type A = { x: number } +local a: A = { x = 1 } +local b = a +type B = typeof(b) +type X = T +local c: X + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 26881b5c..d146f4e8 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -377,4 +377,32 @@ local b: number = a CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") +{ + check(R"( +--!nonstrict +function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) +_(...)(setfenv,_,not _,"")[_] = nil +end +do end +_(...)(...,setfenv,_):_G() +)"); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") +{ + CheckResult result = check(R"( + local l0,l0 + repeat + type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) + function _(l0):(t0)&(t0) + while nil do + end + end + until _(_)(_)._ + )"); + + CHECK_LE(0, result.errors.size()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp new file mode 100644 index 00000000..30df717b --- /dev/null +++ b/tests/TypeInfer.loops.test.cpp @@ -0,0 +1,473 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferLoops"); + +TEST_CASE_FIXTURE(Fixture, "for_loop") +{ + CheckResult result = check(R"( + local q + for i=0, 50, 2 do + q = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("q")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop") +{ + CheckResult result = check(R"( + local n + local s + for i, v in pairs({ "foo" }) do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +{ + CheckResult result = check(R"( + local n + local s + for i, v in next, { "foo", "bar" } do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") +{ + CheckResult result = check(R"( + local it: any + local a, b + for i, v in it do + a, b = i, v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") +{ + CheckResult result = check(R"( + local foo = "bar" + for i, v in foo do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +{ + CheckResult result = check(R"( + local function keys(dictionary) + local new = {} + local index = 1 + + for key in pairs(dictionary) do + new[index] = key + index = index + 1 + end + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +{ + CheckResult result = check(R"( + local function range(l, h): () -> number + return function() + return l + end + end + + for n: string in range(1, 10) do + print(n) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") +{ + CheckResult result = check(R"( + function f(x) + gobble.prop = x.otherprop + end + + local p + for _, part in i_am_not_defined do + p = part + f(part) + part.thirdprop = false + end + )"); + + CHECK_EQ(2, result.errors.size()); + + TypeId p = requireType("p"); + CHECK_EQ("*unknown*", toString(p)); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") +{ + CheckResult result = check(R"( + local bad_iter = 5 + + for a in bad_iter() do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +{ + CheckResult result = check(R"( + local function hasDivisors(value: number, table) + return false + end + + function prime_iter(state, index) + while hasDivisors(index, state) do + index += 1 + end + + state[index] = true + return index + end + + function primes1() + return prime_iter, {} + end + + function primes2() + return prime_iter, {}, "" + end + + function primes3() + return prime_iter, {}, 2 + end + + for p in primes1() do print(p) end -- mismatch in argument count + + for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string + + for p in primes3() do print(p) end -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + + TypeMismatch* tm = get(result.errors[1]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +{ + CheckResult result = check(R"( + function prime_iter(state, index) + return 1 + end + + for p in prime_iter do print(p) end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") +{ + CheckResult result = check(R"( + function primes() + return function (state: number) end, 2 + end + + for p, q in primes do + q = "" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "while_loop") +{ + CheckResult result = check(R"( + local i + while true do + i = 8 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop") +{ + CheckResult result = check(R"( + local i + repeat + i = 'hi' + until true + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + + print(x) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") +{ + CheckResult result = check(R"( + local T = {} + + function T.f(p) + for i, v in pairs(p) do + T.f(v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") +{ + // In this case, we cannot know the element type of the table {}. It could be anything. + // We therefore must initially ascribe a free typevar to iter. + CheckResult result = check(R"( + for iter in pairs({}) do + iter:g().p = true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") +{ + CheckResult result = check(R"( + while true do + local a = 1 + end + + print(a) -- oops! + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ(us->name, "a"); +} + +TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") +{ + CheckResult result = check(R"( + local key + for i, e in ipairs({}) do key = i end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("number", toString(requireType("key"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") +{ + // This code doesn't pass typechecking. We just care that it doesn't crash. + (void)check(R"( + --!nonstrict + function _:_(...) + end + + repeat + if _ then + else + _ = ... + end + until _ + + for _ in _() do + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") +{ + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + while true do + if a then return 10 end + end + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a): number + while true do + if a then break end + return 10 + end + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + repeat + if a then return 10 end + until false + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a, b): number + repeat + if a then break end + + if b then return 10 end + until false + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a: number?): number + repeat + return 10 + until a ~= nil + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } +} + +TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") +{ + CheckResult result = check(R"( + local t = {} + for _ in t do + for _ in assert(missing()) do + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp new file mode 100644 index 00000000..63643610 --- /dev/null +++ b/tests/TypeInfer.modules.test.cpp @@ -0,0 +1,310 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferModules"); + +TEST_CASE_FIXTURE(Fixture, "require") +{ + fileResolver.source["game/A"] = R"( + local function hooty(x: number): string + return "Hi there!" + end + + return {hooty=hooty} + )"; + + fileResolver.source["game/B"] = R"( + local Hooty = require(game.A) + + local h -- free! + local i = Hooty.hooty(h) + )"; + + CheckResult aResult = frontend.check("game/A"); + dumpErrors(aResult); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["game/B"]; + + REQUIRE(b != nullptr); + + dumpErrors(bResult); + + std::optional iType = requireType(b, "i"); + REQUIRE_EQ("string", toString(*iType)); + + std::optional hType = requireType(b, "h"); + REQUIRE_EQ("number", toString(*hType)); +} + +TEST_CASE_FIXTURE(Fixture, "require_types") +{ + fileResolver.source["workspace/A"] = R"( + export type Point = {x: number, y: number} + + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + local Hooty = require(workspace.A) + + local h: Hooty.Point + )"; + + CheckResult bResult = frontend.check("workspace/B"); + dumpErrors(bResult); + + ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; + REQUIRE(b != nullptr); + + TypeId hType = requireType(b, "h"); + REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); +} + +TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") +{ + fileResolver.source["game/A"] = R"( + local T = {} + function T.f(...) end + return T + )"; + + fileResolver.source["game/B"] = R"( + local A = require(game.A) + local f = A.f + )"; + + CheckResult result = frontend.check("game/B"); + + ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); + REQUIRE(bModule != nullptr); + + TypeId f = follow(requireType(bModule, "f")); + + const FunctionTypeVar* ftv = get(f); + REQUIRE(ftv); + + auto iter = begin(ftv->argTypes); + auto endIter = end(ftv->argTypes); + + REQUIRE(iter == endIter); + REQUIRE(iter.tail()); + + CHECK(get(*iter.tail())); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") +{ + CheckResult result = check(R"( + local p: SomeModule.DoesNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); +} + +TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") +{ + const std::string sourceA = R"( + )"; + + const std::string sourceB = R"( + local Hooty = require(script.Parent.A) + )"; + + fileResolver.source["game/Workspace/A"] = sourceA; + fileResolver.source["game/Workspace/B"] = sourceB; + + frontend.check("game/Workspace/A"); + frontend.check("game/Workspace/B"); + + ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; + ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; + + CHECK(aModule->errors.empty()); + REQUIRE_EQ(1, bModule->errors.size()); + CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); + + auto hootyType = requireType(bModule, "Hooty"); + + CHECK_EQ("*unknown*", toString(hootyType)); +} + +TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") +{ + fileResolver.source["Modules/A"] = ""; + fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; + + fileResolver.source["Modules/B"] = R"( + local M = require(script.Parent.A) + )"; + + CheckResult result = frontend.check("Modules/B"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") +{ + fileResolver.source["game/A"] = R"( +--!strict +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +--!strict +local tbl = { abc = require(game.A) } +local a : string = "" +a = tbl.abc.def + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") +{ + fileResolver.source["game/A"] = R"( +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +local tbl: string = require(game.A) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") +{ + CheckResult result = check(R"( +local n = {} +function n:Clone() end + +local m = {} + +function m.a(x) + x:Clone() +end + +function m.b() + m.a(n) +end + +return m +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "custom_require_global") +{ + CheckResult result = check(R"( +--!nonstrict +require = function(a) end + +local crash = require(game.A) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "require_failed_module") +{ + fileResolver.source["game/A"] = R"( +return unfortunately() + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(aResult); + + CheckResult result = check(R"( +local ModuleA = require(game.A) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional oty = requireType("ModuleA"); + CHECK_EQ("*unknown*", toString(*oty)); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") +{ + fileResolver.source["game/A"] = R"( +export type Type = { unrelated: boolean } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = {} +function x:Destroy(): () end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +export type Type = { x: { a: number } } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = { x = { a = 2 } } +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +local y = setmetatable({}, {}) +export type Type = { x: typeof(y) } +return { x = y } + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = types +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp new file mode 100644 index 00000000..40831bf6 --- /dev/null +++ b/tests/TypeInfer.oop.test.cpp @@ -0,0 +1,275 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferOOP"); + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function1 = function(Arg1) + end + + someTable.Function1() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function2 = function(Arg1, Arg2) + end + + someTable.Function2() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> number)} + local T: T + + T.method(4) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") +{ + CheckResult result = check(R"( + -- This catches a bug where x:m didn't count as a use of x + -- so toposort would happily reorder a definition of + -- function x:m before the definition of x. + function g() f() end + local x = {} + function x:m() end + function f() x:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(999), T:bar("hi") + end + + function T:bar(i) + return i + end + + local a, b = T:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") +{ + check(R"( + local T = {} + + function T.method(self) + self:method() + end + + function T.method2(self) + self:method() + end + + T:method2() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") +{ + CheckResult result = check(R"( + ("foo") + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(50, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") +{ + // CLI-30902 + CheckResult result = check(R"( + --!strict + + type Foo = { + fooConn: () -> () | nil + } + + local Foo = {} + Foo.__index = Foo + + function Foo.new() + local self: Foo = { + fooConn = nil, + } + setmetatable(self, Foo) + + self.fooConn = function() + self:method() -- Key 'method' not found in table self + end + + return self + end + + function Foo:method() + print("foo") + end + + local foo = Foo.new() + + -- TODO This is the best our current refinement support can offer :( + local bar = foo.fooConn + if bar then bar() end + + -- foo.fooConn() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") +{ + CheckResult result = check(R"( +local x: {prop: number} = {prop=9999} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") +{ + CheckResult result = check(R"( +--!nonstrict +local f = {} +function f:foo(a: number, b: number) end + +function bar(...) + f.foo(f, 1, ...) +end + +bar(2) +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp new file mode 100644 index 00000000..baa25978 --- /dev/null +++ b/tests/TypeInfer.operators.test.cpp @@ -0,0 +1,759 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferOperators"); + +TEST_CASE_FIXTURE(Fixture, "or_joins_types") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:string|number = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("x")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:number|string = s + local y = x or "s" + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("y")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" or "b" + local x:string = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(*requireType("s"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") +{ + CheckResult result = check(R"( + local s = "a" and 10 + local x:boolean|number = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "boolean | number"); +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" and true + local x:boolean = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(*requireType("x"), *typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_ternary") +{ + CheckResult result = check(R"( + local s = (1/2) > 0.5 and "a" or 10 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: string) + return a + (tonumber(b) :: number), a .. b + end + local n, s = add(2,"3") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("add")); + + std::optional retType = first(functionType->retType); + CHECK_EQ(std::optional(typeChecker.numberType), retType); + CHECK_EQ(requireType("n"), typeChecker.numberType); + CHECK_EQ(requireType("s"), typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") +{ + CheckResult result = check(R"( + local PI=3.1415926535897931 + local SOLAR_MASS=4*PI * PI + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: any) + return a + b + end + local t = add(1,2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") +{ + CheckResult result = check(R"( + local a = 4 + 8 + local b = a + 9 + local s = 'hotdogs' + local t = s .. s + local c = b - a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("string", toString(requireType("t"))); + CHECK_EQ("number", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = a * b + local d = a * 2 + local e = a * 'cabbage' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK_EQ("Vec3", toString(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = b * a + local d = 2 * a + local e = 'cabbage' * a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK_EQ("Vec3", toString(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "compare_numbers") +{ + CheckResult result = check(R"( + local a = 441 + local b = 0 + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compare_strings") +{ + CheckResult result = check(R"( + local a = '441' + local b = '0' + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") +{ + CheckResult result = check(R"( + local a = {} + local b = {} + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +{ + CheckResult result = check(R"( + local M = {} + function M.new() + return setmetatable({}, M) + end + type M = typeof(M.new()) + + local a = M.new() + local b = M.new() + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + + local a = M.new() + local b = {} + local c = a < b -- line 10 + local d = b < a -- line 11 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); + + REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + type M = typeof(M.new()) + + local a = M.new() + local b = {} + local c = a < b -- line 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = get(result.errors[0]); + REQUIRE(err != nullptr); + + // Frail. :| + REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); +} + +TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") +{ + CheckResult result = check(R"( + --!nonstrict + + function maybe_a_number(): number? + return 50 + end + + local a = maybe_a_number() < maybe_a_number() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") +{ + CheckResult result = check(R"( + local s = 10 + s += 20 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number"); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") +{ + CheckResult result = check(R"( + local s = 10 + s += true + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") +{ + CheckResult result = check(R"( + local s = 'hello' + s += 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__add(a: V2, b: V2): V2 + return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 += v2 + )"); + CHECK_EQ(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__mod(a: V2, b: V2): number + return a.x * b.x + a.y * b.y + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 %= v2 + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + CHECK_EQ(*tm->wantedType, *requireType("v2")); + CHECK_EQ(*tm->givenType, *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +(f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +local x = false +(x and f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") +{ + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + local mt = {} + setmetatable(foo, mt) + + mt.__unm = function(val: typeof(foo)): string + return val.value .. "test" + end + + local a = -foo + + local b = 1+-1 + + local bar = { + value = 10 + } + local c = -bar -- disallowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + + GenericError* gen = get(result.errors[0]); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); +} + +TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") +{ + CheckResult result = check(R"( + local b = not "string" + local c = not (math.random() > 0.5 and "string" or 7) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("boolean", toString(requireType("b"))); + REQUIRE_EQ("boolean", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +{ + CheckResult result = check(R"( + --!strict + local a = "1.24" + 123 -- not allowed + + local foo = { + value = 10 + } + + local b = foo + 1 -- not allowed + + local bar = { + value = 1 + } + + local mt = {} + + setmetatable(bar, mt) + + mt.__add = function(a: typeof(bar), b: number): number + return a.value + b + end + + local c = bar + 1 -- allowed + + local d = bar + foo -- not allowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + + TypeMismatch* tm2 = get(result.errors[2]); + CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm2->givenType, *requireType("foo")); + + GenericError* gen2 = get(result.errors[1]); + REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); +} + +// CLI-29033 +TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") +{ + CheckResult result = check(R"( + function merge(lower, greater) + if lower.y == greater.y then + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") +{ + CheckResult result = check(R"( + local function f(x) + return x .. "y" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") +{ + CheckResult result = check(R"( + local function f(x) + return "foo" .. x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") +{ + std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; + + std::string src = R"( + function foo(a, b) + )"; + + for (const auto& op : ops) + src += "local _ = a " + op + "b\n"; + + src += "end"; + + CheckResult result = check(src); + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") +{ + CheckResult result = check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") +{ + CheckResult result = check(R"( + local a: boolean = true + local b: boolean = false + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") +{ + CheckResult result = check(R"( + local a: number | string = "" + local b: number | string = 1 + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") +{ + CheckResult result = check(R"( + --!strict + local _ + _ += _ and _ or _ and _ or _ and _ + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") +{ + // In non-strict mode, global definition is still allowed + { + CheckResult result = check(R"( + --!nonstrict + a = a + 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In strict mode we no longer generate two errors from lhs + { + CheckResult result = check(R"( + --!strict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In non-strict mode, compound assignment is not a definition, it's a modification + { + CheckResult result = check(R"( + --!nonstrict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") +{ + CheckResult result = check(R"( +--!nonstrict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = 1 or a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ("number?", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") +{ + CheckResult result = check(R"( + type Array = { [number]: T } + type Fiber = { id: number } + type null = {} + + local fiberStack: Array = {} + local index = 0 + + local function f(fiber: Fiber) + local a = fiber ~= fiberStack[index] + local b = fiberStack[index] ~= fiber + end + + return f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: string | number, b: boolean | number) + return a == b + end + )"); + + // This doesn't produce any errors but for the wrong reasons. + // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refine_and_or") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x or 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("u"))); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp new file mode 100644 index 00000000..44b7b0d0 --- /dev/null +++ b/tests/TypeInfer.primitives.test.cpp @@ -0,0 +1,100 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferPrimitives"); + +TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") +{ + CheckResult result = check("local foo = 5 foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "string_length") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = #s + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_index") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = s[4] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* nat = get(result.errors[0]); + REQUIRE(nat); + CHECK_EQ("string", toString(nat->ty)); + + CHECK_EQ("*unknown*", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_method") +{ + CheckResult result = check(R"( + local p = ("tacos"):len() + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(*requireType("p"), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_indirect") +{ + CheckResult result = check(R"( + local s:string + local l = s.lower + local p = l(s) + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(*requireType("p"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_other") +{ + CheckResult result = check(R"( + local s:string + local p = s:match("foo") + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(toString(requireType("p")), "string?"); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( +local x: number = 9999 +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index a5147d56..9b347921 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1298,4 +1298,22 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); } +TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") +{ + const std::string code = R"( + function f(a) + if type(a) == "boolean" then + local a1 = a + elseif a.fn() then + local a2 = a + else + local a3 = a + end + end + )"; + CheckResult result = check(code); + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 3ed536ea..7f8d8fec 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -5,6 +5,8 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(BetterDiagnosticCodesInStudio) + using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); @@ -353,7 +355,14 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); + if (FFlag::BetterDiagnosticCodesInStudio) + { + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); + } + else + { + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") @@ -445,7 +454,7 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si {"LuauSingletonTypes", true}, {"LuauEqConstraint", true}, {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, }; @@ -472,9 +481,9 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauEqConstraint", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, - {"LuauDoNotAccidentallyDependOnPointerOrdering", true} + {"LuauDoNotAccidentallyDependOnPointerOrdering", true}, }; CheckResult result = check(R"( @@ -497,7 +506,7 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -515,7 +524,7 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -544,7 +553,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -565,4 +574,97 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") CHECK_EQ("{string}", toString(requireType("t"))); } +TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauWidenIfSupertypeIsFree2", true}, + }; + + CheckResult result = check(R"( + local function foo(my_enum: "A" | "B") end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"(("A" | "B") -> ())", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index a5eba5df..91140aaa 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2384,4 +2384,504 @@ _ = (_.cos) LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") +{ + CheckResult result = check("local foo = {} foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "table_length") +{ + CheckResult result = check(R"( + local t = {} + local s = #t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(nullptr != get(requireType("t"))); + CHECK_EQ(*typeChecker.numberType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") +{ + CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.nilType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") +{ + check(R"( + local o + local v = o:i() + + function g(u) + v = u + end + + o:f(g) + o:h() + o:h() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") +{ + CheckResult result = check(R"( + local Instance: any + local UDim2: any + + function Create(instanceType) + return function(data) + local obj = Instance.new(instanceType) + for k, v in pairs(data) do + if type(k) == 'number' then + --v.Parent = obj + else + obj[k] = v + end + end + return obj + end + end + + local topbarShadow = Create'ImageLabel'{ + Name = "TopBarShadow"; + Size = UDim2.new(1, 0, 0, 3); + Position = UDim2.new(0, 0, 1, 0); + Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; + BackgroundTransparency = 1; + Active = false; + Visible = false; + }; + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") +{ + CheckResult result = check(R"( + local T = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("T", toString(requireType("T"))); +} + +TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") +{ + CheckResult result = check(R"( + function foo(arr) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const FunctionTypeVar* fooType = get(requireType("foo")); + REQUIRE(fooType); + + std::optional fooArg1 = first(fooType->argTypes); + REQUIRE(fooArg1); + + const TableTypeVar* fooArg1Table = get(*fooArg1); + REQUIRE(fooArg1Table); + + CHECK_EQ(fooArg1Table->state, TableState::Generic); +} + +/* + * This test case exposed an oversight in the treatment of free tables. + * Free tables, like free TypeVars, need to record the scope depth where they were created so that + * we do not erroneously let-generalize them when they are used in a nested lambda. + * + * For more information about let-generalization, see + * + * The important idea here is that the return type of Counter.new is a table with some metatable. + * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by + * the generalization process), then it loses the knowledge that its metatable will have an :incr() + * method. + */ +TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") +{ + CheckResult result = check(R"( + local Counter = {} + Counter.__index = Counter + + function Counter.new() + local self = setmetatable({count=0}, Counter) + return self + end + + function Counter:incr() + self.count = 1 + return self.count + end + + local self = Counter.new() + print(self:incr()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* counterType = getMutable(requireType("Counter")); + REQUIRE(counterType); + + const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); + REQUIRE(newType); + + std::optional newRetType = *first(newType->retType); + REQUIRE(newRetType); + + const MetatableTypeVar* newRet = get(follow(*newRetType)); + REQUIRE(newRet); + + const TableTypeVar* newRetMeta = get(newRet->metatable); + REQUIRE(newRetMeta); + + CHECK(newRetMeta->props.count("incr")); + CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); +} + +// TODO: CLI-39624 +TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") +{ + CheckResult result = check(R"( + --!strict + local Option = {} + Option.__index = Option + function Option.Is(obj) + return (type(obj) == "table" and getmetatable(obj) == Option) + end + return Option + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") +{ + CheckResult result = check(R"( + --!strict + function f(U) + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + end + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(100, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") +{ + CheckResult result = check(R"( +local x = {} +x.a = "a" +x[0] = true +x.b = 37 +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") +{ + { + Fixture fix; + + // inherit env from parent fixture checker + fix.typeChecker.globalScope = typeChecker.globalScope; + + fix.check(R"( +--!nonstrict +type MT = typeof(setmetatable) +function wtf(arg: {MT}): typeof(table) + arg = wtf(arg) +end +)"); + } + + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + { + for (auto& p : typeChecker.globalScope->bindings) + { + toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas + } + } +} + +TEST_CASE_FIXTURE(Fixture, "evil_table_unification") +{ + // this code re-infers the type of _ while processing fields of _, which can cause use-after-free + check(R"( +--!nonstrict +_ = ... +_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil +do end +)"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") +{ + CheckResult result = check("local x = setmetatable({})"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") +{ + CheckResult result = check(R"( +--!nonstrict +local l0:any,l61:t0 = _,math +while _ do +_() +end +function _():t0 +end +type t0 = any +)"); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +type K = X +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +local a = {} +a.x = 4 +local b: X +a.y = 5 +local c: X +c = b +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("a"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") +{ + CheckResult result = check(R"( +local foo = {42} +local bar: number? +local baz = foo[bar] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); +} + +TEST_CASE_FIXTURE(Fixture, "table_simple_call") +{ + CheckResult result = check(R"( +local a = setmetatable({ x = 2 }, { + __call = function(self) + return (self.x :: number) * 2 -- should work without annotation in the future + end +}) +local b = a() +local c = a(2) -- too many arguments + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +{ + CheckResult result = check(R"( + type Foo = {x: string} + local t = {} + setmetatable(t, { + __index = function(x: string): ...Foo + return {x = x} + end + }) + + local foo = t.bar + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") +{ + fileResolver.source["Module/Backend/Types"] = R"( + export type Fiber = { + return_: Fiber? + } + return {} + )"; + + fileResolver.source["Module/Backend"] = R"( + local Types = require(script.Types) + type Fiber = Types.Fiber + type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } + + local function attach(renderer): () + local function getPrimaryFiber(fiber) + local alternate = fiber.alternate + return fiber + end + + local function getFiberIDForNative() + local fiber = renderer.findFiberByHostInstance() + fiber = fiber.return_ + return getPrimaryFiber(fiber) + end + end + + function culprit(renderer: ReactRenderer): () + attach(renderer) + end + + return culprit + )"; + + CheckResult result = frontend.check("Module/Backend"); +} + +TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t.x and t or 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); +} + +TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x == 5 or t.x == 31337 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("boolean", toString(requireType("u"))); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +{ + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at + * the level of the table. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") +{ + CheckResult result = check(R"( + --!strict + local T = {} + + local function f(prop) + T[1] = { + prop = prop, + } + end + + local function g() + local l = T[1].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index d7bbad20..571d0f8d 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -60,15 +60,6 @@ TEST_CASE_FIXTURE(Fixture, "tc_error_2") }})); } -TEST_CASE_FIXTURE(Fixture, "tc_function") -{ - CheckResult result = check("function five() return 5 end"); - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* fiveType = get(requireType("five")); - REQUIRE(fiveType != nullptr); -} - TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") { CheckResult result = check("local f = nil; f = 'hello world'"); @@ -108,462 +99,12 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "check_function_bodies") -{ - CheckResult result = check("function myFunction() local a = 0 a = true end"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.booleanType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_type") -{ - CheckResult result = check("function take_five() return 5 end"); - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* takeFiveType = get(requireType("take_five")); - REQUIRE(takeFiveType != nullptr); - - std::vector retVec = flatten(takeFiveType->retType).first; - REQUIRE(!retVec.empty()); - - REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") -{ - CheckResult result = check("function take_five() return 5 end local five = take_five()"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") -{ - CheckResult result = check("local foo = 5 foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0]) != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") -{ - CheckResult result = check("local foo = {} foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK(get(result.errors[0]) != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") -{ - CheckResult result = check(R"( - function take_five() - return 5 - end - - take_five().prop = 888 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); -} - TEST_CASE_FIXTURE(Fixture, "expr_statement") { CheckResult result = check("local foo = 5 foo()"); LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "generic_function") -{ - CheckResult result = check(R"( - function id(x) return x end - local a = id(55) - local b = id(nil) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.nilType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") -{ - CheckResult result = check(R"( - function f(...) end - - f(1) - f("foo", 2) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") -{ - CheckResult result = check(R"( - local T = {} - function T.f(...) - local result = {} - - for i = 1, select("#", ...) do - local dictionary = select(i, ...) - for key, value in pairs(dictionary) do - result[key] = value - end - end - - return result - end - - return T - )"); - - auto r = first(getMainModule()->getModuleScope()->returnType); - REQUIRE(r); - - TableTypeVar* ttv = getMutable(*r); - REQUIRE(ttv); - - TypeId k = ttv->props["f"].type; - REQUIRE(k); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "for_loop") -{ - CheckResult result = check(R"( - local q - for i=0, 50, 2 do - q = i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("q")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop") -{ - CheckResult result = check(R"( - local n - local s - for i, v in pairs({ "foo" }) do - n = i - s = v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") -{ - CheckResult result = check(R"( - local n - local s - for i, v in next, { "foo", "bar" } do - n = i - s = v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") -{ - CheckResult result = check(R"( - local it: any - local a, b - for i, v in it do - a, b = i, v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") -{ - CheckResult result = check(R"( - local foo = "bar" - for i, v in foo do - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") -{ - CheckResult result = check(R"( - local function keys(dictionary) - local new = {} - local index = 1 - - for key in pairs(dictionary) do - new[index] = key - index = index + 1 - end - - return new - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") -{ - CheckResult result = check(R"( - local function range(l, h): () -> number - return function() - return l - end - end - - for n: string in range(1, 10) do - print(n) - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") -{ - CheckResult result = check(R"( - function f(x) - gobble.prop = x.otherprop - end - - local p - for _, part in i_am_not_defined do - p = part - f(part) - part.thirdprop = false - end - )"); - - CHECK_EQ(2, result.errors.size()); - - TypeId p = requireType("p"); - CHECK_EQ("*unknown*", toString(p)); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") -{ - CheckResult result = check(R"( - local bad_iter = 5 - - for a in bad_iter() do - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") -{ - CheckResult result = check(R"( - local function hasDivisors(value: number, table) - return false - end - - function prime_iter(state, index) - while hasDivisors(index, state) do - index += 1 - end - - state[index] = true - return index - end - - function primes1() - return prime_iter, {} - end - - function primes2() - return prime_iter, {}, "" - end - - function primes3() - return prime_iter, {}, 2 - end - - for p in primes1() do print(p) end -- mismatch in argument count - - for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string - - for p in primes3() do print(p) end -- no error - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Arg); - CHECK_EQ(2, acm->expected); - CHECK_EQ(1, acm->actual); - - TypeMismatch* tm = get(result.errors[1]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") -{ - CheckResult result = check(R"( - function prime_iter(state, index) - return 1 - end - - for p in prime_iter do print(p) end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Arg); - CHECK_EQ(2, acm->expected); - CHECK_EQ(0, acm->actual); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") -{ - CheckResult result = check(R"( - function bar(): any - return true - end - - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(typeChecker.anyType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") -{ - CheckResult result = check(R"( - function bar(): any - return true - end - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") -{ - CheckResult result = check(R"( - local bar: any - - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") -{ - CheckResult result = check(R"( - local bar: any - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") -{ - CheckResult result = check(R"( - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") -{ - CheckResult result = check(R"( - function bar(c) return c end - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") -{ - CheckResult result = check(R"( - function primes() - return function (state: number) end, 2 - end - - for p, q in primes do - q = "" - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "if_statement") { CheckResult result = check(R"( @@ -583,474 +124,6 @@ TEST_CASE_FIXTURE(Fixture, "if_statement") CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "while_loop") -{ - CheckResult result = check(R"( - local i - while true do - i = 8 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("i")); -} - -TEST_CASE_FIXTURE(Fixture, "repeat_loop") -{ - CheckResult result = check(R"( - local i - repeat - i = 'hi' - until true - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.stringType, *requireType("i")); -} - -TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") -{ - CheckResult result = check(R"( - repeat - local x = true - until x - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") -{ - CheckResult result = check(R"( - repeat - local x = true - until x - - print(x) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "table_length") -{ - CheckResult result = check(R"( - local t = {} - local s = #t - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK(nullptr != get(requireType("t"))); - CHECK_EQ(*typeChecker.numberType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "string_length") -{ - CheckResult result = check(R"( - local s = "Hello, World!" - local t = #s - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "string_index") -{ - CheckResult result = check(R"( - local s = "Hello, World!" - local t = s[4] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - NotATable* nat = get(result.errors[0]); - REQUIRE(nat); - CHECK_EQ("string", toString(nat->ty)); - - CHECK_EQ("*unknown*", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local l = #this_is_not_defined - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local originalReward = unknown.Parent.Reward:GetChildren()[1] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") -{ - CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") -{ - CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") -{ - CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.nilType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local foo = (true).x - foo.x = foo.y - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") -{ - CheckResult result = check(R"( - local someTable = {} - - someTable.Function1 = function(Arg1) - end - - someTable.Function1() -- Argument count mismatch - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") -{ - CheckResult result = check(R"( - local someTable = {} - - someTable.Function2 = function(Arg1, Arg2) - end - - someTable.Function2() -- Argument count mismatch - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") -{ - CheckResult result = check(R"( - type T = {method: ((T, number) -> number) & ((number) -> number)} - local T: T - - T.method(4) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply("") - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); - - ExtraInformation* ei = get(result.errors[1]); - REQUIRE(ei); - CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); -} - -TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply() - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); - - ExtraInformation* ei = get(result.errors[1]); - REQUIRE(ei); - CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); -} - -TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply(1, "") - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") -{ - CheckResult result = check(R"( - type T = {method: ((T, number) -> number) & ((number) -> string)} - local T: T - - local a = T.method(T, 4) - local b = T.method(5) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("a"))); - CHECK_EQ("string", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "too_many_arguments") -{ - CheckResult result = check(R"( - --!nonstrict - - function g(a: number) end - - g() - - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto err = result.errors[0]; - auto acm = get(err); - REQUIRE(acm); - - CHECK_EQ(1, acm->expected); - CHECK_EQ(0, acm->actual); -} - -TEST_CASE_FIXTURE(Fixture, "any_type_propagates") -{ - CheckResult result = check(R"( - local foo: any - local bar = foo:method("argument") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "can_subscript_any") -{ - CheckResult result = check(R"( - local foo: any - local bar = foo[5] - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("bar"))); -} - -// Not strictly correct: metatables permit overriding this -TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") -{ - CheckResult result = check(R"( - local foo: any = {} - local bar = #foo - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_function") -{ - CheckResult result = check(R"( - function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") -{ - CheckResult result = check(R"( - local f = function() return f() end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_local_function") -{ - CheckResult result = check(R"( - local function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -// FIXME: This and the above case get handled very differently. It's pretty dumb. -// We really should unify the two code paths, probably by deleting AstStatFunction. -TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") -{ - CheckResult result = check(R"( - local count - function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") -{ - CheckResult result = check(R"( - function f() - return f - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") -{ - CheckResult result = check(R"( - function f(g) - return f(f) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); -} - -// TODO: File a Jira about this -/* -TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") -{ - CheckResult result = check(R"( - function a(x) return 1 end - a(...) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); - - TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; - - auto iter = begin(varargPack); - auto endIter = end(varargPack); - - CHECK(iter != endIter); - ++iter; - CHECK(iter == endIter); - - CHECK(!iter.tail()); -} -*/ - -TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") -{ - CheckResult result = check(R"( - -- This catches a bug where x:m didn't count as a use of x - -- so toposort would happily reorder a definition of - -- function x:m before the definition of x. - function g() f() end - local x = {} - function x:m() end - function f() x:m() end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") -{ - CheckResult result = check(R"( - local Get_des - function Get_des(func) - Get_des(func) - end - - local function f(d) - d:IsA("BasePart") - d.Parent:FindFirstChild("Humanoid") - d:IsA("Decal") - end - Get_des(f) - - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") -{ - CheckResult result = check(R"( - local d - d:foo() - d:foo() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") { CheckResult result = check(R"( @@ -1067,121 +140,6 @@ TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "generic_table_method") -{ - CheckResult result = check(R"( - local T = {} - - function T:bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId tType = requireType("T"); - TableTypeVar* tTable = getMutable(tType); - REQUIRE(tTable != nullptr); - - TypeId barType = tTable->props["bar"].type; - REQUIRE(barType != nullptr); - - const FunctionTypeVar* ftv = get(follow(barType)); - REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); - - std::vector args = flatten(ftv->argTypes).first; - TypeId argType = args.at(1); - - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); -} - -TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") -{ - CheckResult result = check(R"( - local T = {} - - function T:foo() - return T:bar(5) - end - - function T:bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - const TableTypeVar* t = get(requireType("T")); - REQUIRE(t != nullptr); - - std::optional fooProp = get(t->props, "foo"); - REQUIRE(bool(fooProp)); - - const FunctionTypeVar* foo = get(follow(fooProp->type)); - REQUIRE(bool(foo)); - - std::optional ret_ = first(foo->retType); - REQUIRE(bool(ret_)); - TypeId ret = follow(*ret_); - - REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); -} - -TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") -{ - CheckResult result = check(R"( - local T = {} - - function T:foo() - return T:bar(999), T:bar("hi") - end - - function T:bar(i) - return i - end - - local a, b = T:foo() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "local_function") -{ - CheckResult result = check(R"( - function f() - return 8 - end - - function g() - local function f() - return 'hello' - end - return f - end - - local h = g() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId h = follow(requireType("h")); - - const FunctionTypeVar* ftv = get(h); - REQUIRE(ftv != nullptr); - - std::optional rt = first(ftv->retType); - REQUIRE(bool(rt)); - - TypeId retType = follow(*rt); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); -} - TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") { CheckResult result = check(R"( @@ -1193,267 +151,10 @@ TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") o = p )"); -} - -/* - * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as - * f {+ method: function(): (t2, T3...) +} - * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} - * - * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. - * - * The correct unification of the argument to 'g' is - * - * {+ method: function(): (t5, T6...) +} - */ -TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") -{ - auto result = check(R"( - function f(o) - o:method() - end - - function g(o) - f(o) - end - )"); - - TypeId g = requireType("g"); - const FunctionTypeVar* gFun = get(g); - REQUIRE(gFun != nullptr); - - auto optionArg = first(gFun->argTypes); - REQUIRE(bool(optionArg)); - - TypeId arg = follow(*optionArg); - const TableTypeVar* argTable = get(arg); - REQUIRE(argTable != nullptr); - - std::optional methodProp = get(argTable->props, "method"); - REQUIRE(bool(methodProp)); - - const FunctionTypeVar* methodFunction = get(methodProp->type); - REQUIRE(methodFunction != nullptr); - - std::optional methodArg = first(methodFunction->argTypes); - REQUIRE(bool(methodArg)); - - REQUIRE_EQ(follow(*methodArg), follow(arg)); -} - -TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") -{ - CheckResult result = check(R"( - local T = {} - - function T.f(p) - for i, v in pairs(p) do - T.f(v) - end - end - )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") -{ - // In this case, we cannot know the element type of the table {}. It could be anything. - // We therefore must initially ascribe a free typevar to iter. - CheckResult result = check(R"( - for iter in pairs({}) do - iter:g().p = true - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") -{ - check(R"( - local T = {} - - function T.method(self) - self:method() - end - - function T.method2(self) - self:method() - end - - T:method2() - )"); -} - -TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") -{ - check(R"( - local o - local v = o:i() - - function g(u) - v = u - end - - o:f(g) - o:h() - o:h() - )"); -} - -TEST_CASE_FIXTURE(Fixture, "require") -{ - fileResolver.source["game/A"] = R"( - local function hooty(x: number): string - return "Hi there!" - end - - return {hooty=hooty} - )"; - - fileResolver.source["game/B"] = R"( - local Hooty = require(game.A) - - local h -- free! - local i = Hooty.hooty(h) - )"; - - CheckResult aResult = frontend.check("game/A"); - dumpErrors(aResult); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = frontend.check("game/B"); - dumpErrors(bResult); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ModulePtr b = frontend.moduleResolver.modules["game/B"]; - - REQUIRE(b != nullptr); - - dumpErrors(bResult); - - std::optional iType = requireType(b, "i"); - REQUIRE_EQ("string", toString(*iType)); - - std::optional hType = requireType(b, "h"); - REQUIRE_EQ("number", toString(*hType)); -} - -TEST_CASE_FIXTURE(Fixture, "require_types") -{ - fileResolver.source["workspace/A"] = R"( - export type Point = {x: number, y: number} - - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - local Hooty = require(workspace.A) - - local h: Hooty.Point - )"; - - CheckResult bResult = frontend.check("workspace/B"); - dumpErrors(bResult); - - ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; - REQUIRE(b != nullptr); - - TypeId hType = requireType(b, "h"); - REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); -} - -TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") -{ - fileResolver.source["game/A"] = R"( - local T = {} - function T.f(...) end - return T - )"; - - fileResolver.source["game/B"] = R"( - local A = require(game.A) - local f = A.f - )"; - - CheckResult result = frontend.check("game/B"); - - ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); - REQUIRE(bModule != nullptr); - - TypeId f = follow(requireType(bModule, "f")); - - const FunctionTypeVar* ftv = get(f); - REQUIRE(ftv); - - auto iter = begin(ftv->argTypes); - auto endIter = end(ftv->argTypes); - - REQUIRE(iter == endIter); - REQUIRE(iter.tail()); - - CHECK(get(*iter.tail())); -} - -TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") -{ - CheckResult result = check(R"( - local f: any - local T = {} - - T.prop = f() - - return T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TableTypeVar* ttv = getMutable(requireType("T")); - REQUIRE(ttv); - REQUIRE(ttv->props.count("prop")); - - REQUIRE_EQ("any", toString(ttv->props["prop"].type)); -} - -TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") -{ - CheckResult result = check(R"( - local p: SomeModule.DoesNotExist - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); -} - -TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") -{ - const std::string sourceA = R"( - )"; - - const std::string sourceB = R"( - local Hooty = require(script.Parent.A) - )"; - - fileResolver.source["game/Workspace/A"] = sourceA; - fileResolver.source["game/Workspace/B"] = sourceB; - - frontend.check("game/Workspace/A"); - frontend.check("game/Workspace/B"); - - ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; - ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; - - CHECK(aModule->errors.empty()); - REQUIRE_EQ(1, bModule->errors.size()); - CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); - - auto hootyType = requireType(bModule, "Hooty"); - - CHECK_EQ("*unknown*", toString(hootyType)); -} - TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") { CheckResult result = check(R"( @@ -1468,144 +169,6 @@ TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") REQUIRE_EQ("parent", ed->symbol); } -TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") -{ - CheckResult result = check(R"( - local A : any - function A.B() end - A:C() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId aType = requireType("A"); - CHECK_EQ(aType, typeChecker.anyType); -} - -TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") -{ - CheckResult result = check(R"( - local Instance: any - local UDim2: any - - function Create(instanceType) - return function(data) - local obj = Instance.new(instanceType) - for k, v in pairs(data) do - if type(k) == 'number' then - --v.Parent = obj - else - obj[k] = v - end - end - return obj - end - end - - local topbarShadow = Create'ImageLabel'{ - Name = "TopBarShadow"; - Size = UDim2.new(1, 0, 0, 3); - Position = UDim2.new(0, 0, 1, 0); - Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; - BackgroundTransparency = 1; - Active = false; - Visible = false; - }; - - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") -{ - CheckResult result = check(R"( - local p = function(x) return x end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - const Luau::FunctionTypeVar* fn = get(requireType("p")); - REQUIRE(fn); - auto ret = first(fn->retType); - REQUIRE(ret); - REQUIRE(get(follow(*ret))); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") -{ - CheckResult result = check(R"( - function foo(a, b) - return a(b) - end - - function bar() - local c: ((number)->number, number)->number = foo -- no error - c = foo -- no error - local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") -{ - CheckResult result = check(R"( - function foo(a, b) - return a(b) - end - - function bar() - local _: (string, string)->number = foo -- string cannot be converted to (string)->number - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "string_method") -{ - CheckResult result = check(R"( - local p = ("tacos"):len() - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(*requireType("p"), *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "string_function_indirect") -{ - CheckResult result = check(R"( - local s:string - local l = s.lower - local p = l(s) - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(*requireType("p"), *typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "string_function_other") -{ - CheckResult result = check(R"( - local s:string - local p = s:match("foo") - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(toString(requireType("p")), "string?"); -} - TEST_CASE_FIXTURE(Fixture, "weird_case") { CheckResult result = check(R"( @@ -1617,80 +180,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_case") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "or_joins_types") -{ - CheckResult result = check(R"( - local s = "a" or 10 - local x:string|number = s - )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("x")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") -{ - CheckResult result = check(R"( - local s = "a" or 10 - local x:number|string = s - local y = x or "s" - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("y")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") -{ - CheckResult result = check(R"( - local s = "a" or "b" - local x:string = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(*requireType("s"), *typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") -{ - CheckResult result = check(R"( - local s = "a" and 10 - local x:boolean|number = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "boolean | number"); -} - -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") -{ - CheckResult result = check(R"( - local s = "a" and true - local x:boolean = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(*requireType("x"), *typeChecker.booleanType); -} - -TEST_CASE_FIXTURE(Fixture, "and_or_ternary") -{ - CheckResult result = check(R"( - local s = (1/2) > 0.5 and "a" or 10 - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") -{ - CheckResult result = check(R"( - local T = {} - function T.new(a: number?, b: number?, c: number?) return 5 end - local m = T.new() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { CheckResult result = check(R"( @@ -1726,303 +215,6 @@ TEST_CASE_FIXTURE(Fixture, "crazy_complexity") } #endif -// We had a bug where a cyclic union caused a stack overflow. -// ex type U = number | U -TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") -{ - CheckResult result = check(R"( - --!strict - - function f(a, b) - a:g(b or {}) - a:g(b) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") -{ - CheckResult result = check(R"( - function get_two() return 5, 6 end - - local a = get_two() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") -{ - CheckResult result = check(R"( - function foo() end - - function bar() - local function foo() end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - function foo() end - - function foo() end - - function bar() - local function foo() end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - function foo(): number - return 1 - end - foo() - - function foo(n: number): number - return 2 - end - foo() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("(number) -> number", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") -{ - CheckResult result = check(R"( - local T = {} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("T", toString(requireType("T"))); -} - -TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") -{ - CheckResult result = check(R"( - function foo(arr) - local work = {} - for i = 1, #arr do - work[i] = arr[i] - end - - return arr - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - const FunctionTypeVar* fooType = get(requireType("foo")); - REQUIRE(fooType); - - std::optional fooArg1 = first(fooType->argTypes); - REQUIRE(fooArg1); - - const TableTypeVar* fooArg1Table = get(*fooArg1); - REQUIRE(fooArg1Table); - - CHECK_EQ(fooArg1Table->state, TableState::Generic); -} - -TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") -{ - CheckResult result = check(R"( - local i = 0 - function most_of_the_natural_numbers(): number? - if i < 10 then - i = i + 1 - return i - else - return nil - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - - std::optional retType = first(functionType->retType); - REQUIRE(retType); - CHECK(get(*retType)); -} - -TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") -{ - CheckResult result = check(R"( - function apply(f, x) - return f(x) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("apply")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(2, argVec.size()); - - const FunctionTypeVar* fType = get(follow(argVec[0])); - REQUIRE(fType != nullptr); - - std::vector fArgs = flatten(fType->argTypes).first; - - TypeId xType = argVec[1]; - - CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") -{ - CheckResult result = check(R"( - function bottomupmerge(comp, a, b, left, mid, right) - local i, j = left, mid - for k = left, right do - if i < mid and (j > right or not comp(a[j], a[i])) then - b[k] = a[i] - i = i + 1 - else - b[k] = a[j] - j = j + 1 - end - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(6, argVec.size()); - - const FunctionTypeVar* fType = get(follow(argVec[0])); - REQUIRE(fType != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") -{ - CheckResult result = check(R"( - function swap(p) - local t = p[0] - p[0] = p[1] - p[1] = t - return nil - end - - function swapTwice(p) - swap(p) - swap(p) - return p - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("swapTwice")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(1, argVec.size()); - - const TableTypeVar* argType = get(follow(argVec[0])); - REQUIRE(argType != nullptr); - - CHECK(bool(argType->indexer)); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") -{ - CheckResult result = check(R"( - function bottomupmerge(comp, a, b, left, mid, right) - local i, j = left, mid - for k = left, right do - if i < mid and (j > right or not comp(a[j], a[i])) then - b[k] = a[i] - i = i + 1 - else - b[k] = a[j] - j = j + 1 - end - end - end - - function mergesort(arr, comp) - local work = {} - for i = 1, #arr do - work[i] = arr[i] - end - local width = 1 - while width < #arr do - for i = 1, #arr, 2*width do - bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) - end - local temp = work - work = arr - arr = temp - width = width * 2 - end - return arr - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - /* - * mergesort takes two arguments: an array of some type T and a function that takes two Ts. - * We must assert that these two types are in fact the same type. - * In other words, comp(arr[x], arr[y]) is well-typed. - */ - - const FunctionTypeVar* ftv = get(requireType("mergesort")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(2, argVec.size()); - - const TableTypeVar* arg0 = get(follow(argVec[0])); - REQUIRE(arg0 != nullptr); - REQUIRE(bool(arg0->indexer)); - - const FunctionTypeVar* arg1 = get(follow(argVec[1])); - REQUIRE(arg1 != nullptr); - REQUIRE_EQ(2, size(arg1->argTypes)); - - std::vector arg1Args = flatten(arg1->argTypes).first; - - CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); - CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); -} - TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( @@ -2046,373 +238,6 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") CHECK_EQ("*unknown*", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") -{ - CheckResult result = check(R"( - local a = unknown.Parent.Reward.GetChildren() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* err = get(result.errors[0]); - REQUIRE(err != nullptr); - - CHECK_EQ("unknown", err->name); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") -{ - CheckResult result = check(R"( - local a = Utility.Create "Foo" {} - )"); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") -{ - CheckResult result = check(R"( - function add(a: number, b: string) - return a + (tonumber(b) :: number), a .. b - end - local n, s = add(2,"3") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* functionType = get(requireType("add")); - - std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); - CHECK_EQ(requireType("n"), typeChecker.numberType); - CHECK_EQ(requireType("s"), typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") -{ - CheckResult result = check(R"( - local PI=3.1415926535897931 - local SOLAR_MASS=4*PI * PI - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") -{ - CheckResult result = check(R"( - function add(a: number, b: any) - return a + b - end - local t = add(1,2) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") -{ - CheckResult result = check(R"( - local a = 4 + 8 - local b = a + 9 - local s = 'hotdogs' - local t = s .. s - local c = b - a - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireType("a"))); - CHECK_EQ("number", toString(requireType("b"))); - CHECK_EQ("string", toString(requireType("s"))); - CHECK_EQ("string", toString(requireType("t"))); - CHECK_EQ("number", toString(requireType("c"))); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( - --!strict - local Vec3 = {} - Vec3.__index = Vec3 - function Vec3.new() - return setmetatable({x=0, y=0, z=0}, Vec3) - end - - export type Vec3 = typeof(Vec3.new()) - - local thefun: any = function(self, o) return self end - - local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun - - Vec3.__mul = multiply - - local a = Vec3.new() - local b = Vec3.new() - local c = a * b - local d = a * 2 - local e = a * 'cabbage' - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK_EQ("Vec3", toString(requireType("e"))); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( - --!strict - local Vec3 = {} - Vec3.__index = Vec3 - function Vec3.new() - return setmetatable({x=0, y=0, z=0}, Vec3) - end - - export type Vec3 = typeof(Vec3.new()) - - local thefun: any = function(self, o) return self end - - local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun - - Vec3.__mul = multiply - - local a = Vec3.new() - local b = Vec3.new() - local c = b * a - local d = 2 * a - local e = 'cabbage' * a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK_EQ("Vec3", toString(requireType("e"))); -} - -TEST_CASE_FIXTURE(Fixture, "compare_numbers") -{ - CheckResult result = check(R"( - local a = 441 - local b = 0 - local c = a < b - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "compare_strings") -{ - CheckResult result = check(R"( - local a = '441' - local b = '0' - local c = a < b - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") -{ - CheckResult result = check(R"( - local a = {} - local b = {} - local c = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* gen = get(result.errors[0]); - - REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") -{ - CheckResult result = check(R"( - local M = {} - function M.new() - return setmetatable({}, M) - end - type M = typeof(M.new()) - - local a = M.new() - local b = M.new() - local c = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* gen = get(result.errors[0]); - REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") -{ - CheckResult result = check(R"( - --!strict - local M = {} - function M.new() - return setmetatable({}, M) - end - function M.__lt(left, right) return true end - - local a = M.new() - local b = {} - local c = a < b -- line 10 - local d = b < a -- line 11 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); - - REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); -} - -TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") -{ - CheckResult result = check(R"( - --!strict - local M = {} - function M.new() - return setmetatable({}, M) - end - function M.__lt(left, right) return true end - type M = typeof(M.new()) - - local a = M.new() - local b = {} - local c = a < b -- line 10 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto err = get(result.errors[0]); - REQUIRE(err != nullptr); - - // Frail. :| - REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); -} - -TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") -{ - CheckResult result = check(R"( - --!nonstrict - - function maybe_a_number(): number? - return 50 - end - - local a = maybe_a_number() < maybe_a_number() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -/* - * This test case exposed an oversight in the treatment of free tables. - * Free tables, like free TypeVars, need to record the scope depth where they were created so that - * we do not erroneously let-generalize them when they are used in a nested lambda. - * - * For more information about let-generalization, see - * - * The important idea here is that the return type of Counter.new is a table with some metatable. - * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by - * the generalization process), then it loses the knowledge that its metatable will have an :incr() - * method. - */ -TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") -{ - CheckResult result = check(R"( - local Counter = {} - Counter.__index = Counter - - function Counter.new() - local self = setmetatable({count=0}, Counter) - return self - end - - function Counter:incr() - self.count = 1 - return self.count - end - - local self = Counter.new() - print(self:incr()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TableTypeVar* counterType = getMutable(requireType("Counter")); - REQUIRE(counterType); - - const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); - REQUIRE(newType); - - std::optional newRetType = *first(newType->retType); - REQUIRE(newRetType); - - const MetatableTypeVar* newRet = get(follow(*newRetType)); - REQUIRE(newRet); - - const TableTypeVar* newRetMeta = get(newRet->metatable); - REQUIRE(newRetMeta); - - CHECK(newRetMeta->props.count("incr")); - CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); -} - -// TODO: CLI-39624 -TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") -{ - CheckResult result = check(R"( - --!strict - local Option = {} - Option.__index = Option - function Option.Is(obj) - return (type(obj) == "table" and getmetatable(obj) == Option) - end - return Option - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") -{ - const std::string code = R"( - function f(a) - if type(a) == "boolean" then - local a1 = a - elseif a.fn() then - local a2 = a - else - local a3 = a - end - end - )"; - CheckResult result = check(code); - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") { CheckResult result = check(R"( @@ -2428,47 +253,6 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") -{ - CheckResult result = check(R"( - ("foo") - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - )"); - - ModulePtr module = getMainModule(); - CHECK_GE(50, module->internalTypes.typeVars.size()); -} - -TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") -{ - CheckResult result = check(R"( - --!strict - function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() - end - )"); - - ModulePtr module = getMainModule(); - CHECK_GE(100, module->internalTypes.typeVars.size()); -} - TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") { CheckResult result = check(R"( @@ -2507,96 +291,6 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") CHECK_GE(5, module->interfaceTypes.typeVars.size()); } -TEST_CASE_FIXTURE(Fixture, "mutual_recursion") -{ - CheckResult result = check(R"( - --!strict - - function newPlayerCharacter() - startGui() -- Unknown symbol 'startGui' - end - - local characterAddedConnection: any - function startGui() - characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") -{ - CheckResult result = check(R"( - --!strict - local x = nil - function f() g() end - -- make sure print(x) doesn't get toposorted here, breaking the mutual block - function g() x = f end - print(x) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") -{ - // CLI-30902 - CheckResult result = check(R"( - --!strict - - type Foo = { - fooConn: () -> () | nil - } - - local Foo = {} - Foo.__index = Foo - - function Foo.new() - local self: Foo = { - fooConn = nil, - } - setmetatable(self, Foo) - - self.fooConn = function() - self:method() -- Key 'method' not found in table self - end - - return self - end - - function Foo:method() - print("foo") - end - - local foo = Foo.new() - - -- TODO This is the best our current refinement support can offer :( - local bar = foo.fooConn - if bar then bar() end - - -- foo.fooConn() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") -{ - CheckResult result = check(R"( - local a: any - local b - for _, i in pairs(a) do - b = i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("b"))); -} - // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") @@ -2653,146 +347,6 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") -{ - CheckResult result = check(R"( - local s = 10 - s += 20 - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number"); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") -{ - CheckResult result = check(R"( - local s = 10 - s += true - )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") -{ - CheckResult result = check(R"( - local s = 'hello' - s += 10 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") -{ - CheckResult result = check(R"( - --!strict - type V2B = { x: number, y: number } - local v2b: V2B = { x = 0, y = 0 } - local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - - function VMT.__add(a: V2, b: V2): V2 - return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) - end - - local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) - local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) - v1 += v2 - )"); - CHECK_EQ(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") -{ - CheckResult result = check(R"( - --!strict - type V2B = { x: number, y: number } - local v2b: V2B = { x = 0, y = 0 } - local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - - function VMT.__mod(a: V2, b: V2): number - return a.x * b.x + a.y * b.y - end - - local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) - local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) - v1 %= v2 - )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - CHECK_EQ(*tm->wantedType, *requireType("v2")); - CHECK_EQ(*tm->givenType, *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") -{ - CheckResult result = check(R"( - --!strict - function f(s) - print(s) - return f - end - - f("foo")("bar") - )"); -} - -TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") -{ - CheckResult result = check(R"( - --!nonstrict - - function f() - return 114 - end - - return function() - return f():andThen() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") -{ - CheckResult result = check(R"( - function onerror() end - function foo() end - xpcall(foo, onerror) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") -{ - CheckResult result = check(R"( - local mycb: (number, number) -> () - - function f() end - - mycb = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") -{ - CheckResult result = check(R"( - local a: any - local b = a() - )"); - - REQUIRE_EQ("any", toString(requireType("b"))); -} - TEST_CASE_FIXTURE(Fixture, "globals") { CheckResult result = check(R"( @@ -2853,91 +407,6 @@ TEST_CASE_FIXTURE(Fixture, "globals_everywhere") CHECK_EQ("any", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") -{ - CheckResult result = check(R"( -local x: any = {} -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") -{ - CheckResult result = check(R"( -local x: {prop: number} = {prop=9999} -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( -local x: number = 9999 -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") -{ - CheckResult result = check(R"( -local x = (true).foo -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") -{ - CheckResult result = check(R"( -function f() return 1; end -function g() return 2; end -(f or g)() -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") -{ - CheckResult result = check(R"( -function f() return 1; end -function g() return 2; end -local x = false -(x and f or g)() -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") -{ - CheckResult result = check(R"( -local x = {} -x.a = "a" -x[0] = true -x.b = 37 -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") { CheckResult result = check(R"( @@ -2955,35 +424,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") -{ - CheckResult result = check(R"( - while true do - local a = 1 - end - - print(a) -- oops! - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* us = get(result.errors[0]); - REQUIRE(us); - CHECK_EQ(us->name, "a"); -} - -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") -{ - CheckResult result = check(R"( - local key - for i, e in ipairs({}) do key = i end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - REQUIRE_EQ("number", toString(requireType("key"))); -} - TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") { CHECK_NOTHROW(check(R"( @@ -2997,79 +437,6 @@ TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") CHECK_EQ("any", toString(requireType("value"))); } -TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - - local function f1(v): number? - if v then - return 1 - end - end - - local function f2(v) - if v then - return 1 - end - end - - local function f3(v): () - if v then - return - end - end - - local function f4(v) - if v then - return - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - FunctionExitsWithoutReturning* err = get(result.errors[0]); - CHECK(err); -} - -TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") -{ - CheckResult result = check(R"( - --!strict - - local function f1(v): number? - if v then - return 1 - end - end - - local function f2(v) - if v then - return 1 - end - end - - local function f3(v): () - if v then - return - end - end - - local function f4(v) - if v then - return - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); - CHECK(annotatedErr); - - FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); - CHECK(inferredErr); -} - // TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") // { // CheckResult result = check(R"( @@ -3085,363 +452,6 @@ TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") // CHECK_EQ("A", toString(requireType("f"))); // } -TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") -{ - fileResolver.source["Modules/A"] = ""; - fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; - - fileResolver.source["Modules/B"] = R"( - local M = require(script.Parent.A) - )"; - - CheckResult result = frontend.check("Modules/B"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") -{ - CheckResult result = check(R"( - function foo(a: number, b: string) end - - foo("Test", 123) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); - - CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ - typeChecker.stringType, - typeChecker.numberType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") -{ - CheckResult result = check(R"( - --!nonstrict - - function Test(a) - return 1, "" - end - - - local tab = {} - table.insert(tab, Test(1)); - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - opts.exhaustive = true; - opts.maxTableLength = 0; - - CHECK_EQ("{any}", toString(requireType("tab"), opts)); -} - -TEST_CASE_FIXTURE(Fixture, "too_many_return_values") -{ - CheckResult result = check(R"( - --!strict - - function f() - return 55 - end - - local a, b = f() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); - CHECK_EQ(acm->expected, 1); - CHECK_EQ(acm->actual, 2); -} - -TEST_CASE_FIXTURE(Fixture, "ignored_return_values") -{ - CheckResult result = check(R"( - --!strict - - function f() - return 55, "" - end - - local a = f() - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") -{ - CheckResult result = check(R"( - --!strict - - function f(): (number, string) - return 55 - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Return); - CHECK_EQ(acm->expected, 2); - CHECK_EQ(acm->actual, 1); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") -{ - CheckResult result = check(R"( - --!strict - local foo = { - value = 10 - } - local mt = {} - setmetatable(foo, mt) - - mt.__unm = function(val: typeof(foo)): string - return val.value .. "test" - end - - local a = -foo - - local b = 1+-1 - - local bar = { - value = 10 - } - local c = -bar -- disallowed - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("string", toString(requireType("a"))); - CHECK_EQ("number", toString(requireType("b"))); - - GenericError* gen = get(result.errors[0]); - REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); -} - -TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") -{ - CheckResult result = check(R"( - local b = not "string" - local c = not (math.random() > 0.5 and "string" or 7) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); -} - -TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") -{ - CheckResult result = check(R"( - --!strict - local a = "1.24" + 123 -- not allowed - - local foo = { - value = 10 - } - - local b = foo + 1 -- not allowed - - local bar = { - value = 1 - } - - local mt = {} - - setmetatable(bar, mt) - - mt.__add = function(a: typeof(bar), b: number): number - return a.value + b - end - - local c = bar + 1 -- allowed - - local d = bar + foo -- not allowed - )"); - - LUAU_REQUIRE_ERROR_COUNT(3, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); - - TypeMismatch* tm2 = get(result.errors[2]); - CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm2->givenType, *requireType("foo")); - - GenericError* gen2 = get(result.errors[1]); - REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); -} - -// CLI-29033 -TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") -{ - CheckResult result = check(R"( - function merge(lower, greater) - if lower.y == greater.y then - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") -{ - CheckResult result = check(R"( - local function f(x) - return x .. "y" - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") -{ - CheckResult result = check(R"( - local function f(x) - return "foo" .. x - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("(string) -> string", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") -{ - std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; - - std::string src = R"( - function foo(a, b) - )"; - - for (const auto& op : ops) - src += "local _ = a " + op + "b\n"; - - src += "end"; - - CheckResult result = check(src); - LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - - CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") -{ - CheckResult result = check(R"( - function foo(a, b): number - return 0 - end - - local a: (string)->number = foo - local b: (number, number)->(number, number) = foo - - local c: (string, number)->number = foo -- no error - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - auto tm1 = get(result.errors[0]); - REQUIRE(tm1); - - CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); - - auto tm2 = get(result.errors[1]); - REQUIRE(tm2); - - CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") -{ - { - Fixture fix; - - // inherit env from parent fixture checker - fix.typeChecker.globalScope = typeChecker.globalScope; - - fix.check(R"( ---!nonstrict -type MT = typeof(setmetatable) -function wtf(arg: {MT}): typeof(table) - arg = wtf(arg) -end -)"); - } - - // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down - // note: it's important for typeck to be destroyed at this point! - { - for (auto& p : typeChecker.globalScope->bindings) - { - toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas - } - } -} - -TEST_CASE_FIXTURE(Fixture, "evil_table_unification") -{ - // this code re-infers the type of _ while processing fields of _, which can cause use-after-free - check(R"( ---!nonstrict -_ = ... -_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil -do end -)"); -} - -TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") -{ - check(R"( ---!nonstrict -function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) -_(...)(setfenv,_,not _,"")[_] = nil -end -do end -_(...)(...,setfenv,_):_G() -)"); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") -{ - // this has a risk of creating cyclic type packs, causing infinite loops / OOMs - check(R"( ---!nonstrict -_ += _(_,...) -repeat -_ += _(...) -until ... + _ -)"); - - check(R"( ---!nonstrict -_ += _(_(...,...),_(...)) -repeat -until _ -)"); -} - TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { check(R"( @@ -3470,19 +480,6 @@ end )"); } -TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") -{ - CheckResult result = check(R"( - --!strict - local t = {} - while true and t[1] do - print(t[1].test) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - struct FindFreeTypeVars { bool foundOne = false; @@ -3506,32 +503,6 @@ struct FindFreeTypeVars } }; -TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") -{ - CheckResult result = check("local x = setmetatable({})"); - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") -{ - // This code doesn't pass typechecking. We just care that it doesn't crash. - (void)check(R"( - --!nonstrict - function _:_(...) - end - - repeat - if _ then - else - _ = ... - end - until _ - - for _ in _() do - end - )"); -} - TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { CheckResult result = check(R"( @@ -3604,36 +575,6 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error } } -TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") -{ - CheckResult result = check(R"( - local a: boolean = true - local b: boolean = false - local foo = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); -} - -TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") -{ - CheckResult result = check(R"( - local a: number | string = "" - local b: number | string = 1 - local foo = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); -} - TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") { CheckResult result = check(R"( @@ -3650,100 +591,6 @@ TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") CHECK_EQ("x", up->key); } -TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") -{ - { - CheckResult result = check(R"( - function unreachablecodepath(a): number - while true do - if a then return 10 end - end - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } - - { - CheckResult result = check(R"( - function reachablecodepath(a): number - while true do - if a then break end - return 10 - end - - print("x") -- correct error - end - reachablecodepath(4) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK(get(result.errors[0])); - } - - { - CheckResult result = check(R"( - function unreachablecodepath(a): number - repeat - if a then return 10 end - until false - - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } - - { - CheckResult result = check(R"( - function reachablecodepath(a, b): number - repeat - if a then break end - - if b then return 10 end - until false - - print("x") -- correct error - end - reachablecodepath(4) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK(get(result.errors[0])); - } - - { - CheckResult result = check(R"( - function unreachablecodepath(a: number?): number - repeat - return 10 - until a ~= nil - - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } -} - -TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") -{ - CheckResult result = check(R"( - --!strict - local _ - _ += _ and _ or _ and _ or _ and _ - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); -} - TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") { CheckResult result = check(R"( @@ -3759,58 +606,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") -{ - // In non-strict mode, global definition is still allowed - { - CheckResult result = check(R"( - --!nonstrict - a = a + 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } - - // In strict mode we no longer generate two errors from lhs - { - CheckResult result = check(R"( - --!strict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } - - // In non-strict mode, compound assignment is not a definition, it's a modification - { - CheckResult result = check(R"( - --!nonstrict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } -} - -TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") -{ - CheckResult result = check(R"( - local t = {} - for _ in t do - for _ in assert(missing()) do - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { CheckResult result = check(R"( @@ -3821,26 +616,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") -{ - CheckResult result = check(R"( ---!strict -local T: any -T = {} -T.__index = T -function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self -end -function T:construct(index) -end -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( @@ -3868,64 +643,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_on_astexprerror") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") -{ - CheckResult result = check(R"( ---!strict -local a: number? = nil -local b: number = a or 1 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") -{ - CheckResult result = check(R"( ---!nonstrict -local a: number? = nil -local b: number = a or 1 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") -{ - CheckResult result = check(R"( ---!strict -local a: number? = nil -local b: number = 1 or a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ("number?", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") -{ - CheckResult result = check(R"( - --!strict - local tbl = {} - function tbl:abc(a: number, b: number) - return a - end - tbl:abc(1, 2) -- Line 6 - -- | Column 14 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - TypeId type = requireTypeAtPosition(Position(6, 14)); - CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); - REQUIRE(ftv); - CHECK(ftv->hasSelf); -} - TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") { CheckResult result = check(R"( @@ -3943,326 +660,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } -TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") -{ - CheckResult result = check(R"( - type Array = { [number]: T } - type Fiber = { id: number } - type null = {} - - local fiberStack: Array = {} - local index = 0 - - local function f(fiber: Fiber) - local a = fiber ~= fiberStack[index] - local b = fiberStack[index] ~= fiber - end - - return f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") -{ - fileResolver.source["game/A"] = R"( ---!strict -return { def = 4 } - )"; - - fileResolver.source["game/B"] = R"( ---!strict -local tbl = { abc = require(game.A) } -local a : string = "" -a = tbl.abc.def - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") -{ - fileResolver.source["game/A"] = R"( -return { def = 4 } - )"; - - fileResolver.source["game/B"] = R"( -local tbl: string = require(game.A) - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") -{ - CheckResult result = check(R"( ---!nonstrict -local f = {} -function f:foo(a: number, b: number) end - -function bar(...) - f.foo(f, 1, ...) -end - -bar(2) -)"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") -{ - CheckResult result = check(R"( -local function f(a: typeof(f)) end -)"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") -{ - CheckResult result = check(R"( ---!nonstrict -local l0:any,l61:t0 = _,math -while _ do -_() -end -function _():t0 -end -type t0 = any -)"); - - std::optional ty = requireType("math"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") -{ - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - - CheckResult result = check(R"( -type X = T -type K = X -)"); - - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("math"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") -{ - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - - CheckResult result = check(R"( -type X = T -local a = {} -a.x = 4 -local b: X -a.y = 5 -local c: X -c = b -)"); - - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("a"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") -{ - CheckResult result = check(R"( -local n = {} -function n:Clone() end - -local m = {} - -function m.a(x) - x:Clone() -end - -function m.b() - m.a(n) -end - -return m -)"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") -{ - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - - // Mutability in type function application right now can create strange recursive types - CheckResult result = check(R"( -type Table = { a: number } -type Self = T -local a: Self
- )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "Table"); -} - -TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") -{ - TypeId mathTy = requireType(typeChecker.globalScope, "math"); - REQUIRE(mathTy); - TableTypeVar* ttv = getMutable(mathTy); - REQUIRE(ttv); - const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); - REQUIRE(ftv); - auto original = ftv->level; - - CheckResult result = check("local a = math.frexp"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK(ftv->level.level == original.level); - CHECK(ftv->level.subLevel == original.subLevel); -} - -TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") -{ - CheckResult result = check(R"( -local foo = {42} -local bar: number? -local baz = foo[bar] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); -} - -TEST_CASE_FIXTURE(Fixture, "table_simple_call") -{ - CheckResult result = check(R"( -local a = setmetatable({ x = 2 }, { - __call = function(self) - return (self.x :: number) * 2 -- should work without annotation in the future - end -}) -local b = a() -local c = a(2) -- too many arguments - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "custom_require_global") -{ - CheckResult result = check(R"( ---!nonstrict -require = function(a) end - -local crash = require(game.A) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") -{ - ScopedFastFlag sff1{"LuauEqConstraint", true}; - - CheckResult result = check(R"( - local function f(a: string | number, b: boolean | number) - return a == b - end - )"); - - // This doesn't produce any errors but for the wrong reasons. - // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") -{ - CheckResult result = check(R"( - type Foo = {x: string} - local t = {} - setmetatable(t, { - __index = function(x: string): ...Foo - return {x = x} - end - }) - - local foo = t.bar - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions o; - o.exhaustive = true; - CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); -} - -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") -{ - CheckResult result = check(R"( - type ( ... ) ( ) ; - ( ... ) ( - - ... ) ( - ... ) - type = ( ... ) ; - ( ... ) ( ) ( ... ) ; - ( ... ) "" - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") -{ - CheckResult result = check(R"( - function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) - xpcall(_,_,_) - _(_,_,_) - end - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") -{ - CheckResult result = check(R"( - function _(l0:t0): (any, ()->()) - end - - type t0 = t0 | {} - )"); - - CHECK_LE(0, result.errors.size()); - - std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); - REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { - return get(err); - }); - CHECK(it != result.errors.end()); -} - TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") { CheckResult result = check(R"( @@ -4316,22 +713,6 @@ TEST_CASE_FIXTURE(Fixture, "no_infinite_loop_when_trying_to_unify_uh_this") CHECK_LE(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") -{ - CheckResult result = check(R"( - local l0,l0 - repeat - type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) - function _(l0):(t0)&(t0) - while nil do - end - end - until _(_)(_)._ - )"); - - CHECK_LE(0, result.errors.size()); -} - TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") { CheckResult result = check(R"( @@ -4349,329 +730,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") CHECK_LE(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") -{ - fileResolver.source["Module/Backend/Types"] = R"( - export type Fiber = { - return_: Fiber? - } - return {} - )"; - - fileResolver.source["Module/Backend"] = R"( - local Types = require(script.Types) - type Fiber = Types.Fiber - type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } - - local function attach(renderer): () - local function getPrimaryFiber(fiber) - local alternate = fiber.alternate - return fiber - end - - local function getFiberIDForNative() - local fiber = renderer.findFiberByHostInstance() - fiber = fiber.return_ - return getPrimaryFiber(fiber) - end - end - - function culprit(renderer: ReactRenderer): () - attach(renderer) - end - - return culprit - )"; - - CheckResult result = frontend.check("Module/Backend"); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: {Tree} } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- this would be an infinite type if we allowed it - type Tree = { data: T, children: {Tree<{T}>} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "record_matching_overload") -{ - CheckResult result = check(R"( - type Overload = ((string) -> string) & ((number) -> number) - local abc: Overload - abc(1) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // AstExprCall is the node that has the overload stored on it. - // findTypeAtPosition will look at the AstExprLocal, but this is not what - // we want to look at. - std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); - REQUIRE_GE(ancestry.size(), 2); - AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); - REQUIRE(bool(parentExpr)); - REQUIRE(parentExpr->is()); - - ModulePtr module = getMainModule(); - auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it); - CHECK_EQ(toString(*it), "(number) -> number"); -} - -TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( - type Overload = ((string) -> string) & ((number, number) -> number) - local abc: Overload - local x = abc(true) - local y = abc(true,true) - local z = abc(true,true,true) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("string", toString(requireType("x"))); - CHECK_EQ("number", toString(requireType("y"))); - // Should this be string|number? - CHECK_EQ("string", toString(requireType("z"))); -} - -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") -{ - // Simple direct arg to arg propagation - CheckResult result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // An optional function is accepted, but since we already provide a function, nil can be ignored - result = check(R"( -type Table = { x: number, y: number } -local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end -f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Make sure self calls match correct index - result = check(R"( -type Table = { x: number, y: number } -local x = {} -x.b = {x = 1, y = 2} -function x:f(a: (Table) -> number) return a(self.b) end -x:f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Mix inferred and explicit argument types - result = check(R"( -function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end -f(function(a: number, b, c) return c and a + b or b - a end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Anonymous function has a variadic pack - result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(...) return select(1, ...).z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Can't accept more arguments than provided - result = check(R"( -function f(a: (a: number, b: number) -> number) return a(1, 2) end -f(function(a, b, c, ...) return a + b end) - )"); - - LUAU_REQUIRE_ERRORS(result); - - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' -caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - - // Infer from variadic packs into elements - result = check(R"( -function f(a: (...number) -> number) return a(1, 2) end -f(function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Infer from variadic packs into variadic packs - result = check(R"( -type Table = { x: number, y: number } -function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end -f(function(a, ...) local b = ... return b.z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Return type inference - result = check(R"( -type Table = { x: number, y: number } -function f(a: (number) -> Table) return a(4) end -f(function(x) return x * 2 end) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{boolean}", toString(requireType("r"))); - - check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") -{ - CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") -{ - CheckResult result = check(R"( -local a = {{x=4}, {x=7}, {x=1}} -table.sort(a, function(x, y) return x.x < y.x end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") -{ - CheckResult result = check(R"( -type Table = { x: number, y: number } -local f: (Table) -> number = function(t) return t.x + t.y end - -type TableWithFunc = { x: number, y: number, f: (number, number) -> number } -local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") -{ - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end - -local function sumrec(f: typeof(sum)) - return sum(2, 3, function(a, b) return a + b end) -end - -local b = sumrec(sum) -- ok -local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") -{ - CheckResult result = check(R"( -local function f(): {string|number} - return {1, "b", 3} -end - -local function g(): (number, {string|number}) - return 4, {1, "b", 3} -end - -local function h(): ...{string|number} - return {4}, {1, "b", 3}, {"s"} -end - -local function i(): ...{string|number} - return {1, "b", 3}, h() -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") { CheckResult result = check(R"( @@ -4719,56 +777,6 @@ f(((function(a, b) return a + b end))) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "refine_and_or") -{ - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t and t.x or 5 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") -{ - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t.x and t or 5 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") -{ - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t and t.x == 5 or t.x == 31337 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("boolean", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") -{ - CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { CheckResult result = check(R"(local a = if true then "true" else "false")"); @@ -4830,40 +838,6 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_error_addition") -{ - CheckResult result = check(R"( ---!strict -local foo = makesandwich() -local bar = foo.nutrition + 100 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - // We should definitely get this error - CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); - // We get this error if makesandwich() returns a free type - // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); -} - -TEST_CASE_FIXTURE(Fixture, "require_failed_module") -{ - fileResolver.source["game/A"] = R"( -return unfortunately() - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_ERRORS(aResult); - - CheckResult result = check(R"( -local ModuleA = require(game.A) - )"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional oty = requireType("ModuleA"); - CHECK_EQ("*unknown*", toString(*oty)); -} - /* * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. * @@ -4921,184 +895,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } -/* - * We had an issue where part of the type of pairs() was an unsealed table. - * This test depends on FFlagDebugLuauFreezeArena to trigger it. - */ -TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") -{ - check(R"( - function _(l0:{n0:any}) - _ = pairs - end - )"); -} - -TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") -{ - check(R"( - function Base64FileReader(data) - local reader = {} - local index: number - - function reader:PeekByte() - return data:byte(index) - end - - function reader:Byte() - return data:byte(index - 1) - end - - return reader - end - - Base64FileReader() - - function ReadMidiEvents(data) - - local reader = Base64FileReader(data) - - while reader:HasMore() do - (reader:Byte() % 128) - end - end - )"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") -{ - CheckResult result = check(R"( -type A = (number, number) -> string -type B = (number) -> string - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' -caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") -{ - CheckResult result = check(R"( -type A = (number, number) -> string -type B = (number, string) -> string - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' -caused by: - Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") -{ - CheckResult result = check(R"( -type A = (number, number) -> (number) -type B = (number, number) -> (number, boolean) - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' -caused by: - Function only returns 1 value. 2 are required here)"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") -{ - CheckResult result = check(R"( -type A = (number, number) -> string -type B = (number, number) -> number - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' -caused by: - Return type is not compatible. Type 'string' could not be converted into 'number')"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") -{ - CheckResult result = check(R"( -type A = (number, number) -> (number, string) -type B = (number, number) -> (number, boolean) - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' -caused by: - Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); -} - -TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") -{ - CheckResult result = check(R"( - local function f(thing: any | string) - local foo = thing.SomeRandomKey - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") -{ - CheckResult result = check(R"( -local t = {} - -function t.x(value) - for k,v in pairs(t) do end -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "table_oop") -{ - CheckResult result = check(R"( - --!strict -local Class = {} -Class.__index = Class - -type Class = typeof(setmetatable({} :: { x: number }, Class)) - -function Class.new(x: number): Class - return setmetatable({x = x}, Class) -end - -function Class.getx(self: Class) - return self.x -end - -function test() - local c = Class.new(42) - local n = c:getx() - local nn = c.x - - print(string.format("%d %d", n, nn)) -end -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") { CheckResult result = check(R"( @@ -5182,213 +978,4 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") -{ - fileResolver.source["game/A"] = R"( -export type Type = { unrelated: boolean } -return {} - )"; - - fileResolver.source["game/B"] = R"( -local types = require(game.A) -type Type = types.Type -local x: Type = {} -function x:Destroy(): () end - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") -{ - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - - fileResolver.source["game/A"] = R"( -export type Type = { x: { a: number } } -return {} - )"; - - fileResolver.source["game/B"] = R"( -local types = require(game.A) -type Type = types.Type -local x: Type = { x = { a = 2 } } -type Rename = typeof(x.x) - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") -{ - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - - fileResolver.source["game/A"] = R"( -local y = setmetatable({}, {}) -export type Type = { x: typeof(y) } -return { x = y } - )"; - - fileResolver.source["game/B"] = R"( -local types = require(game.A) -type Type = types.Type -local x: Type = types -type Rename = typeof(x.x) - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" then - local x = a:byte() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" or a == "bye" then - local x = a:byte() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); -} - -TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" then - local x = #a - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); -} - -TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" or a == "bye" then - local x = #a - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); -} - -/* - * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at - * the level of the table. - */ -TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") -{ - CheckResult result = check(R"( - --!strict - local T = {} - - local function f(prop) - T[1] = { - prop = prop, - } - end - - local function g() - local l = T[1].prop - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") -{ - CheckResult result = check(R"( -local function f(x: string) - local p = x:split('a') - p = table.pack(table.unpack(p, 1, #p - 1)) - return p -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") -{ - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; - - fileResolver.source["game/isAMagicMock"] = R"( ---!nonstrict -return function(value) - return false -end - )"; - - CheckResult result = check(R"( ---!nonstrict -local MagicMock = {} -MagicMock.is = require(game.isAMagicMock) - -function MagicMock.is(value) - return false -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") -{ - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; - - CheckResult result = check(R"( -function string.len(): number - return 1 -end - )"); - - LUAU_REQUIRE_ERRORS(result); -} - TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index fcc21c18..130f33d7 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -895,4 +895,87 @@ caused by: Type 'boolean' could not be converted into 'string')"); } +// TODO: File a Jira about this +/* +TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") +{ + CheckResult result = check(R"( + function a(x) return 1 end + a(...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); + + TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; + + auto iter = begin(varargPack); + auto endIter = end(varargPack); + + CHECK(iter != endIter); + ++iter; + CHECK(iter == endIter); + + CHECK(!iter.tail()); +} +*/ + +TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") +{ + CheckResult result = check(R"( + --!strict + function f(s) + print(s) + return f + end + + f("foo")("bar") + )"); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") +{ + // this has a risk of creating cyclic type packs, causing infinite loops / OOMs + check(R"( +--!nonstrict +_ += _(_,...) +repeat +_ += _(...) +until ... + _ +)"); + + check(R"( +--!nonstrict +_ += _(_(...,...),_(...)) +repeat +until _ +)"); +} + +TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") +{ + CheckResult result = check(R"( + type ( ... ) ( ) ; + ( ... ) ( - - ... ) ( - ... ) + type = ( ... ) ; + ( ... ) ( ) ( ... ) ; + ( ... ) "" + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") +{ + CheckResult result = check(R"( + function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) + xpcall(_,_,_) + _(_,_,_) + end + )"); + + CHECK_LE(0, result.errors.size()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ad4cecd8..ad1e31e5 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -496,4 +496,20 @@ caused by: None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); } +// We had a bug where a cyclic union caused a stack overflow. +// ex type U = number | U +TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") +{ + CheckResult result = check(R"( + --!strict + + function f(a, b) + a:g(b or {}) + a:g(b) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 78d90077..f803c319 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -833,6 +833,17 @@ assert((function() return sum end)() == 105) +-- shrinking array part +assert((function() + local t = table.create(100, 42) + for i=1,90 do t[i] = nil end + t[101] = 42 + local sum = 0 + for _,v in ipairs(t) do sum += v end + for _,v in pairs(t) do sum += v end + return sum +end)() == 462) + -- upvalues: recursive capture assert((function() local function fact(n) return n < 1 and 1 or n * fact(n-1) end return fact(5) end)() == 120) @@ -881,6 +892,14 @@ end)() == "6,8,10") -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") +-- type/typeof/newproxy interaction with metatables: __type doesn't work intentionally to avoid spoofing +assert((function() + local ud = newproxy(true) + getmetatable(ud).__type = "number" + + return concat(type(ud),typeof(ud)) +end)() == "userdata,userdata") + testgetfenv() -- DONT MOVE THIS LINE return 'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 6ba99fb9..ec0b412e 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -3,14 +3,14 @@ print "testing debugger" -- note, this file can't run in isolation from C tests local a = 5 -function foo(b) +function foo(b, ...) print("in foo", b) a = 6 end breakpoint(8) -foo(50) +foo(50, 42) breakpoint(16) -- next line print("here") diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 751188be..297cf011 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -305,4 +305,6 @@ assert(ecall(function() return "a" + "b" end) == "attempt to perform arithmetic assert(ecall(function() return 1 > nil end) == "attempt to compare nil < number") -- note reversed order (by design) assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= number") +assert(ecall(function() local t = {} setmetatable(t, { __newindex = function(t,i,v) end }) t[nil] = 2 end) == "table index is nil") + return('OK') diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua new file mode 100644 index 00000000..2b127099 --- /dev/null +++ b/tests/conformance/interrupt.lua @@ -0,0 +1,11 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing interrupts") + +function foo() + for i=1,10 do end + return +end + +foo() + +return "OK" diff --git a/tools/gdb-printers.py b/tools/gdb_printers.py similarity index 100% rename from tools/gdb-printers.py rename to tools/gdb_printers.py diff --git a/tools/lldb-formatters.lldb b/tools/lldb-formatters.lldb deleted file mode 100644 index 3868ac20..00000000 --- a/tools/lldb-formatters.lldb +++ /dev/null @@ -1,2 +0,0 @@ -type synthetic add -x "^Luau::Variant<.+>$" -l LuauVisualize.LuauVariantSyntheticChildrenProvider -type summary add -x "^Luau::Variant<.+>$" -l LuauVisualize.luau_variant_summary diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb new file mode 100644 index 00000000..f6fa6cf5 --- /dev/null +++ b/tools/lldb_formatters.lldb @@ -0,0 +1,2 @@ +type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSyntheticChildrenProvider +type summary add -x "^Luau::Variant<.+>$" -l lldb_formatters.luau_variant_summary diff --git a/tools/LuauVisualize.py b/tools/lldb_formatters.py similarity index 100% rename from tools/LuauVisualize.py rename to tools/lldb_formatters.py From 57faf7aaf2d6dd20eb69bc4088e61549cfa6552f Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 17 Mar 2022 17:32:02 -0700 Subject: [PATCH 031/102] Lower the stack limit to make tests pass in debug --- tests/TypeInfer.test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 571d0f8d..660ddcfc 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -334,7 +334,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") #if defined(LUAU_ENABLE_ASAN) int limit = 250; #elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; + int limit = 300; #else int limit = 600; #endif From 373da161e915de2aa71ba83fe9baf23b269057f3 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 24 Mar 2022 14:49:08 -0700 Subject: [PATCH 032/102] Sync to upstream/release/520 --- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/ToString.h | 3 +- Analysis/include/Luau/TypePack.h | 6 +- Analysis/include/Luau/TypeVar.h | 3 + Analysis/include/Luau/Unifier.h | 4 +- Analysis/include/Luau/UnifierSharedState.h | 2 + Analysis/src/Error.cpp | 30 ++- Analysis/src/Linter.cpp | 43 +++- Analysis/src/Module.cpp | 31 +-- Analysis/src/ToString.cpp | 72 +++++-- Analysis/src/TypeInfer.cpp | 69 ++++--- Analysis/src/TypePack.cpp | 19 +- Analysis/src/TypeVar.cpp | 18 ++ Analysis/src/Unifier.cpp | 173 +++++++++++----- Ast/src/Parser.cpp | 20 +- VM/src/lapi.cpp | 50 ++--- VM/src/ldo.cpp | 21 +- VM/src/ldo.h | 2 +- VM/src/lgc.cpp | 228 +++++++++++---------- VM/src/lgc.h | 2 +- VM/src/lstate.cpp | 26 +-- VM/src/lstate.h | 44 ++-- VM/src/ltable.cpp | 4 +- tests/Autocomplete.test.cpp | 5 - tests/Conformance.test.cpp | 2 - tests/Fixture.cpp | 19 +- tests/Linter.test.cpp | 85 ++------ tests/ToString.test.cpp | 25 +++ tests/Transpiler.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 2 - tests/TypeInfer.functions.test.cpp | 76 +++++++ tests/TypeInfer.modules.test.cpp | 85 +++++++- tests/TypeInfer.operators.test.cpp | 26 +++ tests/TypeInfer.refinements.test.cpp | 11 - tests/TypeInfer.singletons.test.cpp | 118 +---------- tests/TypeInfer.tables.test.cpp | 38 ++++ tests/TypeInfer.test.cpp | 2 +- tests/TypeInfer.unionTypes.test.cpp | 1 + 38 files changed, 805 insertions(+), 563 deletions(-) diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 72350255..53b946a0 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -96,6 +96,7 @@ struct CountMismatch size_t expected; size_t actual; Context context = Arg; + bool isVariadic = false; bool operator==(const CountMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index a97bf6d6..49ee82fe 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -32,6 +32,7 @@ struct ToStringOptions size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' + std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; struct ToStringResult @@ -65,7 +66,7 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, const ToStringOptions& opts = {}); // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 946be356..85fa467f 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -119,9 +119,9 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); TypePackId follow(TypePackId tp, std::function mapper); -size_t size(TypePackId tp); -bool finite(TypePackId tp); -size_t size(const TypePack& tp); +size_t size(TypePackId tp, TxnLog* log = nullptr); +bool finite(TypePackId tp, TxnLog* log = nullptr); +size_t size(const TypePack& tp, TxnLog* log = nullptr); std::optional first(TypePackId tp); TypePackVar* asMutable(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 29578dcd..b8c4b362 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -488,6 +488,9 @@ const TableTypeVar* getTableType(TypeId type); // Returns nullptr if the type has no name. const std::string* getName(TypeId type); +// Returns name of the module where type was defined if type has that information +std::optional getDefinitionModuleName(TypeId type); + // Checks whether a union contains all types of another union. bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index f1ffbcc0..474af50c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -90,7 +90,9 @@ private: TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); - void cacheResult(TypeId subTy, TypeId superTy); + bool canCacheResult(TypeId subTy, TypeId superTy); + void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount); + void cacheResult_DEPRECATED(TypeId subTy, TypeId superTy); public: void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 88997c41..9a3ba56d 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/DenseHash.h" +#include "Luau/Error.h" #include "Luau/TypeVar.h" #include "Luau/TypePack.h" @@ -42,6 +43,7 @@ struct UnifierSharedState DenseHashSet seenAny{nullptr}; DenseHashMap skipCacheForType{nullptr}; DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + DenseHashMap, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}}; DenseHashSet tempSeenTy{nullptr}; DenseHashSet tempSeenTp{nullptr}; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 26d3b76d..210c0191 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -53,7 +54,32 @@ struct ErrorConverter { std::string operator()(const Luau::TypeMismatch& tm) const { - std::string result = "Type '" + Luau::toString(tm.givenType) + "' could not be converted into '" + Luau::toString(tm.wantedType) + "'"; + std::string givenTypeName = Luau::toString(tm.givenType); + std::string wantedTypeName = Luau::toString(tm.wantedType); + + std::string result; + + if (FFlag::LuauTypeMismatchModuleName) + { + if (givenTypeName == wantedTypeName) + { + if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) + { + if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) + { + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; + } + } + } + + if (result.empty()) + result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + } + else + { + result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + } if (tm.error) { @@ -147,7 +173,7 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual, /*argPrefix*/ nullptr, e.isVariadic); } LUAU_ASSERT(!"Unknown context"); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 56c4e3e8..b7480e34 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,6 +14,7 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) +LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false) namespace Luau { @@ -1135,16 +1136,20 @@ private: enum TypeKind { - Kind_Invalid, + Kind_Unknown, Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. - Kind_Vector, // For 'vector' but only used when type is used - Kind_Userdata, // custom userdata type - Vector3/etc. + Kind_Vector, // 'vector' but only used when type is used + Kind_Userdata, // custom userdata type + + // TODO: remove these with LuauLintNoRobloxBits Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. }; bool containsPropName(TypeId ty, const std::string& propName) { + LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); + if (auto ctv = get(ty)) return lookupClassProp(ctv, propName) != nullptr; @@ -1163,13 +1168,23 @@ private: if (name == "vector") return Kind_Vector; - if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; + if (FFlag::LuauLintNoRobloxBits) + { + if (std::optional maybeTy = context->scope->lookupType(name)) + return Kind_Userdata; - return Kind_Invalid; + return Kind_Unknown; + } + else + { + if (std::optional maybeTy = context->scope->lookupType(name)) + // Kind_Userdata is probably not 100% precise but is close enough + return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; + else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) + return Kind_Enum; + + return Kind_Unknown; + } } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1177,7 +1192,7 @@ private: std::string name(expr->value.data, expr->value.size); TypeKind kind = getTypeKind(name); - if (kind == Kind_Invalid) + if (kind == Kind_Unknown) { emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s'", name.c_str()); return; @@ -1189,7 +1204,7 @@ private: return; // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) + if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) return; } @@ -1198,12 +1213,18 @@ private: bool acceptsClassName(AstName method) { + LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); + return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); } bool visit(AstExprCall* node) override { + // TODO: Simply remove the override + if (FFlag::LuauLintNoRobloxBits) + return true; + if (AstExprIndexName* index = node->func->as()) { AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index a330a98d..0787d3a4 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,10 +12,8 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) -LUAU_FASTFLAG(LuauImmutableTypes) namespace Luau { @@ -65,8 +63,7 @@ TypeId TypeArena::addTV(TypeVar&& tv) { TypeId allocated = typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -75,8 +72,7 @@ TypeId TypeArena::freshType(TypeLevel level) { TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -85,8 +81,7 @@ TypePackId TypeArena::addTypePack(std::initializer_list types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -95,8 +90,7 @@ TypePackId TypeArena::addTypePack(std::vector types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -105,8 +99,7 @@ TypePackId TypeArena::addTypePack(TypePack tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -115,8 +108,7 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -439,16 +431,9 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - if (FFlag::LuauImmutableTypes) - { - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - else - { + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } } return res; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 010ca361..59ee6de2 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -16,6 +16,7 @@ * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTFLAGVARIABLE(LuauDocFuncParameters, false) namespace Luau { @@ -769,6 +770,7 @@ struct TypePackStringifier else state.emit(", "); + // Do not respect opts.namedFunctionOverrideArgNames here if (elemIndex < elemNames.size() && elemNames[elemIndex]) { state.emit(elemNames[elemIndex]->name); @@ -1090,13 +1092,13 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, const ToStringOptions& opts) { ToStringResult result; StringifierState state(opts, result, opts.nameMap); TypeVarStringifier tvs{state}; - state.emit(prefix); + state.emit(funcName); if (!opts.hideNamedFunctionTypeParameters) tvs.stringify(ftv.generics, ftv.genericPacks); @@ -1104,28 +1106,59 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV state.emit("("); auto argPackIter = begin(ftv.argTypes); - auto argNameIter = ftv.argNames.begin(); bool first = true; - while (argPackIter != end(ftv.argTypes)) + if (FFlag::LuauDocFuncParameters) { - if (!first) - state.emit(", "); - first = false; - - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) + size_t idx = 0; + while (argPackIter != end(ftv.argTypes)) { - state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); - ++argNameIter; - } - else - { - state.emit("_: "); - } + if (!first) + state.emit(", "); + first = false; - tvs.stringify(*argPackIter); - ++argPackIter; + // We don't respect opts.functionTypeArguments + if (idx < opts.namedFunctionOverrideArgNames.size()) + { + state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); + } + else if (idx < ftv.argNames.size() && ftv.argNames[idx]) + { + state.emit(ftv.argNames[idx]->name + ": "); + } + else + { + state.emit("_: "); + } + tvs.stringify(*argPackIter); + + ++argPackIter; + ++idx; + } + } + else + { + auto argNameIter = ftv.argNames.begin(); + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + state.emit(", "); + first = false; + + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); + ++argNameIter; + } + else + { + state.emit("_: "); + } + + tvs.stringify(*argPackIter); + ++argPackIter; + } } if (argPackIter.tail()) @@ -1134,7 +1167,6 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV state.emit(", "); state.emit("...: "); - if (auto vtp = get(*argPackIter.tail())) tvs.stringify(vtp->ty); else diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 41e8ce55..9965d5aa 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,10 +27,8 @@ LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as fals LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) -LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) -LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -38,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) +LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. @@ -47,6 +46,8 @@ LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) +LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) +LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) namespace Luau { @@ -291,6 +292,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. unifierState.cachedUnify.clear(); + unifierState.cachedUnifyError.clear(); unifierState.skipCacheForType.clear(); if (FFlag::LuauTwoPassAliasDefinitionFix) @@ -1303,7 +1305,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy // Additionally, we can't modify types that come from other modules - if (ttv->name || (FFlag::LuauImmutableTypes && follow(ty)->owningArena != ¤tModule->internalTypes)) + if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) { bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), binding->typeParams.end(), [](auto&& itp, auto&& tp) { @@ -1315,7 +1317,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias }); // Copy can be skipped if this is an identical alias - if ((FFlag::LuauImmutableTypes && !ttv->name) || ttv->name != name || !sameTys || !sameTps) + if (!ttv->name || ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1349,7 +1351,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else if (auto mtv = getMutable(follow(ty))) { // We can't modify types that come from other modules - if (!FFlag::LuauImmutableTypes || follow(ty)->owningArena == ¤tModule->internalTypes) + if (follow(ty)->owningArena == ¤tModule->internalTypes) mtv->syntheticName = name; } @@ -1512,14 +1514,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) + if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) result = {singletonType(bexpr->value)}; else result = {booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) + if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else result = {stringType}; @@ -2490,12 +2492,24 @@ TypeId TypeChecker::checkBinaryOperation( lhsType = follow(lhsType); rhsType = follow(rhsType); - if (!isNonstrictMode() && get(lhsType)) + if (FFlag::LuauDecoupleOperatorInferenceFromUnifiedTypeInference) { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); + if (!isNonstrictMode() && get(lhsType)) + { + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + // We will fall-through to the `return anyType` check below. + } + } + else + { + if (!isNonstrictMode() && get(lhsType)) + { + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -3452,7 +3466,8 @@ void TypeChecker::checkArgumentList( { if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; } ++paramIter; @@ -4163,13 +4178,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - if (FFlag::LuauImmutableTypes) - return *moduleType; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks, cloneState); + return *moduleType; } void TypeChecker::tablify(TypeId type) @@ -4941,10 +4950,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& indexer = table->indexer) tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - return addType(TableTypeVar{ - props, tableIndexer, scope->level, - TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe - }); + if (FFlag::LuauTypeMismatchModuleName) + { + TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + ttv.definitionModuleName = currentModuleName; + return addType(std::move(ttv)); + } + else + { + return addType(TableTypeVar{ + props, tableIndexer, scope->level, + TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe + }); + } } else if (const auto& func = annotation.as()) { @@ -5206,6 +5224,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; + + if (FFlag::LuauTypeMismatchModuleName) + ttv->definitionModuleName = currentModuleName; } return instantiated; diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 91123f46..5bb05234 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -222,20 +222,21 @@ TypePackId follow(TypePackId tp, std::function mapper) } } -size_t size(TypePackId tp) +size_t size(TypePackId tp, TxnLog* log) { - if (auto pack = get(follow(tp))) - return size(*pack); + tp = log ? log->follow(tp) : follow(tp); + if (auto pack = get(tp)) + return size(*pack, log); else return 0; } -bool finite(TypePackId tp) +bool finite(TypePackId tp, TxnLog* log) { - tp = follow(tp); + tp = log ? log->follow(tp) : follow(tp); if (auto pack = get(tp)) - return pack->tail ? finite(*pack->tail) : true; + return pack->tail ? finite(*pack->tail, log) : true; if (get(tp)) return false; @@ -243,14 +244,14 @@ bool finite(TypePackId tp) return true; } -size_t size(const TypePack& tp) +size_t size(const TypePack& tp, TxnLog* log) { size_t result = tp.head.size(); if (tp.tail) { - const TypePack* tail = get(follow(*tp.tail)); + const TypePack* tail = get(log ? log->follow(*tp.tail) : follow(*tp.tail)); if (tail) - result += size(*tail); + result += size(*tail, log); } return result; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 89549535..36545ad9 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -290,6 +290,24 @@ const std::string* getName(TypeId type) return nullptr; } +std::optional getDefinitionModuleName(TypeId type) +{ + type = follow(type); + + if (auto ttv = get(type)) + { + if (!ttv->definitionModuleName.empty()) + return ttv->definitionModuleName; + } + else if (auto ftv = get(type)) + { + if (ftv->definition) + return ftv->definition->definitionModuleName; + } + + return std::nullopt; +} + bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) { std::unordered_set superTypes; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 60a9c9a5..398dc9e2 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,10 +14,9 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); +LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) @@ -26,6 +25,7 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) +LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) namespace Luau @@ -63,7 +63,7 @@ struct PromoteTypeLevels bool operator()(TID ty, const T&) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; return true; @@ -83,7 +83,7 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FunctionTypeVar&) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; promote(ty, log.getMutable(ty)); @@ -93,7 +93,7 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const TableTypeVar& ttv) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; if (ttv.state != TableState::Free && ttv.state != TableState::Generic) @@ -118,7 +118,7 @@ struct PromoteTypeLevels static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return; PromoteTypeLevels ptl{log, typeArena, minLevel}; @@ -130,7 +130,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + if (tp->owningArena != typeArena) return; PromoteTypeLevels ptl{log, typeArena, minLevel}; @@ -170,7 +170,7 @@ struct SkipCacheForType bool operator()(TypeId ty, const TableTypeVar&) { // Types from other modules don't contain mutable elements and are ok to cache - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; TableTypeVar& ttv = *getMutable(ty); @@ -194,7 +194,7 @@ struct SkipCacheForType bool operator()(TypeId ty, const T& t) { // Types from other modules don't contain mutable elements and are ok to cache - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; const bool* prev = skipCacheForType.find(ty); @@ -212,7 +212,7 @@ struct SkipCacheForType bool operator()(TypePackId tp, const T&) { // Types from other modules don't contain mutable elements and are ok to cache - if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + if (tp->owningArena != typeArena) return false; return true; @@ -445,12 +445,33 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(superTy, subTy); - bool cacheEnabled = !isFunctionCall && !isIntersection; + bool cacheEnabled; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before - if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) - return; + if (FFlag::LuauUnifierCacheErrors) + { + cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; + + if (cacheEnabled) + { + if (cache.contains({subTy, superTy})) + return; + + if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) + { + reportError(TypeError{location, *error}); + return; + } + } + } + else + { + cacheEnabled = !isFunctionCall && !isIntersection; + + if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + return; + } // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back @@ -461,6 +482,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.pushSeen(superTy, subTy); + size_t errorCount = errors.size(); + if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); @@ -480,8 +503,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && (log.getMutable(superTy) || log.getMutable(superTy)) && - log.getMutable(subTy)) + else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); else if (log.getMutable(superTy) && log.getMutable(subTy)) @@ -491,8 +513,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyTables(subTy, superTy, isIntersection); - if (cacheEnabled && errors.empty()) - cacheResult(subTy, superTy); + if (!FFlag::LuauUnifierCacheErrors) + { + if (cacheEnabled && errors.empty()) + cacheResult_DEPRECATED(subTy, superTy); + } } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. @@ -512,6 +537,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + if (FFlag::LuauUnifierCacheErrors && cacheEnabled) + cacheResult(subTy, superTy, errorCount); + log.popSeen(superTy, subTy); } @@ -646,10 +674,21 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[i]; - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + if (FFlag::LuauUnifierCacheErrors) { - startIndex = i; - break; + if (cache.contains({subTy, type})) + { + startIndex = i; + break; + } + } + else + { + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + { + startIndex = i; + break; + } } } } @@ -737,10 +776,21 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[i]; - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + if (FFlag::LuauUnifierCacheErrors) { - startIndex = i; - break; + if (cache.contains({type, superTy})) + { + startIndex = i; + break; + } + } + else + { + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } } } } @@ -771,17 +821,17 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV } } -void Unifier::cacheResult(TypeId subTy, TypeId superTy) +bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); if (superTyInfo && *superTyInfo) - return; + return false; bool* subTyInfo = sharedState.skipCacheForType.find(subTy); if (subTyInfo && *subTyInfo) - return; + return false; auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; @@ -793,9 +843,33 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy) }; if (!superTyInfo && skipCacheFor(superTy)) - return; + return false; if (!subTyInfo && skipCacheFor(subTy)) + return false; + + return true; +} + +void Unifier::cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount) +{ + if (errors.size() == prevErrorCount) + { + if (canCacheResult(subTy, superTy)) + sharedState.cachedUnify.insert({subTy, superTy}); + } + else if (errors.size() == prevErrorCount + 1) + { + if (canCacheResult(subTy, superTy)) + sharedState.cachedUnifyError[{subTy, superTy}] = errors.back().data; + } +} + +void Unifier::cacheResult_DEPRECATED(TypeId subTy, TypeId superTy) +{ + LUAU_ASSERT(!FFlag::LuauUnifierCacheErrors); + + if (!canCacheResult(subTy, superTy)) return; sharedState.cachedUnify.insert({superTy, subTy}); @@ -1283,24 +1357,6 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal subFunction = log.getMutable(subTy); } - if (!FFlag::LuauImmutableTypes) - { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - PendingType* newSubTy = log.queue(subTy); - FunctionTypeVar* newSubFtv = getMutable(newSubTy); - LUAU_ASSERT(newSubFtv); - newSubFtv->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - PendingType* newSuperTy = log.queue(superTy); - FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); - LUAU_ASSERT(newSuperFtv); - newSuperFtv->definition = subFunction->definition; - } - } - ctx = context; if (FFlag::LuauTxnLogSeesTypePacks2) @@ -1563,8 +1619,25 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + + if (FFlag::LuauExtendedIndexerError) + { + innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + + bool reported = !innerState.errors.empty(); + + checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + + innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + + if (!reported) + checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); + } + else + { + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + } if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1771,6 +1844,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); TableTypeVar* freeTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -1840,6 +1914,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -2120,6 +2195,8 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2 || !FFlag::LuauExtendedIndexerError); + tryUnify_(subIndexer.indexType, superIndexer.indexType); tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); } @@ -2211,7 +2288,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas queue.pop_back(); // Types from other modules don't have free types - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) continue; if (seen.find(ty)) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 941a3ea4..f6dfd904 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,8 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) namespace Luau { @@ -1233,8 +1231,7 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && - (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -1500,17 +1497,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } - else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + else if (lexer.current().type == Lexeme::ReservedTrue) { nextLexeme(); return {allocator.alloc(begin, true)}; } - else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + else if (lexer.current().type == Lexeme::ReservedFalse) { nextLexeme(); return {allocator.alloc(begin, false)}; } - else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { if (std::optional> value = parseCharArray()) { @@ -1520,7 +1517,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) else return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; } - else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + else if (lexer.current().type == Lexeme::BrokenString) { Location location = lexer.current().location; nextLexeme(); @@ -2189,11 +2186,8 @@ AstExpr* Parser::parseTableConstructor() AstExpr* key = allocator.alloc(name.location, nameString); AstExpr* value = parseExpr(); - if (FFlag::LuauTableFieldFunctionDebugname) - { - if (AstExprFunction* func = value->as()) - func->debugname = name.name; - } + if (AstExprFunction* func = value->as()) + func->debugname = name.name; items.push_back({AstExprTable::Item::Record, key, value}); } diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 3c087314..46b10934 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauGcAdditionalStats) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1060,8 +1058,11 @@ int lua_gc(lua_State* L, int what, int data) g->GCthreshold = 0; bool waspaused = g->gcstate == GCSpause; - double startmarktime = g->gcstats.currcycle.marktime; - double startsweeptime = g->gcstats.currcycle.sweeptime; + +#ifdef LUAI_GCMETRICS + double startmarktime = g->gcmetrics.currcycle.marktime; + double startsweeptime = g->gcmetrics.currcycle.sweeptime; +#endif // track how much work the loop will actually perform size_t actualwork = 0; @@ -1079,31 +1080,30 @@ int lua_gc(lua_State* L, int what, int data) } } - if (FFlag::LuauGcAdditionalStats) +#ifdef LUAI_GCMETRICS + // record explicit step statistics + GCCycleMetrics* cyclemetrics = g->gcstate == GCSpause ? &g->gcmetrics.lastcycle : &g->gcmetrics.currcycle; + + double totalmarktime = cyclemetrics->marktime - startmarktime; + double totalsweeptime = cyclemetrics->sweeptime - startsweeptime; + + if (totalmarktime > 0.0) { - // record explicit step statistics - GCCycleStats* cyclestats = g->gcstate == GCSpause ? &g->gcstats.lastcycle : &g->gcstats.currcycle; + cyclemetrics->markexplicitsteps++; - double totalmarktime = cyclestats->marktime - startmarktime; - double totalsweeptime = cyclestats->sweeptime - startsweeptime; - - if (totalmarktime > 0.0) - { - cyclestats->markexplicitsteps++; - - if (totalmarktime > cyclestats->markmaxexplicittime) - cyclestats->markmaxexplicittime = totalmarktime; - } - - if (totalsweeptime > 0.0) - { - cyclestats->sweepexplicitsteps++; - - if (totalsweeptime > cyclestats->sweepmaxexplicittime) - cyclestats->sweepmaxexplicittime = totalsweeptime; - } + if (totalmarktime > cyclemetrics->markmaxexplicittime) + cyclemetrics->markmaxexplicittime = totalmarktime; } + if (totalsweeptime > 0.0) + { + cyclemetrics->sweepexplicitsteps++; + + if (totalsweeptime > cyclemetrics->sweepmaxexplicittime) + cyclemetrics->sweepmaxexplicittime = totalsweeptime; + } +#endif + // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) { diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index b5ae496b..c133a59e 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAG(LuauReduceStackReallocs) - /* ** {====================================================== ** Error-recovery functions @@ -33,6 +31,15 @@ struct lua_jmpbuf jmp_buf buf; }; +/* use POSIX versions of setjmp/longjmp if possible: they don't save/restore signal mask and are therefore faster */ +#if defined(__linux__) || defined(__APPLE__) +#define LUAU_SETJMP(buf) _setjmp(buf) +#define LUAU_LONGJMP(buf, code) _longjmp(buf, code) +#else +#define LUAU_SETJMP(buf) setjmp(buf) +#define LUAU_LONGJMP(buf, code) longjmp(buf, code) +#endif + int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud) { lua_jmpbuf jb; @@ -40,7 +47,7 @@ int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud) jb.status = 0; L->global->errorjmp = &jb; - if (setjmp(jb.buf) == 0) + if (LUAU_SETJMP(jb.buf) == 0) f(L, ud); L->global->errorjmp = jb.prev; @@ -52,7 +59,7 @@ l_noret luaD_throw(lua_State* L, int errcode) if (lua_jmpbuf* jb = L->global->errorjmp) { jb->status = errcode; - longjmp(jb->buf, 1); + LUAU_LONGJMP(jb->buf, 1); } if (L->global->cb.panic) @@ -165,8 +172,8 @@ static void correctstack(lua_State* L, TValue* oldstack) void luaD_reallocstack(lua_State* L, int newsize) { TValue* oldstack = L->stack; - int realsize = newsize + (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK); - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); + int realsize = newsize + EXTRA_STACK; + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat); TValue* newstack = L->stack; for (int i = L->stacksize; i < realsize; i++) @@ -514,7 +521,7 @@ static void callerrfunc(lua_State* L, void* ud) static void restore_stack_limit(lua_State* L) { - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); if (L->size_ci > LUAI_MAXCALLS) { /* there was an overflow? */ int inuse = cast_int(L->ci - L->base_ci); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 1c1480d6..6e16e6f1 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -11,7 +11,7 @@ if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ luaD_growstack(L, n); \ else \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); #define incr_top(L) \ { \ diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index a656854e..8fc930d5 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -11,8 +11,6 @@ #include "lmem.h" #include "ludata.h" -LUAU_FASTFLAGVARIABLE(LuauGcAdditionalStats, false) - #include #define GC_SWEEPMAX 40 @@ -48,7 +46,8 @@ LUAU_FASTFLAGVARIABLE(LuauGcAdditionalStats, false) reallymarkobject(g, obj2gco(t)); \ } -static void recordGcStateTime(global_State* g, int startgcstate, double seconds, bool assist) +#ifdef LUAI_GCMETRICS +static void recordGcStateStep(global_State* g, int startgcstate, double seconds, bool assist, size_t work) { switch (startgcstate) { @@ -56,58 +55,76 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, // record root mark time if we have switched to next state if (g->gcstate == GCSpropagate) { - g->gcstats.currcycle.marktime += seconds; + g->gcmetrics.currcycle.marktime += seconds; - if (FFlag::LuauGcAdditionalStats && assist) - g->gcstats.currcycle.markassisttime += seconds; + if (assist) + g->gcmetrics.currcycle.markassisttime += seconds; } break; case GCSpropagate: case GCSpropagateagain: - g->gcstats.currcycle.marktime += seconds; + g->gcmetrics.currcycle.marktime += seconds; + g->gcmetrics.currcycle.markrequests += g->gcstepsize; - if (FFlag::LuauGcAdditionalStats && assist) - g->gcstats.currcycle.markassisttime += seconds; + if (assist) + g->gcmetrics.currcycle.markassisttime += seconds; break; case GCSatomic: - g->gcstats.currcycle.atomictime += seconds; + g->gcmetrics.currcycle.atomictime += seconds; break; case GCSsweep: - g->gcstats.currcycle.sweeptime += seconds; + g->gcmetrics.currcycle.sweeptime += seconds; + g->gcmetrics.currcycle.sweeprequests += g->gcstepsize; - if (FFlag::LuauGcAdditionalStats && assist) - g->gcstats.currcycle.sweepassisttime += seconds; + if (assist) + g->gcmetrics.currcycle.sweepassisttime += seconds; break; default: LUAU_ASSERT(!"Unexpected GC state"); } if (assist) - g->gcstats.stepassisttimeacc += seconds; + { + g->gcmetrics.stepassisttimeacc += seconds; + g->gcmetrics.currcycle.assistwork += work; + g->gcmetrics.currcycle.assistrequests += g->gcstepsize; + } else - g->gcstats.stepexplicittimeacc += seconds; + { + g->gcmetrics.stepexplicittimeacc += seconds; + g->gcmetrics.currcycle.explicitwork += work; + g->gcmetrics.currcycle.explicitrequests += g->gcstepsize; + } } -static void startGcCycleStats(global_State* g) +static double recordGcDeltaTime(double& timer) { - g->gcstats.currcycle.starttimestamp = lua_clock(); - g->gcstats.currcycle.pausetime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; + double now = lua_clock(); + double delta = now - timer; + timer = now; + return delta; } -static void finishGcCycleStats(global_State* g) +static void startGcCycleMetrics(global_State* g) { - g->gcstats.currcycle.endtimestamp = lua_clock(); - g->gcstats.currcycle.endtotalsizebytes = g->totalbytes; - - g->gcstats.completedcycles++; - g->gcstats.lastcycle = g->gcstats.currcycle; - g->gcstats.currcycle = GCCycleStats(); - - g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime; - g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime; - g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime; + g->gcmetrics.currcycle.starttimestamp = lua_clock(); + g->gcmetrics.currcycle.pausetime = g->gcmetrics.currcycle.starttimestamp - g->gcmetrics.lastcycle.endtimestamp; } +static void finishGcCycleMetrics(global_State* g) +{ + g->gcmetrics.currcycle.endtimestamp = lua_clock(); + g->gcmetrics.currcycle.endtotalsizebytes = g->totalbytes; + + g->gcmetrics.completedcycles++; + g->gcmetrics.lastcycle = g->gcmetrics.currcycle; + g->gcmetrics.currcycle = GCCycleMetrics(); + + g->gcmetrics.currcycle.starttotalsizebytes = g->totalbytes; + g->gcmetrics.currcycle.heaptriggersizebytes = g->GCthreshold; +} +#endif + static void removeentry(LuaNode* n) { LUAU_ASSERT(ttisnil(gval(n))); @@ -598,20 +615,19 @@ static size_t atomic(lua_State* L) LUAU_ASSERT(g->gcstate == GCSatomic); size_t work = 0; + +#ifdef LUAI_GCMETRICS double currts = lua_clock(); - double prevts = currts; +#endif /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ work += propagateall(g); - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimeupval += currts - prevts; - prevts = currts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); +#endif /* remark weak tables */ g->gray = g->weak; @@ -621,34 +637,26 @@ static size_t atomic(lua_State* L) markmt(g); /* mark basic metatables (again) */ work += propagateall(g); - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimeweak += currts - prevts; - prevts = currts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeweak += recordGcDeltaTime(currts); +#endif /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; work += propagateall(g); - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimegray += currts - prevts; - prevts = currts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimegray += recordGcDeltaTime(currts); +#endif /* remove collected objects from weak tables */ work += cleartable(L, g->weak); g->weak = NULL; - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimeclear += currts - prevts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeclear += recordGcDeltaTime(currts); +#endif /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); @@ -742,8 +750,9 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) { - if (FFlag::LuauGcAdditionalStats) - g->gcstats.currcycle.propagatework = g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork; +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.propagatework = g->gcmetrics.currcycle.explicitwork + g->gcmetrics.currcycle.assistwork; +#endif // perform one iteration over 'gray again' list g->gray = g->grayagain; @@ -762,9 +771,10 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - if (FFlag::LuauGcAdditionalStats) - g->gcstats.currcycle.propagateagainwork = - g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork - g->gcstats.currcycle.propagatework; +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.propagateagainwork = + g->gcmetrics.currcycle.explicitwork + g->gcmetrics.currcycle.assistwork - g->gcmetrics.currcycle.propagatework; +#endif g->gcstate = GCSatomic; } @@ -772,8 +782,13 @@ static size_t gcstep(lua_State* L, size_t limit) } case GCSatomic: { - g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomicstarttimestamp = lua_clock(); + g->gcmetrics.currcycle.atomicstarttotalsizebytes = g->totalbytes; +#endif + + g->gcstats.atomicstarttimestamp = lua_clock(); + g->gcstats.atomicstarttotalsizebytes = g->totalbytes; cost = atomic(L); /* finish mark phase */ @@ -809,18 +824,20 @@ static size_t gcstep(lua_State* L, size_t limit) return cost; } -static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) +static int64_t getheaptriggererroroffset(global_State* g) { // adjust for error using Proportional-Integral controller // https://en.wikipedia.org/wiki/PID_controller - int32_t errorKb = int32_t((cyclestats->atomicstarttotalsizebytes - cyclestats->heapgoalsizebytes) / 1024); + int32_t errorKb = int32_t((g->gcstats.atomicstarttotalsizebytes - g->gcstats.heapgoalsizebytes) / 1024); // we use sliding window for the error integral to avoid error sum 'windup' when the desired target cannot be reached - int32_t* slot = &triggerstats->terms[triggerstats->termpos % triggerstats->termcount]; + const size_t triggertermcount = sizeof(g->gcstats.triggerterms) / sizeof(g->gcstats.triggerterms[0]); + + int32_t* slot = &g->gcstats.triggerterms[g->gcstats.triggertermpos % triggertermcount]; int32_t prev = *slot; *slot = errorKb; - triggerstats->integral += errorKb - prev; - triggerstats->termpos++; + g->gcstats.triggerintegral += errorKb - prev; + g->gcstats.triggertermpos++; // controller tuning // https://en.wikipedia.org/wiki/Ziegler%E2%80%93Nichols_method @@ -832,7 +849,7 @@ static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCyc const double Ki = 0.54 * Ku / Ti; // integral gain double proportionalTerm = Kp * errorKb; - double integralTerm = Ki * triggerstats->integral; + double integralTerm = Ki * g->gcstats.triggerintegral; double totalTerm = proportionalTerm + integralTerm; @@ -841,23 +858,20 @@ static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCyc static size_t getheaptrigger(global_State* g, size_t heapgoal) { - GCCycleStats* lastcycle = &g->gcstats.lastcycle; - GCCycleStats* currcycle = &g->gcstats.currcycle; - // adjust threshold based on a guess of how many bytes will be allocated between the cycle start and sweep phase // our goal is to begin the sweep when used memory has reached the heap goal const double durationthreshold = 1e-3; - double allocationduration = currcycle->atomicstarttimestamp - lastcycle->endtimestamp; + double allocationduration = g->gcstats.atomicstarttimestamp - g->gcstats.endtimestamp; // avoid measuring intervals smaller than 1ms if (allocationduration < durationthreshold) return heapgoal; - double allocationrate = (currcycle->atomicstarttotalsizebytes - lastcycle->endtotalsizebytes) / allocationduration; - double markduration = currcycle->atomicstarttimestamp - currcycle->starttimestamp; + double allocationrate = (g->gcstats.atomicstarttotalsizebytes - g->gcstats.endtotalsizebytes) / allocationduration; + double markduration = g->gcstats.atomicstarttimestamp - g->gcstats.starttimestamp; int64_t expectedgrowth = int64_t(markduration * allocationrate); - int64_t offset = getheaptriggererroroffset(&g->gcstats.triggerstats, currcycle); + int64_t offset = getheaptriggererroroffset(g); int64_t heaptrigger = heapgoal - (expectedgrowth + offset); // clamp the trigger between memory use at the end of the cycle and the heap goal @@ -868,11 +882,6 @@ void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - if (assist) - g->gcstats.currcycle.assistrequests += g->gcstepsize; - else - g->gcstats.currcycle.explicitrequests += g->gcstepsize; - int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -881,24 +890,23 @@ void luaC_step(lua_State* L, bool assist) // at the start of the new cycle if (g->gcstate == GCSpause) - startGcCycleStats(g); + g->gcstats.starttimestamp = lua_clock(); + +#ifdef LUAI_GCMETRICS + if (g->gcstate == GCSpause) + startGcCycleMetrics(g); + + double lasttimestamp = lua_clock(); +#endif int lastgcstate = g->gcstate; - double lasttimestamp = lua_clock(); size_t work = gcstep(L, lim); + (void)work; - if (assist) - g->gcstats.currcycle.assistwork += work; - else - g->gcstats.currcycle.explicitwork += work; - - recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); - - if (lastgcstate == GCSpropagate) - g->gcstats.currcycle.markrequests += g->gcstepsize; - else if (lastgcstate == GCSsweep) - g->gcstats.currcycle.sweeprequests += g->gcstepsize; +#ifdef LUAI_GCMETRICS + recordGcStateStep(g, lastgcstate, lua_clock() - lasttimestamp, assist, work); +#endif // at the end of the last cycle if (g->gcstate == GCSpause) @@ -909,13 +917,13 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold = heaptrigger; - finishGcCycleStats(g); + g->gcstats.heapgoalsizebytes = heapgoal; + g->gcstats.endtimestamp = lua_clock(); + g->gcstats.endtotalsizebytes = g->totalbytes; - if (FFlag::LuauGcAdditionalStats) - g->gcstats.currcycle.starttotalsizebytes = g->totalbytes; - - g->gcstats.currcycle.heapgoalsizebytes = heapgoal; - g->gcstats.currcycle.heaptriggersizebytes = heaptrigger; +#ifdef LUAI_GCMETRICS + finishGcCycleMetrics(g); +#endif } else { @@ -933,8 +941,10 @@ void luaC_fullgc(lua_State* L) { global_State* g = L->global; +#ifdef LUAI_GCMETRICS if (g->gcstate == GCSpause) - startGcCycleStats(g); + startGcCycleMetrics(g); +#endif if (g->gcstate <= GCSatomic) { @@ -954,11 +964,12 @@ void luaC_fullgc(lua_State* L) gcstep(L, SIZE_MAX); } - finishGcCycleStats(g); +#ifdef LUAI_GCMETRICS + finishGcCycleMetrics(g); + startGcCycleMetrics(g); +#endif /* run a full collection cycle */ - startGcCycleStats(g); - markroot(L); while (g->gcstate != GCSpause) { @@ -980,10 +991,11 @@ void luaC_fullgc(lua_State* L) if (g->GCthreshold < g->totalbytes) g->GCthreshold = g->totalbytes; - finishGcCycleStats(g); + g->gcstats.heapgoalsizebytes = heapgoalsizebytes; - g->gcstats.currcycle.heapgoalsizebytes = heapgoalsizebytes; - g->gcstats.currcycle.heaptriggersizebytes = g->GCthreshold; +#ifdef LUAI_GCMETRICS + finishGcCycleMetrics(g); +#endif } void luaC_barrierupval(lua_State* L, GCObject* v) @@ -1075,21 +1087,21 @@ int64_t luaC_allocationrate(lua_State* L) if (g->gcstate <= GCSatomic) { - double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; + double duration = lua_clock() - g->gcstats.endtimestamp; if (duration < durationthreshold) return -1; - return int64_t((g->totalbytes - g->gcstats.lastcycle.endtotalsizebytes) / duration); + return int64_t((g->totalbytes - g->gcstats.endtotalsizebytes) / duration); } // totalbytes is unstable during the sweep, use the rate measured at the end of mark phase - double duration = g->gcstats.currcycle.atomicstarttimestamp - g->gcstats.lastcycle.endtimestamp; + double duration = g->gcstats.atomicstarttimestamp - g->gcstats.endtimestamp; if (duration < durationthreshold) return -1; - return int64_t((g->gcstats.currcycle.atomicstarttotalsizebytes - g->gcstats.lastcycle.endtotalsizebytes) / duration); + return int64_t((g->gcstats.atomicstarttotalsizebytes - g->gcstats.endtotalsizebytes) / duration); } void luaC_wakethread(lua_State* L) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ebf999b5..dcd070b7 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -82,7 +82,7 @@ #define luaC_checkGC(L) \ { \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); \ if (L->global->totalbytes >= L->global->GCthreshold) \ { \ condhardmemtests(luaC_validate(L), 1); \ diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index d4f3f0a1..fbc6fb1e 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,8 +10,6 @@ #include "ldo.h" #include "ldebug.h" -LUAU_FASTFLAGVARIABLE(LuauReduceStackReallocs, false) - /* ** Main thread combines a thread state and the global state */ @@ -35,7 +33,7 @@ static void stack_init(lua_State* L1, lua_State* L) for (int i = 0; i < BASIC_STACK_SIZE + EXTRA_STACK; i++) setnilvalue(stack + i); /* erase new stack */ L1->top = stack; - L1->stack_last = stack + (L1->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); + L1->stack_last = stack + (L1->stacksize - EXTRA_STACK); /* initialize first ci */ L1->ci->func = L1->top; setnilvalue(L1->top++); /* `function' entry for this `ci' */ @@ -141,30 +139,16 @@ void lua_resetthread(lua_State* L) ci->top = ci->base + LUA_MINSTACK; setnilvalue(ci->func); L->ci = ci; - if (FFlag::LuauReduceStackReallocs) - { - if (L->size_ci != BASIC_CI_SIZE) - luaD_reallocCI(L, BASIC_CI_SIZE); - } - else - { + if (L->size_ci != BASIC_CI_SIZE) luaD_reallocCI(L, BASIC_CI_SIZE); - } /* clear thread state */ L->status = LUA_OK; L->base = L->ci->base; L->top = L->ci->base; L->nCcalls = L->baseCcalls = 0; /* clear thread stack */ - if (FFlag::LuauReduceStackReallocs) - { - if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) - luaD_reallocstack(L, BASIC_STACK_SIZE); - } - else - { + if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) luaD_reallocstack(L, BASIC_STACK_SIZE); - } for (int i = 0; i < L->stacksize; i++) setnilvalue(L->stack + i); } @@ -234,6 +218,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->cb = lua_Callbacks(); g->gcstats = GCStats(); +#ifdef LUAI_GCMETRICS + g->gcmetrics = GCMetrics(); +#endif + if (luaD_rawrunprotected(L, f_luaopen, NULL) != 0) { /* memory allocation error: free partial state */ diff --git a/VM/src/lstate.h b/VM/src/lstate.h index b2bedb48..e7c37373 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -75,10 +75,26 @@ typedef struct CallInfo #define f_isLua(ci) (!ci_func(ci)->isC) #define isLua(ci) (ttisfunction((ci)->func) && f_isLua(ci)) -struct GCCycleStats +struct GCStats +{ + // data for proportional-integral controller of heap trigger value + int32_t triggerterms[32] = {0}; + uint32_t triggertermpos = 0; + int32_t triggerintegral = 0; + + size_t atomicstarttotalsizebytes = 0; + size_t endtotalsizebytes = 0; + size_t heapgoalsizebytes = 0; + + double starttimestamp = 0; + double atomicstarttimestamp = 0; + double endtimestamp = 0; +}; + +#ifdef LUAI_GCMETRICS +struct GCCycleMetrics { size_t starttotalsizebytes = 0; - size_t heapgoalsizebytes = 0; size_t heaptriggersizebytes = 0; double pausetime = 0.0; // time from end of the last cycle to the start of a new one @@ -120,16 +136,7 @@ struct GCCycleStats size_t endtotalsizebytes = 0; }; -// data for proportional-integral controller of heap trigger value -struct GCHeapTriggerStats -{ - static const unsigned termcount = 32; - int32_t terms[termcount] = {0}; - uint32_t termpos = 0; - int32_t integral = 0; -}; - -struct GCStats +struct GCMetrics { double stepexplicittimeacc = 0.0; double stepassisttimeacc = 0.0; @@ -137,14 +144,10 @@ struct GCStats // when cycle is completed, last cycle values are updated uint64_t completedcycles = 0; - GCCycleStats lastcycle; - GCCycleStats currcycle; - - // only step count and their time is accumulated - GCCycleStats cyclestatsacc; - - GCHeapTriggerStats triggerstats; + GCCycleMetrics lastcycle; + GCCycleMetrics currcycle; }; +#endif /* ** `global state', shared by all threads of this state @@ -206,6 +209,9 @@ typedef struct global_State GCStats gcstats; +#ifdef LUAI_GCMETRICS + GCMetrics gcmetrics; +#endif } global_State; // clang-format on diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 2deec2b9..431501f3 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -526,8 +526,8 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) LuaNode* othern; LuaNode* n = getfreepos(t); /* get a free place */ if (n == NULL) - { /* cannot find a free place? */ - rehash(L, t, key); /* grow table */ + { /* cannot find a free place? */ + rehash(L, t, key); /* grow table */ if (!FFlag::LuauTableRehashRework) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 17fd6b13..1db782cc 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2726,11 +2726,6 @@ end TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, - }; - check(R"( --!strict local foo: "hello" | "bye" = "hello" diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9e4cb4a5..83d4518d 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -496,8 +496,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag luauTableFieldFunctionDebugname{"LuauTableFieldFunctionDebugname", true}; - runConformance("debug.lua"); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index dbdd06a4..a7e7ea39 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -132,21 +132,9 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars } CheckResult Fixture::check(Mode mode, std::string source) -{ - configResolver.defaultConfig.mode = mode; - fileResolver.source[mainModuleName] = std::move(source); - - CheckResult result = frontend.check(fromString(mainModuleName)); - - configResolver.defaultConfig.mode = Mode::Strict; - - return result; -} - -CheckResult Fixture::check(const std::string& source) { ModuleName mm = fromString(mainModuleName); - configResolver.defaultConfig.mode = Mode::Strict; + configResolver.defaultConfig.mode = mode; fileResolver.source[mm] = std::move(source); frontend.markDirty(mm); @@ -155,6 +143,11 @@ CheckResult Fixture::check(const std::string& source) return result; } +CheckResult Fixture::check(const std::string& source) +{ + return check(Mode::Strict, source); +} + LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { ParseOptions parseOptions; diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 91b23197..9ce9a4c2 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -597,6 +597,8 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { + ScopedFastFlag sff("LuauLintNoRobloxBits", true); + unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -606,81 +608,26 @@ TEST_CASE_FIXTURE(Fixture, "UnknownType") TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; - ClassTypeVar::Props enumItemProps{ - {"EnumType", {typeChecker.anyType}}, - }; - - ClassTypeVar enumItemClass{"EnumItem", enumItemProps, std::nullopt, std::nullopt, {}, {}}; - TypeId enumItemType = typeChecker.globalTypes.addType(enumItemClass); - TypeFun enumItemTypeFun{{}, enumItemType}; - - ClassTypeVar normalIdClass{"NormalId", {}, enumItemType, std::nullopt, {}, {}}; - TypeId normalIdType = typeChecker.globalTypes.addType(normalIdClass); - TypeFun normalIdTypeFun{{}, normalIdType}; - - // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); - addGlobalBinding(typeChecker, "typeof", typeChecker.anyType, "@test"); typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["Workspace"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["RunService"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["Instance"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["ColorSequence"] = TypeFun{{}, typeChecker.anyType}; - typeChecker.globalScope->exportedTypeBindings["EnumItem"] = enumItemTypeFun; - typeChecker.globalScope->importedTypeBindings["Enum"] = {{"NormalId", normalIdTypeFun}}; - freeze(typeChecker.globalTypes); LintResult result = lint(R"( -local _e01 = game:GetService("Foo") -local _e02 = game:GetService("NormalId") -local _e03 = game:FindService("table") -local _e04 = type(game) == "Part" -local _e05 = type(game) == "NormalId" -local _e06 = typeof(game) == "Bar" -local _e07 = typeof(game) == "Part" -local _e08 = typeof(game) == "vector" -local _e09 = typeof(game) == "NormalId" -local _e10 = game:IsA("ColorSequence") -local _e11 = game:IsA("Enum.NormalId") -local _e12 = game:FindFirstChildWhichIsA("function") +local game = ... +local _e01 = type(game) == "Part" +local _e02 = typeof(game) == "Bar" +local _e03 = typeof(game) == "vector" -local _o01 = game:GetService("Workspace") -local _o02 = game:FindService("RunService") -local _o03 = type(game) == "number" -local _o04 = type(game) == "vector" -local _o05 = typeof(game) == "string" -local _o06 = typeof(game) == "Instance" -local _o07 = typeof(game) == "EnumItem" -local _o08 = game:IsA("Part") -local _o09 = game:IsA("NormalId") -local _o10 = game:FindFirstChildWhichIsA("Part") +local _o01 = type(game) == "number" +local _o02 = type(game) == "vector" +local _o03 = typeof(game) == "Part" )"); - REQUIRE_EQ(result.warnings.size(), 12); - CHECK_EQ(result.warnings[0].location.begin.line, 1); - CHECK_EQ(result.warnings[0].text, "Unknown type 'Foo'"); - CHECK_EQ(result.warnings[1].location.begin.line, 2); - CHECK_EQ(result.warnings[1].text, "Unknown type 'NormalId' (expected class type)"); - CHECK_EQ(result.warnings[2].location.begin.line, 3); - CHECK_EQ(result.warnings[2].text, "Unknown type 'table' (expected class type)"); - CHECK_EQ(result.warnings[3].location.begin.line, 4); - CHECK_EQ(result.warnings[3].text, "Unknown type 'Part' (expected primitive type)"); - CHECK_EQ(result.warnings[4].location.begin.line, 5); - CHECK_EQ(result.warnings[4].text, "Unknown type 'NormalId' (expected primitive type)"); - CHECK_EQ(result.warnings[5].location.begin.line, 6); - CHECK_EQ(result.warnings[5].text, "Unknown type 'Bar'"); - CHECK_EQ(result.warnings[6].location.begin.line, 7); - CHECK_EQ(result.warnings[6].text, "Unknown type 'Part' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[7].location.begin.line, 8); - CHECK_EQ(result.warnings[7].text, "Unknown type 'vector' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[8].location.begin.line, 9); - CHECK_EQ(result.warnings[8].text, "Unknown type 'NormalId' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[9].location.begin.line, 10); - CHECK_EQ(result.warnings[9].text, "Unknown type 'ColorSequence' (expected class or enum type)"); - CHECK_EQ(result.warnings[10].location.begin.line, 11); - CHECK_EQ(result.warnings[10].text, "Unknown type 'Enum.NormalId'"); - CHECK_EQ(result.warnings[11].location.begin.line, 12); - CHECK_EQ(result.warnings[11].text, "Unknown type 'function' (expected class type)"); + REQUIRE_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings[0].location.begin.line, 2); + CHECK_EQ(result.warnings[0].text, "Unknown type 'Part' (expected primitive type)"); + CHECK_EQ(result.warnings[1].location.begin.line, 3); + CHECK_EQ(result.warnings[1].text, "Unknown type 'Bar'"); + CHECK_EQ(result.warnings[2].location.begin.line, 4); + CHECK_EQ(result.warnings[2].text, "Unknown type 'vector' (expected primitive or userdata type)"); } TEST_CASE_FIXTURE(Fixture, "ForRangeTable") diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 6713a589..3051e209 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -470,6 +470,7 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function id(x) return x end )"); @@ -482,6 +483,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function map(arr, fn) local t = {} @@ -500,6 +502,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... @@ -516,6 +519,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; TypePackVar empty{TypePack{}}; FunctionTypeVar ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); @@ -523,6 +527,7 @@ TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: a, ...): (a, a, b...) return x, x, ... @@ -537,6 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): ...number return 1, 2, 3 @@ -551,6 +557,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): (string, ...number) return 'a', 1, 2, 3 @@ -565,6 +572,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local f: (number, y: number) -> number )"); @@ -577,6 +585,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: T, g: (T) -> U)): () end @@ -590,4 +599,20 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") +{ + ScopedFastFlag flag{"LuauDocFuncParameters", true}; + + CheckResult result = check(R"( + local function test(a, b : string, ... : number) return a end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + ToStringOptions opts; + opts.namedFunctionOverrideArgNames = {"first", "second", "third"}; + CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 5f0295b0..5ac45ff2 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -651,8 +651,6 @@ local a: Packed TEST_CASE_FIXTURE(Fixture, "transpile_singleton_types") { - ScopedFastFlag luauParseSingletonTypes{"LuauParseSingletonTypes", true}; - std::string code = R"( type t1 = 'hello' type t2 = true diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index ec20a2c7..c6fbebed 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -887,8 +887,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauAssertStripsFalsyTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 4288098a..da4ea074 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1335,4 +1335,80 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") +{ + ScopedFastFlag sff{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; + CheckResult result = check(R"( + function test(a: number, b: string, ...) + end + + test(1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") +{ + ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; + ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; + CheckResult result = check(R"( +function test(a: number, b: string, ...) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(3, acm->expected); + CHECK_EQ(1, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") +{ + ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; + ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; + CheckResult result = check(R"( +function test(a: number, b: string, ...) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +pcall(wrapper, test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(4, acm->expected); + CHECK_EQ(2, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 63643610..e5eeae31 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauTableSubtypingVariance2) + TEST_SUITE_BEGIN("TypeInferModules"); TEST_CASE_FIXTURE(Fixture, "require") @@ -268,8 +270,6 @@ function x:Destroy(): () end TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") { - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - fileResolver.source["game/A"] = R"( export type Type = { x: { a: number } } return {} @@ -288,8 +288,6 @@ type Rename = typeof(x.x) TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") { - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - fileResolver.source["game/A"] = R"( local y = setmetatable({}, {}) export type Type = { x: typeof(y) } @@ -307,4 +305,83 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "module_type_conflict") +{ + ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; + + fileResolver.source["game/A"] = R"( +export type T = { x: number } +return {} + )"; + + fileResolver.source["game/B"] = R"( +export type T = { x: string } +return {} + )"; + + fileResolver.source["game/C"] = R"( +local A = require(game.A) +local B = require(game.B) +local a: A.T = { x = 2 } +local b: B.T = a + )"; + + CheckResult result = frontend.check("game/C"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTableSubtypingVariance2) + { + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); + } + else + { + CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") +{ + ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; + + fileResolver.source["game/A"] = R"( +export type Wrap = { x: T } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local A = require(game.A) +export type T = A.Wrap +return {} + )"; + + fileResolver.source["game/C"] = R"( +local A = require(game.A) +export type T = A.Wrap +return {} + )"; + + fileResolver.source["game/D"] = R"( +local A = require(game.B) +local B = require(game.C) +local a: A.T = { x = 2 } +local b: B.T = a + )"; + + CheckResult result = frontend.check("game/D"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTableSubtypingVariance2) + { + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); + } + else + { + CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'"); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index baa25978..6a8a9d93 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -756,4 +756,30 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") CHECK_EQ("number", toString(requireType("u"))); } +TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") +{ + ScopedFastFlag sff{"LuauDecoupleOperatorInferenceFromUnifiedTypeInference", true}; + + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x + y + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + + result = check(Mode::Nonstrict, R"( + local function f(x, y) + return x + y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // When type inference is unified, we could add an assertion that + // the strict and nonstrict types are equivalent. This isn't actually + // the case right now, though. +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 9b347921..cddeab6e 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -435,7 +435,6 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -1002,8 +1001,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -1028,8 +1025,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -1066,8 +1061,6 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1091,8 +1084,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1134,8 +1125,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 7f8d8fec..d39341ea 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -13,11 +13,6 @@ TEST_SUITE_BEGIN("TypeSingletons"); TEST_CASE_FIXTURE(Fixture, "bool_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: true = true local b: false = false @@ -28,11 +23,6 @@ TEST_CASE_FIXTURE(Fixture, "bool_singletons") TEST_CASE_FIXTURE(Fixture, "string_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "foo" = "foo" local b: "bar" = "bar" @@ -43,11 +33,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons") TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: true = false )"); @@ -58,11 +43,6 @@ TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "foo" = "bar" )"); @@ -73,11 +53,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "\n" = "\000\r" )"); @@ -88,11 +63,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: true = true local b: boolean = a @@ -103,11 +73,6 @@ TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "foo" = "foo" local b: string = a @@ -118,11 +83,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a: true, b: "foo") end f(true, "foo") @@ -133,11 +93,6 @@ TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a: true, b: "foo") end f(true, "bar") @@ -149,11 +104,6 @@ TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a, b) end local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) @@ -166,11 +116,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a, b) end local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) @@ -184,11 +129,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "foo" @@ -201,11 +141,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" @@ -218,11 +153,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type MyEnum1 = "foo" | "bar" type MyEnum2 = MyEnum1 | "baz" @@ -237,8 +167,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -257,11 +185,6 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -274,11 +197,6 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -292,10 +210,6 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( --!strict type T = { @@ -320,10 +234,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") } TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( --!strict type T = { @@ -341,10 +251,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( --!strict type S = "bar" @@ -367,8 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, + ScopedFastFlag sffs[]{ {"LuauUnsealedTableLiteral", true}, }; @@ -386,8 +291,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -409,8 +312,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -432,8 +333,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -451,7 +350,6 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { ScopedFastFlag sff[]{ - {"LuauSingletonTypes", true}, {"LuauEqConstraint", true}, {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, @@ -477,8 +375,6 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauEqConstraint", true}, {"LuauWidenIfSupertypeIsFree2", true}, @@ -504,8 +400,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -521,8 +415,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -551,8 +443,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -577,8 +467,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -595,7 +483,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -614,7 +501,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -633,7 +519,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -652,7 +537,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 91140aaa..0cc12d19 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2078,6 +2078,44 @@ caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") +{ + ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; + + CheckResult result = check(R"( + type A = { [number]: string } + type B = { [string]: string } + + local a: A = { 'a', 'b' } + local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") +{ + ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; + + CheckResult result = check(R"( + type A = { [number]: number } + type B = { [number]: string } + + local a: A = { 1, 2, 3 } + local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string')"); +} + TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { ScopedFastFlag sffs[]{ diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 571d0f8d..660ddcfc 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -334,7 +334,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") #if defined(LUAU_ENABLE_ASAN) int limit = 250; #elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; + int limit = 300; #else int limit = 600; #endif diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ad1e31e5..68b7c4fb 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -296,6 +296,7 @@ return f() REQUIRE(acm); CHECK_EQ(1, acm->expected); CHECK_EQ(0, acm->actual); + CHECK_FALSE(acm->isVariadic); } TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") From 75bccce3db6eb63404a9bd896f99d8ab2ab8aafe Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Tue, 29 Mar 2022 12:37:14 -0700 Subject: [PATCH 033/102] RFC: Lower Bounds Calculation (#388) Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> --- rfcs/lower-bounds-calculation.md | 217 +++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 rfcs/lower-bounds-calculation.md diff --git a/rfcs/lower-bounds-calculation.md b/rfcs/lower-bounds-calculation.md new file mode 100644 index 00000000..a1793884 --- /dev/null +++ b/rfcs/lower-bounds-calculation.md @@ -0,0 +1,217 @@ +# Lower Bounds Calculation + +## Summary + +We propose adapting lower bounds calculation from Pierce's Local Type Inference paper into the Luau type inference algorithm. + +https://www.cis.upenn.edu/~bcpierce/papers/lti-toplas.pdf + +## Motivation + +There are a number of important scenarios that occur where Luau cannot infer a sensible type without annotations. + +Many of these revolve around type variables that occur in contravariant positions. + +### Function Return Types + +A very common thing to write in Luau is a function to try to find something in some data structure. These functions habitually return the relevant datum when it is successfully found, or `nil` in the case that it cannot. For instance: + +```lua +-- A.lua +function find_first_if(vec, f) + for i, e in ipairs(vec) do + if f(e) then + return i + end + end + + return nil +end +``` + +This function has two `return` statements: One returns `number` and the other `nil`. Today, Luau flags this as an error. We ask authors to add a return annotation to make this error go away. + +We would like to automatically infer `find_first_if : ({T}, (T) -> boolean) -> number?`. + +Higher order functions also present a similar problem. + +```lua +-- B.lua +function foo(f) + f(5) + f("string") +end +``` + +There is nothing wrong with the implementation of `foo` here, but Luau fails to typecheck it all the same because `f` is used in an inconsistent way. This too can be worked around by introducing a type annotation for `f`. + +The fact that the return type of `f` is never used confounds things a little, but for now it would be a big improvement if we inferred `f : ((number | string) -> T...) -> ()`. + +## Design + +We introduce a new kind of TypeVar, `ConstrainedTypeVar` to represent a TypeVar whose lower bounds are known. We will never expose syntax for a user to write these types: They only temporarily exist as type inference is being performed. + +When unifying some type with a `ConstrainedTypeVar` we _broaden_ the set of constraints that can be placed upon it. + +It may help to realize that what we have been doing up until now has been _upper bounds calculation_. + +When we `quantify` a function, we will _normalize_ each type and convert each `ConstrainedTypeVar` into a `UnionTypeVar`. + +### Normalization + +When computing lower bounds, we need to have some process by which we reduce types down to a minimal shape and canonicalize them, if only to have a clean way to flush out degenerate unions like `A | A`. Normalization is about reducing union and intersection types to a minimal, canonicalizable shape. + +A normalized union is one where there do not exist two branches on the union where one is a subtype of the other. It is quite straightforward to implement. + +A normalized intersection is a little bit more complicated: + +1. The tables of an intersection are always combined into a single table. Coincident properties are merged into intersections of their own. + * eg `normalize({x: number, y: string} & {y: number, z: number}) == {x: number, y: string & number, z: number}` + * This is recursive. eg `normalize({x: {y: number}} & {x: {y: string}}) == {x: {y: number & string}}` +1. If two functions in the intersection have a subtyping relationship, the normalization results only in the super-type-most function. (more on function subtyping later) + +### Function subtyping relationships + +If we are going to infer intersections of functions, then we need to be very careful about keeping combinatorics under control. We therefore need to be very deliberate about what subtyping rules we have for functions of differing arity. We have some important requirements: + +* We'd like some way to canonicalize intersections of functions, and yet +* optional function arguments are a great feature that we don't want to break + +A very important use case for us is the case where the user is providing a callback to some higher-order function, and that function will be invoked with extra arguments that the original customer doesn't actually care about. For example: + +```lua +-- C.lua +function map_array(arr, f) + local result = {} + for i, e in ipairs(arr) do + table.insert(result, f(e, i, arr)) + end + return result +end + +local example = {1, 2, 3, 4} +local example_result = map_array(example, function(i) return i * 2 end) +``` + +This function mirrors the actual `Array.map` function in JavaScript. It is very frequent for users of this function to provide a lambda that only accepts one argument. It would be annoying for callers to be forced to provide a lambda that accepts two unused arguments. This obviously becomes even worse if the function later changes to provide yet more optional information to the callback. + +This use case is very important for Roblox, as we have many APIs that accept callbacks. Implementors of those callbacks frequently omit arguments that they don't care about. + +Here is an example straight out of the Roblox developer documentation. ([full example here](https://developer.roblox.com/en-us/api-reference/event/BasePart/Touched)) + +```lua +-- D.lua +local part = script.Parent + +local function blink() + -- ... +end + +part.Touched:Connect(blink) +``` + +The `Touched` event actually passes a single argument: the part that touched the `Instance` in question. In this example, it is omitted from the callback handler. + +We therefore want _oversaturation_ of a function to be allowed, but this combines with optional function arguments to create a problem with soundness. Consider the following: + +```lua +-- E.lua +type Callback = (Instance) -> () + +local cb: Callback +function register_callback(c: Callback) + cb = c +end + +function invoke_callback(i: Instance) + cb(i) +end + +--- + +function bad_callback(x: number?) +end + +local obscured: () -> () = bad_callback + +register_callback(obscured) + +function good_callback() +end + +register_callback(good_callback) +``` + +The problem we run into is, if we allow the subtyping rule `(T?) -> () <: () -> ()` and also allow oversaturation of a function, it becomes easy to obscure an argument type and pass the wrong type of value to it. + +Next, consider the following type alias + +```lua +-- F.lua +type OldFunctionType = (any, any) -> any +type NewFunctionType = (any) -> any +type FunctionType = OldFunctionType & NewFunctionType +``` + +If we have a subtyping rule `(T0..TN) <: (T0..TN-1)` to permit the function subtyping relationship `(T0..TN-1) -> R <: (T0..TN) -> R`, then the above type alias normalizes to `(any) -> any`. In order to call the two-argument variation, we would need to permit oversaturation, which runs afoul of the soundness hole from the previous example. + +We need a solution here. + +To resolve this, let's reframe things in simpler terms: + +If there is never a subtyping relationship between packs of different length, then we don't have any soundness issues, but we find ourselves unable to register `good_callback`. + +To resolve _that_, consider that we are in truth being a bit hasty when we say `good_callback : () -> ()`. We can pass any number of arguments to this function safely. We could choose to type `good_callback : () -> () & (any) -> () & (any, any) -> () & ...`. Luau already has syntax for this particular sort of infinite intersection: `good_callback : (any...) -> ()`. + +So, we propose some different inference rules for functions: + +1. The AST fragment `function(arg0..argN) ... end` is typed `(T0..TN, any...) -> R` where `arg0..argN : T0..TN` and `R` is the inferred return type of the function body. Function statements are inferred the same way. +1. Type annotations are unchanged. `() -> ()` is still a nullary function. + +For reference, the subtyping rules for unions and functions are unchanged. We include them here for clarity. + +1. `A <: A | B` +1. `B <: A | B` +1. `A | B <: T` if `A <: T` or `B <: T` +1. `T -> R <: U -> S` if `U <: T` and `R <: S` + +We propose new subtyping rules for type packs: + +1. `(T0..TN) <: (U0..UN)` if, for each `T` and `U`, `T <: U` +1. `(U...)` is the same as `() | (U) | (U, U) | (U, U, U) | ...`, therefore +1. `(T0..TN) <: (U...)` if for each `T`, `T <: U`, therefore +1. `(U...) -> R <: (T0..TN) -> R` if for each `T`, `T <: U` + +The important difference is that we remove all subtyping rules that mention options. Functions of different arities are no longer considered subtypes of one another. Optional function arguments are still allowed, but function as a feature of function calls. + +Under these rules, functions of different arities can never be converted to one another, but actual functions are known to be safe to oversaturate with anything, and so gain a type that says so. + +Under these subtyping rules, snippets `C.lua` and `D.lua`, check the way we want: literal functions are implicitly safe to oversaturate, so it is fine to cast them as the necessary callback function type. + +`E.lua` also typechecks the way we need it to: `(Instance) -> () ()` and so `obscured` cannot receive the value `bad_callback`, which prevents it from being passed to `register_callback`. However, `good_callback : (any...) -> ()` and `(any...) -> () <: (Instance) -> ()` and so it is safe to register `good_callback`. + +Snippet `F.lua` is also fixed with this ruleset: There is no subtyping relationship between `(any) -> ()` and `(any, any) -> ()`, so the intersection is not combined under normalization. + +This works, but itself creates some small problems that we need to resolve: + +First, the `...` symbol still needs to be unavailable for functions that have been given this implicit `...any` type. This is actually taken care of in the Luau parser, so no code change is required. + +Secondly, we do not want to silently allow oversaturation of direct calls to a function if we know that the arguments will be ignored. We need to treat these variadic packs differently when unifying for function calls. + +Thirdly, we don't want to display this variadic in the signature if the author doesn't expect to see it. + +We solve these issues by adding a property `bool VariadicTypePack::hidden` to the implementation and switching on it in the above scenarios. The implementation is relatively straightforward for all 3 cases. + +## Drawbacks + +There is a potential cause for concern that we will be inferring unions of functions in cases where we previously did not. Unions are known to be potential sources of performance issues. One possibility is to allow Luau to be less intelligent and have it "give up" and produce less precise types. This would come at the cost of accuracy and soundness. + +If we allow functions to be oversaturated, we are going to miss out on opportunities to warn the user about legitimate problems with their program. I think we will have to work out some kind of special logic to detect when we are oversaturating a function whose exact definition is known and warn on that. + +Allowing indirect function calls to be oversaturated with `nil` values only should be safe, but a little bit unfortunate. As long as we statically know for certain that `nil` is actually a permissible value for that argument position, it should be safe. + +## Alternatives + +If we are willing to sacrifice soundness, we could adopt success typing and come up with an inference algorithm that produces less precise type information. + +We could also technically choose to do nothing, but this has some unpalatable consequences: Something I would like to do in the near future is to have the inference algorithm assume the same `self` type for all methods of a table. This will make inference of common OO patterns dramatically more intuitive and ergonomic, but inference of polymorphic methods requires some kind of lower bounds calculation to work correctly. From af64680a5ef527ae59adbf24fd8581e2c80575ad Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 29 Mar 2022 16:58:59 -0700 Subject: [PATCH 034/102] Mark singleton types and unsealed table literals RFCs as implemented (#438) --- rfcs/STATUS.md | 13 ++++++------- rfcs/function-bit32-countlz-countrz.md | 4 ++-- rfcs/function-coroutine-close.md | 4 ++-- rfcs/syntax-safe-navigation-operator.md | 2 ++ rfcs/syntax-singleton-types.md | 2 ++ rfcs/syntax-type-ascription-bidi.md | 4 ++-- rfcs/unsealed-table-assign-optional-property.md | 2 ++ rfcs/unsealed-table-literals.md | 4 +++- 8 files changed, 21 insertions(+), 14 deletions(-) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index 93a09ece..e3e227a0 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -17,17 +17,10 @@ This document tracks unimplemented RFCs. ## Sealed/unsealed typing changes -[RFC: Unsealed table literals](https://github.com/Roblox/luau/blob/master/rfcs/unsealed-table-literals.md) | [RFC: Only strip optional properties from unsealed tables during subtyping](https://github.com/Roblox/luau/blob/master/rfcs/unsealed-table-subtyping-strips-optional-properties.md) **Status**: Implemented but not fully rolled out yet. -## Singleton types - -[RFC: Singleton types](https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md) - -**Status**: Implemented but not fully rolled out yet. - ## Safe navigation operator [RFC: Safe navigation postfix operator (?)](https://github.com/Roblox/luau/blob/master/rfcs/syntax-safe-navigation-operator.md) @@ -47,3 +40,9 @@ This document tracks unimplemented RFCs. [RFC: Generalized iteration](https://github.com/Roblox/luau/blob/master/rfcs/generalized-iteration.md) **Status**: Needs implementation + +## Lower Bounds Calculation + +[RFC: Lower bounds calculation](https://github.com/Roblox/luau/blob/master/rfcs/lower-bounds-calculation.md) + +**Status**: Needs implementation diff --git a/rfcs/function-bit32-countlz-countrz.md b/rfcs/function-bit32-countlz-countrz.md index d2439f72..b4ccb197 100644 --- a/rfcs/function-bit32-countlz-countrz.md +++ b/rfcs/function-bit32-countlz-countrz.md @@ -1,11 +1,11 @@ # bit32.countlz/countrz +**Status**: Implemented + ## Summary Add bit32.countlz (count left zeroes) and bit32.countrz (count right zeroes) to accelerate bit scanning -**Status**: Implemented - ## Motivation All CPUs have instructions to determine the position of first/last set bit in an integer. These instructions have a variety of uses, the popular ones being: diff --git a/rfcs/function-coroutine-close.md b/rfcs/function-coroutine-close.md index 6def1533..b9ffbf6f 100644 --- a/rfcs/function-coroutine-close.md +++ b/rfcs/function-coroutine-close.md @@ -1,11 +1,11 @@ # coroutine.close +**Status**: Implemented + ## Summary Add `coroutine.close` function from Lua 5.4 that takes a suspended coroutine and makes it "dead" (non-runnable). -**Status**: Implemented - ## Motivation When implementing various higher level objects on top of coroutines, such as promises, it can be useful to cancel the coroutine execution externally - when the caller is not diff --git a/rfcs/syntax-safe-navigation-operator.md b/rfcs/syntax-safe-navigation-operator.md index c98f3957..11c4b37f 100644 --- a/rfcs/syntax-safe-navigation-operator.md +++ b/rfcs/syntax-safe-navigation-operator.md @@ -1,5 +1,7 @@ # Safe navigation postfix operator (?) +**Note**: We have unresolved issues with interaction between this feature and Roblox instance hierarchy. This may affect the viability of this proposal. + ## Summary Introduce syntax to navigate through `nil` values, or short-circuit with `nil` if it was encountered. diff --git a/rfcs/syntax-singleton-types.md b/rfcs/syntax-singleton-types.md index 26ea3028..2c1f5442 100644 --- a/rfcs/syntax-singleton-types.md +++ b/rfcs/syntax-singleton-types.md @@ -2,6 +2,8 @@ > Note: this RFC was adapted from an internal proposal that predates RFC process +**Status**: Implemented + ## Summary Introduce a new kind of type variable, called singleton types. They are just like normal types but has the capability to represent a constant runtime value as a type. diff --git a/rfcs/syntax-type-ascription-bidi.md b/rfcs/syntax-type-ascription-bidi.md index bf37eca2..0831aba5 100644 --- a/rfcs/syntax-type-ascription-bidi.md +++ b/rfcs/syntax-type-ascription-bidi.md @@ -1,11 +1,11 @@ # Relaxing type assertions +**Status**: Implemented + ## Summary The way `::` works today is really strange. The best solution we can come up with is to allow `::` to convert between any two related types. -**Status**: Implemented - ## Motivation Due to an accident of the implementation, the Luau `::` operator can only be used for downcasts and casts to `any`. diff --git a/rfcs/unsealed-table-assign-optional-property.md b/rfcs/unsealed-table-assign-optional-property.md index ed037b14..477399c2 100644 --- a/rfcs/unsealed-table-assign-optional-property.md +++ b/rfcs/unsealed-table-assign-optional-property.md @@ -1,5 +1,7 @@ # Unsealed table assignment creates an optional property +**Status**: Implemented + ## Summary In Luau, tables have a state, which can, among others, be "unsealed". diff --git a/rfcs/unsealed-table-literals.md b/rfcs/unsealed-table-literals.md index 320bf7ca..669b67d4 100644 --- a/rfcs/unsealed-table-literals.md +++ b/rfcs/unsealed-table-literals.md @@ -1,5 +1,7 @@ # Unsealed table literals +**Status**: Implemented + ## Summary Currently the only way to create an unsealed table is as an empty table literal `{}`. @@ -73,4 +75,4 @@ We could introduce a new table state for unsealed-but-precise tables. The trade-off is that that would be more precise, at the cost of adding user-visible complexity to the type system. -We could continue to treat array-like tables as sealed. \ No newline at end of file +We could continue to treat array-like tables as sealed. From f3ea2f96f736f646e07f03bec6797a0a41021022 Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Wed, 30 Mar 2022 18:38:55 -0500 Subject: [PATCH 035/102] Recap March 2022 (#439) * March Recap Co-authored-by: Arseny Kapoulkine --- .../2022-03-31-luau-recap-march-2022.md | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 docs/_posts/2022-03-31-luau-recap-march-2022.md diff --git a/docs/_posts/2022-03-31-luau-recap-march-2022.md b/docs/_posts/2022-03-31-luau-recap-march-2022.md new file mode 100644 index 00000000..8ac88732 --- /dev/null +++ b/docs/_posts/2022-03-31-luau-recap-march-2022.md @@ -0,0 +1,109 @@ +--- +layout: single +title: "Luau Recap: March 2022" +--- + +Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-march-2022/).] + +## Singleton types + +We added support for singleton types! These allow you to use string or +boolean literals in types. These types are only inhabited by the +literal, for example if a variable `x` has type `"foo"`, then `x == +"foo"` is guaranteed to be true. + +Singleton types are particularly useful when combined with union types, +for example: + +```lua +type Animals = "Dog" | "Cat" | "Bird" +``` + +or: + +```lua +type Falsey = false | nil +``` + +In particular, singleton types play well with unions of tables, +allowing tagged unions (also known as discriminated unions): + +```lua +type Ok = { type: "ok", value: T } +type Err = { type: "error", error: E } +type Result = Ok | Err + +local result: Result = ... +if result.type == "ok" then + -- result :: Ok + print(result.value) +else + -- result :: Err + error(result.error) +end +``` + +The RFC for singleton types is https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md + +## Width subtyping + +A common idiom for programming with tables is to provide a public interface type, but to keep some of the concrete implementation private, for example: + +```lua +type Interface = { + name: string, +} + +type Concrete = { + name: string, + id: number, +} +``` + +Within a module, a developer might use the concrete type, but export functions using the interface type: + +```lua +local x: Concrete = { + name = "foo", + id = 123, +} + +local function get(): Interface + return x +end +``` + +Previously examples like this did not typecheck but now they do! + +This language feature is called *width subtyping* (it allows tables to get *wider*, that is to have more properties). + +The RFC for width subtyping is https://github.com/Roblox/luau/blob/master/rfcs/sealed-table-subtyping.md + +## Typechecking improvements + + * Generic function type inference now works the same for generic types and generic type packs. + * We improved some error messages. + * There are now fewer crashes (hopefully none!) due to mutating types inside the Luau typechecker. + * We fixed a bug that could cause two incompatible copies of the same class to be created. + * Luau now copes better with cyclic metatable types (it gives a type error rather than hanging). + * Fixed a case where types are not properly bound to all of the subtype when the subtype is a union. + * We fixed a bug that confused union and intersection types of table properties. + * Functions declared as `function f(x : any)` can now be called as `f()` without a type error. + +## API improvements + + * Implement `table.clone` which takes a table and returns a new table that has the same keys/values/metatable. The cloning is shallow - if some keys refer to tables that need to be cloned, that can be done manually by modifying the resulting table. + +## Debugger improvements + + * Use the property name as the name of methods in the debugger. + +## Performance improvements + + * Optimize table rehashing (~15% faster dictionary table resize on average) + * Improve performance of freeing tables (~5% lift on some GC benchmarks) + * Improve gathering performance metrics for GC. + * Reduce stack memory reallocation. + From 06bbfd90b5d8496d3f63b082ce6dd26b7849e6e3 Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Thu, 31 Mar 2022 09:31:06 -0500 Subject: [PATCH 036/102] Fix code sample in March 2022 Recap (#442) --- docs/_posts/2022-03-31-luau-recap-march-2022.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_posts/2022-03-31-luau-recap-march-2022.md b/docs/_posts/2022-03-31-luau-recap-march-2022.md index 8ac88732..ff3a4d0f 100644 --- a/docs/_posts/2022-03-31-luau-recap-march-2022.md +++ b/docs/_posts/2022-03-31-luau-recap-march-2022.md @@ -39,7 +39,7 @@ local result: Result = ... if result.type == "ok" then -- result :: Ok print(result.value) -else +elseif result.type == "error" then -- result :: Err error(result.error) end From ba60730e0f67c3b06d209cb406fd1d632866c742 Mon Sep 17 00:00:00 2001 From: Alexander McCord <11488393+alexmccord@users.noreply.github.com> Date: Thu, 31 Mar 2022 11:54:06 -0700 Subject: [PATCH 037/102] Add documentation on singleton types and tagged unions to typecheck.md. (#440) Update the typecheck.md page to talk about singleton types and their uses, tagged unions. As a driveby, improve the documentation on type refinements. And delete the unknown symbols part, this is really dated. * Update docs/_pages/typecheck.md to fix a typo Co-authored-by: Arseny Kapoulkine --- docs/_pages/typecheck.md | 124 ++++++++++++++++++++++++++++++--------- 1 file changed, 95 insertions(+), 29 deletions(-) diff --git a/docs/_pages/typecheck.md b/docs/_pages/typecheck.md index 3580d66e..9443f112 100644 --- a/docs/_pages/typecheck.md +++ b/docs/_pages/typecheck.md @@ -31,20 +31,6 @@ foo = 1 However, given the second snippet in strict mode, the type checker would be able to infer `number` for `foo`. -## Unknown symbols - -Consider how often you're likely to assign a new value to a local variable. What if you accidentally misspelled it? Oops, it's now assigned globally and your local variable is still using the old value. - -```lua -local someLocal = 1 - -soeLocal = 2 -- the bug - -print(someLocal) -``` - -Because of this, Luau type checker currently emits an error in strict mode; use local variables instead. - ## Structural type system Luau's type system is structural by default, which is to say that we inspect the shape of two tables to see if they are similar enough. This was the obvious choice because Lua 5.1 is inherently structural. @@ -267,6 +253,23 @@ Note: it's impossible to create an intersection type of some primitive types, e. Note: Luau still does not support user-defined overloaded functions. Some of Roblox and Lua 5.1 functions have different function signature, so inherently requires overloaded functions. +## Singleton types (aka literal types) + +Luau's type system also supports singleton types, which means it's a type that represents one single value at runtime. At this time, both string and booleans are representable in types. + +> We do not currently support numbers as types. For now, this is intentional. + +```lua +local foo: "Foo" = "Foo" -- ok +local bar: "Bar" = foo -- not ok +local baz: string = foo -- ok + +local t: true = true -- ok +local f: false = false -- ok +``` + +This happens all the time, especially through [type refinements](#type-refinements) and is also incredibly useful when you want to enforce program invariants in the type system! See [tagged unions](#tagged-unions) for more information. + ## Variadic types Luau permits assigning a type to the `...` variadic symbol like any other parameter: @@ -375,22 +378,40 @@ local account: Account = Account.new("Alexander", 500) --^^^^^^^ not ok, 'Account' does not exist ``` +## Tagged unions + +Tagged unions are just union types! In particular, they're union types of tables where they have at least _some_ common properties but the structure of the tables are different enough. Here's one example: + +```lua +type Result = { type: "ok", value: T } | { type: "err", error: E } +``` + +This `Result` type can be discriminated by using type refinements on the property `type`, like so: + +```lua +if result.type == "ok" then + -- result is known to be { type: "ok", value: T } + -- and attempting to index for error here will fail + print(result.value) +elseif result.type == "err" then + -- result is known to be { type: "err", error: E } + -- and attempting to index for value here will fail + print(result.error) +end +``` + +Which works out because `value: T` exists only when `type` is in actual fact `"ok"`, and `error: E` exists only when `type` is in actual fact `"err"`. + ## Type refinements -When we check the type of a value, what we're doing is we're refining the type, hence "type refinement." Currently, the support for this is somewhat basic. +When we check the type of any lvalue (a global, a local, or a property), what we're doing is we're refining the type, hence "type refinement." The support for this is arbitrarily complex, so go crazy! -Using `type` comparison: -```lua -local stringOrNumber: string | number = "foo" +Here are all the ways you can refine: +1. Truthy test: `if x then` will refine `x` to be truthy. +2. Type guards: `if type(x) == "number" then` will refine `x` to be `number`. +3. Equality: `x == "hello"` will refine `x` to be a singleton type `"hello"`. -if type(x) == "string" then - local onlyString: string = stringOrNumber -- ok - local onlyNumber: number = stringOrNumber -- not ok -end - -local onlyString: string = stringOrNumber -- not ok -local onlyNumber: number = stringOrNumber -- not ok -``` +And they can be composed with many of `and`/`or`/`not`. `not`, just like `~=`, will flip the resulting refinements, that is `not x` will refine `x` to be falsy. Using truthy test: ```lua @@ -398,10 +419,55 @@ local maybeString: string? = nil if maybeString then local onlyString: string = maybeString -- ok + local onlyNil: nil = maybeString -- not ok +end + +if not maybeString then + local onlyString: string = maybeString -- not ok + local onlyNil: nil = maybeString -- ok end ``` -And using `assert` will work with the above type guards: +Using `type` test: +```lua +local stringOrNumber: string | number = "foo" + +if type(stringOrNumber) == "string" then + local onlyString: string = stringOrNumber -- ok + local onlyNumber: number = stringOrNumber -- not ok +end + +if type(stringOrNumber) ~= "string" then + local onlyString: string = stringOrNumber -- not ok + local onlyNumber: number = stringOrNumber -- ok +end +``` + +Using equality test: +```lua +local myString: string = f() + +if myString == "hello" then + local hello: "hello" = myString -- ok because it is absolutely "hello"! + local copy: string = myString -- ok +end +``` + +And as said earlier, we can compose as many of `and`/`or`/`not` as we wish with these refinements: +```lua +local function f(x: any, y: any) + if (x == "hello" or x == "bye") and type(y) == "string" then + -- x is of type "hello" | "bye" + -- y is of type string + end + + if not (x ~= "hi") then + -- x is of type "hi" + end +end +``` + +`assert` can also be used to refine in all the same ways: ```lua local stringOrNumber: string | number = "foo" @@ -411,7 +477,7 @@ local onlyString: string = stringOrNumber -- ok local onlyNumber: number = stringOrNumber -- not ok ``` -## Typecasts +## Type casts Expressions may be typecast using `::`. Typecasting is useful for specifying the type of an expression when the automatically inferred type is too generic. @@ -487,4 +553,4 @@ There are some caveats here though. For instance, the require path must be resol Cyclic module dependencies can cause problems for the type checker. In order to break a module dependency cycle a typecast of the module to `any` may be used: ```lua local myModule = require(MyModule) :: any -``` \ No newline at end of file +``` From 83c1c48e09105d5c83055c2827b288071fe4e055 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 31 Mar 2022 13:37:49 -0700 Subject: [PATCH 038/102] Sync to upstream/release/521 --- Analysis/src/BuiltinDefinitions.cpp | 5 +-- Analysis/src/TypeInfer.cpp | 28 +++++++++---- Ast/include/Luau/TimeTrace.h | 20 ++++----- Ast/src/Lexer.cpp | 8 +++- Ast/src/TimeTrace.cpp | 5 +-- Compiler/include/Luau/Bytecode.h | 3 +- Compiler/src/BytecodeBuilder.cpp | 2 +- VM/src/laux.cpp | 13 +----- VM/src/lbaselib.cpp | 25 ++--------- VM/src/ldebug.cpp | 18 ++------ VM/src/lmem.h | 6 +-- VM/src/ltable.cpp | 64 +++++++++++++++++++++++------ VM/src/ltablib.cpp | 30 ++++++++++++-- VM/src/lvmexecute.cpp | 6 ++- VM/src/lvmload.cpp | 13 ++---- tests/Autocomplete.test.cpp | 5 +-- tests/Conformance.test.cpp | 4 -- tests/Parser.test.cpp | 14 +++++++ tests/TypeInfer.functions.test.cpp | 41 ++++++++++++++++-- tests/conformance/nextvar.lua | 31 ++++++++++++++ 20 files changed, 222 insertions(+), 119 deletions(-) diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bf9ef303..3895b01b 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,7 +9,6 @@ #include LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) -LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. @@ -289,9 +288,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - - if (FFlag::LuauTableCloneType) - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 9965d5aa..6df6bff0 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as fals LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) +LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) @@ -34,7 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify3, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -463,7 +464,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + std::optional expectedType; + + if (FFlag::LuauInferStatFunction && !fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + } + } + + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -1103,7 +1115,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify2) + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify3) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); @@ -1116,7 +1128,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (ttv->state == TableState::Sealed) { - if (!ttv->indexer || !isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); } @@ -1141,7 +1153,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify2) + else if (FFlag::LuauStatFunctionSimplify3) { LUAU_ASSERT(function.name->is()); @@ -1151,7 +1163,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (function.func->self) { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); AstExprIndexName* indexName = function.name->as(); if (!indexName) @@ -1190,7 +1202,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); TypeId leftType = checkLValueBinding(scope, *function.name); @@ -2985,7 +2997,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T return errorRecoveryType(scope); } - if (FFlag::LuauStatFunctionSimplify2) + if (FFlag::LuauStatFunctionSimplify3) { if (lhsType->persistent) return errorRecoveryType(scope); diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 503eca61..5018456f 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -130,8 +130,8 @@ ThreadContext& getThreadContext(); struct Scope { - explicit Scope(ThreadContext& context, uint16_t token) - : context(context) + explicit Scope(uint16_t token) + : context(getThreadContext()) { if (!FFlag::DebugLuauTimeTracing) return; @@ -152,8 +152,8 @@ struct Scope struct OptionalTailScope { - explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) - : context(context) + explicit OptionalTailScope(uint16_t token, uint32_t threshold) + : context(getThreadContext()) , token(token) , threshold(threshold) { @@ -188,27 +188,27 @@ struct OptionalTailScope uint32_t pos; }; -LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); +LUAU_NOINLINE uint16_t createScopeData(const char* name, const char* category); } // namespace TimeTrace } // namespace Luau // Regular scope #define LUAU_TIMETRACE_SCOPE(name, category) \ - static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ - Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + static uint16_t lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic) // A scope without nested scopes that may be skipped if the time it took is less than the threshold #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ - static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ - Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + static uint16_t lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail, microsec) // Extra key/value data can be added to regular scopes #define LUAU_TIMETRACE_ARGUMENT(name, value) \ do \ { \ if (FFlag::DebugLuauTimeTracing) \ - lttScopeStatic.second.eventArgument(name, value); \ + lttScope.context.eventArgument(name, value); \ } while (false) #else diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index d56c8860..70c6c78d 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkip, false) + namespace Luau { @@ -352,6 +354,8 @@ const Lexeme& Lexer::next() const Lexeme& Lexer::next(bool skipComments) { + bool first = true; + // in skipComments mode we reject valid comments do { @@ -359,9 +363,11 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - prevLocation = lexeme.location; + if (!FFlag::LuauParseLocationIgnoreCommentSkip || first) + prevLocation = lexeme.location; lexeme = readNext(); + first = false; } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index 8079830b..19564f05 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -246,10 +246,9 @@ ThreadContext& getThreadContext() return context; } -std::pair createScopeData(const char* name, const char* category) +uint16_t createScopeData(const char* name, const char* category) { - uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); - return {token, Luau::TimeTrace::getThreadContext()}; + return createToken(Luau::TimeTrace::getGlobalContext(), name, category); } } // namespace TimeTrace } // namespace Luau diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 679712f6..c6e5a03b 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -376,8 +376,7 @@ enum LuauOpcode enum LuauBytecodeTag { // Bytecode version - LBC_VERSION = 1, - LBC_VERSION_FUTURE = 2, // TODO: This will be removed in favor of LBC_VERSION with LuauBytecodeV2Force + LBC_VERSION = 2, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 09f06b68..6944de0f 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -508,7 +508,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(LBC_VERSION_FUTURE); + bytecode = char(LBC_VERSION); writeStringTable(bytecode); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 9fe2ebb6..72169a86 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,8 +11,6 @@ #include -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauMorePreciseLuaLTypeName, false) - /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -337,15 +335,8 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) const char* luaL_typename(lua_State* L, int idx) { - if (DFFlag::LuauMorePreciseLuaLTypeName) - { - const TValue* obj = luaA_toobject(L, idx); - return luaT_objtypename(L, obj); - } - else - { - return lua_typename(L, lua_type(L, idx)); - } + const TValue* obj = luaA_toobject(L, idx); + return luaT_objtypename(L, obj); } /* diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 96ad493b..f7917611 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -11,8 +11,6 @@ #include #include -LUAU_DYNAMIC_FASTFLAG(LuauMorePreciseLuaLTypeName) - static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -189,31 +187,16 @@ static int luaB_gcinfo(lua_State* L) static int luaB_type(lua_State* L) { luaL_checkany(L, 1); - if (DFFlag::LuauMorePreciseLuaLTypeName) - { - /* resulting name doesn't differentiate between userdata types */ - lua_pushstring(L, lua_typename(L, lua_type(L, 1))); - } - else - { - lua_pushstring(L, luaL_typename(L, 1)); - } + /* resulting name doesn't differentiate between userdata types */ + lua_pushstring(L, lua_typename(L, lua_type(L, 1))); return 1; } static int luaB_typeof(lua_State* L) { luaL_checkany(L, 1); - if (DFFlag::LuauMorePreciseLuaLTypeName) - { - /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ - lua_pushstring(L, luaL_typename(L, 1)); - } - else - { - const TValue* obj = luaA_toobject(L, 1); - lua_pushstring(L, luaT_objtypename(L, obj)); - } + /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ + lua_pushstring(L, luaL_typename(L, 1)); return 1; } diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 7a9947b7..e050050e 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauBytecodeV2Force) - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -91,16 +89,6 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } -static int getlinedefined(Proto* p) -{ - if (FFlag::LuauBytecodeV2Force) - return p->linedefined; - else if (p->linedefined >= 0) - return p->linedefined; - else - return luaG_getline(p, 0); -} - static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { int status = 1; @@ -120,7 +108,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, { ar->source = getstr(f->l.p->source); ar->what = "Lua"; - ar->linedefined = getlinedefined(f->l.p); + ar->linedefined = f->l.p->linedefined; } luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; @@ -133,7 +121,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, } else { - ar->currentline = f->isC ? -1 : getlinedefined(f->l.p); + ar->currentline = f->isC ? -1 : f->l.p->linedefined; } break; @@ -424,7 +412,7 @@ static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* con } const char* debugname = p->debugname ? getstr(p->debugname) : NULL; - int linedefined = getlinedefined(p); + int linedefined = p->linedefined; callback(context, debugname, linedefined, depth, buffer, size); diff --git a/VM/src/lmem.h b/VM/src/lmem.h index 00788452..e552d739 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -10,12 +10,12 @@ union GCObject; #define luaM_newgco(L, t, size, memcat) cast_to(t*, luaM_newgco_(L, size, memcat)) #define luaM_freegco(L, p, size, memcat, page) luaM_freegco_(L, obj2gco(p), size, memcat, page) -#define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) +#define luaM_arraysize_(L, n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) -#define luaM_newarray(L, n, t, memcat) cast_to(t*, luaM_new_(L, luaM_arraysize_(n, sizeof(t)), memcat)) +#define luaM_newarray(L, n, t, memcat) cast_to(t*, luaM_new_(L, luaM_arraysize_(L, n, sizeof(t)), memcat)) #define luaM_freearray(L, b, n, t, memcat) luaM_free_(L, (b), (n) * sizeof(t), memcat) #define luaM_reallocarray(L, v, oldn, n, t, memcat) \ - ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(n, sizeof(t)), memcat))) + ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(L, n, sizeof(t)), memcat))) LUAI_FUNC void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat); LUAI_FUNC GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 431501f3..1c75c0b0 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -2,17 +2,26 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details /* -** Implementation of tables (aka arrays, objects, or hash tables). -** Tables keep its elements in two parts: an array part and a hash part. -** Non-negative integer keys are all candidates to be kept in the array -** part. The actual size of the array is the largest `n' such that at -** least half the slots between 0 and n are in use. -** Hash uses a mix of chained scatter table with Brent's variation. -** A main invariant of these tables is that, if an element is not -** in its main position (i.e. the `original' position that its hash gives -** to it), then the colliding element is in its own main position. -** Hence even when the load factor reaches 100%, performance remains good. -*/ + * Implementation of tables (aka arrays, objects, or hash tables). + * + * Tables keep the elements in two parts: an array part and a hash part. + * Integer keys >=1 are all candidates to be kept in the array part. The actual size of the array is the + * largest n such that at least half the slots between 0 and n are in use. + * Hash uses a mix of chained scatter table with Brent's variation. + * + * A main invariant of these tables is that, if an element is not in its main position (i.e. the original + * position that its hash gives to it), then the colliding element is in its own main position. + * Hence even when the load factor reaches 100%, performance remains good. + * + * Table keys can be arbitrary values unless they contain NaN. Keys are hashed and compared using raw equality, + * so even if the key is a userdata with an overridden __eq, it's not used during hash lookups. + * + * Each table has a "boundary", defined as the index k where t[k] ~= nil and t[k+1] == nil. The boundary can be + * computed using a binary search and can be adjusted when the table is modified; crucially, Luau enforces an + * invariant where the boundary must be in the array part - this enforces a consistent iteration order through the + * prefix of the table when using pairs(), and allows to implement algorithms that access elements in 1..#t range + * more efficiently. + */ #include "ltable.h" @@ -25,6 +34,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) +LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary, false) // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 @@ -460,7 +470,20 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) totaluse++; /* compute new size for array part */ int na = computesizes(nums, &nasize); + /* enforce the boundary invariant; for performance, only do hash lookups if we must */ + if (FFlag::LuauTableNewBoundary) + { + bool tbound = t->node != dummynode || nasize < t->sizearray; + int ekindex = ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; + /* move the array size up until the boundary is guaranteed to be inside the array part */ + while (nasize + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, nasize + 1)))) + { + nasize++; + na++; + } + } /* resize the table to new computed sizes */ + LUAU_ASSERT(na <= totaluse); resize(L, t, nasize, totaluse - na); } @@ -520,10 +543,18 @@ static LuaNode* getfreepos(Table* t) */ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { + /* enforce boundary invariant */ + if (FFlag::LuauTableNewBoundary && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + { + rehash(L, t, key); /* grow table */ + + // after rehash, numeric keys might be located in the new array part, but won't be found in the node part + return arrayornewkey(L, t, key); + } + LuaNode* mp = mainposition(t, key); if (!ttisnil(gval(mp)) || mp == dummynode) { - LuaNode* othern; LuaNode* n = getfreepos(t); /* get a free place */ if (n == NULL) { /* cannot find a free place? */ @@ -542,7 +573,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) LUAU_ASSERT(n != dummynode); TValue mk; getnodekey(L, &mk, mp); - othern = mainposition(t, &mk); + LuaNode* othern = mainposition(t, &mk); if (othern != mp) { /* is colliding node out of its main position? */ /* yes; move colliding node into free position */ @@ -704,6 +735,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) { + LUAU_ASSERT(!FFlag::LuauTableNewBoundary); unsigned int i = j; /* i is zero or a present index */ j++; /* find `i' and `j' such that i is present and j is not */ @@ -788,6 +820,12 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } + else if (FFlag::LuauTableNewBoundary) + { + /* validate boundary invariant */ + LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); + return j; + } /* else must find a boundary in hash part */ else if (t->node == dummynode) /* hash part is empty? */ return j; /* that is easy... */ diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 00753742..241a99e3 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,7 +10,9 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauTableClone, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry2, false) + +void (*lua_table_move_telemetry)(lua_State* L, int f, int e, int t, int nf, int nt); static int foreachi(lua_State* L) { @@ -197,6 +199,29 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); + void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; + + if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb) + { + int nf = lua_objlen(L, 1); + int nt = lua_objlen(L, tt); + + bool report = false; + + // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) + if (!(f == 1 || (f >= 1 && f <= nf))) + report = true; + if (!(e == nf || (e >= 1 && e <= nf))) + report = true; + + // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) + if (!(t == nt + 1 || (t >= 1 && t <= nt))) + report = true; + + if (report) + telemetrycb(L, f, e, t, nf, nt); + } + if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); @@ -512,9 +537,6 @@ static int tisfrozen(lua_State* L) static int tclone(lua_State* L) { - if (!FFlag::LuauTableClone) - luaG_runerror(L, "table.clone is not available"); - luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 96a87b7e..34949efb 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAG(LuauTableNewBoundary) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2266,9 +2268,9 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (h->lsizenode == 0 && ttisnil(gval(h->node))) + else if (FFlag::LuauTableNewBoundary || (h->lsizenode == 0 && ttisnil(gval(h->node)))) { - // hash part is empty: fallthrough to exit + // fallthrough to exit VM_NEXT(); } else diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 4e5435b7..8b742f1c 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) - // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -156,12 +154,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : (version != LBC_VERSION && version != LBC_VERSION_FUTURE)) + if (version != LBC_VERSION) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, - FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, LBC_VERSION, version); return 1; } @@ -292,11 +289,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } - if (FFlag::LuauBytecodeV2Force || version == LBC_VERSION_FUTURE) - p->linedefined = readVarInt(data, size, offset); - else - p->linedefined = -1; - + p->linedefined = readVarInt(data, size, offset); p->debugname = readString(strings, data, size, offset); uint8_t lineinfo = read(data, size, offset); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1db782cc..4e8a1d55 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -262,7 +261,7 @@ TEST_CASE_FIXTURE(ACFixture, "get_member_completions") auto ac = autocomplete('1'); - CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2221,7 +2220,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 83d4518d..0ed7dc44 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,8 +241,6 @@ TEST_CASE("Math") TEST_CASE("Table") { - ScopedFastFlag sff("LuauTableClone", true); - runConformance("nextvar.lua"); } @@ -467,8 +465,6 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { - ScopedFastFlag sff("LuauTableCloneType", true); - runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 7f6a6c0d..7dacc669 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1604,6 +1604,20 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") CHECK_EQ((Position{3, 42}), block->body.data[2]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") +{ + ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )"); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index da4ea074..dbae7b54 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1270,7 +1270,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict @@ -1294,7 +1294,7 @@ end TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; CheckResult result = check(R"( function string.len(): number @@ -1302,7 +1302,40 @@ function string.len(): number end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + + // if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash + frontend.clear(); + + result = check(R"( +print(string.len('hello')) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite_2") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; + + CheckResult result = check(R"( +local t: { f: ((x: number) -> number)? } = {} + +function t.f(x) + print(x + 5) + return x .. "asd" +end + +t.f = function(x) + print(x + 5) + return x .. "asd" +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'string' could not be converted into 'number')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") @@ -1319,7 +1352,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; CheckResult result = check(R"( local t: {[string]: () -> number} = {} diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index e85fcbe8..ab9be42c 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -550,4 +550,35 @@ do assert(not pcall(table.clone, 42)) end +-- test boundary invariant maintenance during rehash +do + local arr = table.create(5, 42) + + arr[1] = nil + arr.a = 'a' -- trigger rehash + + assert(#arr == 5) -- technically 0 is also valid, but it happens to be 5 because array capacity is 5 +end + +-- test boundary invariant maintenance when replacing hash keys +do + local arr = {} + arr.a = 'a' + arr.a = nil + arr[1] = 1 -- should rehash and resize array part, otherwise # won't find the boundary in array part + + assert(#arr == 1) +end + +-- test boundary invariant maintenance when table is filled from the end +do + local arr = {} + for i=5,2,-1 do + arr[i] = i + assert(#arr == 0) + end + arr[1] = 1 + assert(#arr == 5) +end + return"OK" From 4c1f208d7a2172202fc3b424cca8c75532e98af9 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 31 Mar 2022 14:01:51 -0700 Subject: [PATCH 039/102] Sync to upstream/release/521 (#443) --- Analysis/src/BuiltinDefinitions.cpp | 5 +-- Analysis/src/TypeInfer.cpp | 28 +++++++++---- Ast/include/Luau/TimeTrace.h | 20 ++++----- Ast/src/Lexer.cpp | 8 +++- Ast/src/TimeTrace.cpp | 5 +-- Compiler/include/Luau/Bytecode.h | 3 +- Compiler/src/BytecodeBuilder.cpp | 2 +- VM/src/laux.cpp | 13 +----- VM/src/lbaselib.cpp | 25 ++--------- VM/src/ldebug.cpp | 18 ++------ VM/src/lmem.h | 6 +-- VM/src/ltable.cpp | 64 +++++++++++++++++++++++------ VM/src/ltablib.cpp | 30 ++++++++++++-- VM/src/lvmexecute.cpp | 6 ++- VM/src/lvmload.cpp | 13 ++---- tests/Autocomplete.test.cpp | 5 +-- tests/Conformance.test.cpp | 4 -- tests/Parser.test.cpp | 14 +++++++ tests/TypeInfer.functions.test.cpp | 41 ++++++++++++++++-- tests/conformance/nextvar.lua | 31 ++++++++++++++ 20 files changed, 222 insertions(+), 119 deletions(-) diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bf9ef303..3895b01b 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,7 +9,6 @@ #include LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) -LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. @@ -289,9 +288,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - - if (FFlag::LuauTableCloneType) - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 9965d5aa..6df6bff0 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as fals LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) +LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) @@ -34,7 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify3, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -463,7 +464,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + std::optional expectedType; + + if (FFlag::LuauInferStatFunction && !fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + } + } + + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -1103,7 +1115,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify2) + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify3) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); @@ -1116,7 +1128,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (ttv->state == TableState::Sealed) { - if (!ttv->indexer || !isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); } @@ -1141,7 +1153,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify2) + else if (FFlag::LuauStatFunctionSimplify3) { LUAU_ASSERT(function.name->is()); @@ -1151,7 +1163,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (function.func->self) { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); AstExprIndexName* indexName = function.name->as(); if (!indexName) @@ -1190,7 +1202,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); TypeId leftType = checkLValueBinding(scope, *function.name); @@ -2985,7 +2997,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T return errorRecoveryType(scope); } - if (FFlag::LuauStatFunctionSimplify2) + if (FFlag::LuauStatFunctionSimplify3) { if (lhsType->persistent) return errorRecoveryType(scope); diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 503eca61..5018456f 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -130,8 +130,8 @@ ThreadContext& getThreadContext(); struct Scope { - explicit Scope(ThreadContext& context, uint16_t token) - : context(context) + explicit Scope(uint16_t token) + : context(getThreadContext()) { if (!FFlag::DebugLuauTimeTracing) return; @@ -152,8 +152,8 @@ struct Scope struct OptionalTailScope { - explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) - : context(context) + explicit OptionalTailScope(uint16_t token, uint32_t threshold) + : context(getThreadContext()) , token(token) , threshold(threshold) { @@ -188,27 +188,27 @@ struct OptionalTailScope uint32_t pos; }; -LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); +LUAU_NOINLINE uint16_t createScopeData(const char* name, const char* category); } // namespace TimeTrace } // namespace Luau // Regular scope #define LUAU_TIMETRACE_SCOPE(name, category) \ - static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ - Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + static uint16_t lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic) // A scope without nested scopes that may be skipped if the time it took is less than the threshold #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ - static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ - Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + static uint16_t lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail, microsec) // Extra key/value data can be added to regular scopes #define LUAU_TIMETRACE_ARGUMENT(name, value) \ do \ { \ if (FFlag::DebugLuauTimeTracing) \ - lttScopeStatic.second.eventArgument(name, value); \ + lttScope.context.eventArgument(name, value); \ } while (false) #else diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index d56c8860..70c6c78d 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkip, false) + namespace Luau { @@ -352,6 +354,8 @@ const Lexeme& Lexer::next() const Lexeme& Lexer::next(bool skipComments) { + bool first = true; + // in skipComments mode we reject valid comments do { @@ -359,9 +363,11 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - prevLocation = lexeme.location; + if (!FFlag::LuauParseLocationIgnoreCommentSkip || first) + prevLocation = lexeme.location; lexeme = readNext(); + first = false; } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index 8079830b..19564f05 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -246,10 +246,9 @@ ThreadContext& getThreadContext() return context; } -std::pair createScopeData(const char* name, const char* category) +uint16_t createScopeData(const char* name, const char* category) { - uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); - return {token, Luau::TimeTrace::getThreadContext()}; + return createToken(Luau::TimeTrace::getGlobalContext(), name, category); } } // namespace TimeTrace } // namespace Luau diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 679712f6..c6e5a03b 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -376,8 +376,7 @@ enum LuauOpcode enum LuauBytecodeTag { // Bytecode version - LBC_VERSION = 1, - LBC_VERSION_FUTURE = 2, // TODO: This will be removed in favor of LBC_VERSION with LuauBytecodeV2Force + LBC_VERSION = 2, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 09f06b68..6944de0f 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -508,7 +508,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(LBC_VERSION_FUTURE); + bytecode = char(LBC_VERSION); writeStringTable(bytecode); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 9fe2ebb6..72169a86 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,8 +11,6 @@ #include -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauMorePreciseLuaLTypeName, false) - /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -337,15 +335,8 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) const char* luaL_typename(lua_State* L, int idx) { - if (DFFlag::LuauMorePreciseLuaLTypeName) - { - const TValue* obj = luaA_toobject(L, idx); - return luaT_objtypename(L, obj); - } - else - { - return lua_typename(L, lua_type(L, idx)); - } + const TValue* obj = luaA_toobject(L, idx); + return luaT_objtypename(L, obj); } /* diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 96ad493b..f7917611 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -11,8 +11,6 @@ #include #include -LUAU_DYNAMIC_FASTFLAG(LuauMorePreciseLuaLTypeName) - static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -189,31 +187,16 @@ static int luaB_gcinfo(lua_State* L) static int luaB_type(lua_State* L) { luaL_checkany(L, 1); - if (DFFlag::LuauMorePreciseLuaLTypeName) - { - /* resulting name doesn't differentiate between userdata types */ - lua_pushstring(L, lua_typename(L, lua_type(L, 1))); - } - else - { - lua_pushstring(L, luaL_typename(L, 1)); - } + /* resulting name doesn't differentiate between userdata types */ + lua_pushstring(L, lua_typename(L, lua_type(L, 1))); return 1; } static int luaB_typeof(lua_State* L) { luaL_checkany(L, 1); - if (DFFlag::LuauMorePreciseLuaLTypeName) - { - /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ - lua_pushstring(L, luaL_typename(L, 1)); - } - else - { - const TValue* obj = luaA_toobject(L, 1); - lua_pushstring(L, luaT_objtypename(L, obj)); - } + /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ + lua_pushstring(L, luaL_typename(L, 1)); return 1; } diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 7a9947b7..e050050e 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauBytecodeV2Force) - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -91,16 +89,6 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } -static int getlinedefined(Proto* p) -{ - if (FFlag::LuauBytecodeV2Force) - return p->linedefined; - else if (p->linedefined >= 0) - return p->linedefined; - else - return luaG_getline(p, 0); -} - static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { int status = 1; @@ -120,7 +108,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, { ar->source = getstr(f->l.p->source); ar->what = "Lua"; - ar->linedefined = getlinedefined(f->l.p); + ar->linedefined = f->l.p->linedefined; } luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; @@ -133,7 +121,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, } else { - ar->currentline = f->isC ? -1 : getlinedefined(f->l.p); + ar->currentline = f->isC ? -1 : f->l.p->linedefined; } break; @@ -424,7 +412,7 @@ static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* con } const char* debugname = p->debugname ? getstr(p->debugname) : NULL; - int linedefined = getlinedefined(p); + int linedefined = p->linedefined; callback(context, debugname, linedefined, depth, buffer, size); diff --git a/VM/src/lmem.h b/VM/src/lmem.h index 00788452..e552d739 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -10,12 +10,12 @@ union GCObject; #define luaM_newgco(L, t, size, memcat) cast_to(t*, luaM_newgco_(L, size, memcat)) #define luaM_freegco(L, p, size, memcat, page) luaM_freegco_(L, obj2gco(p), size, memcat, page) -#define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) +#define luaM_arraysize_(L, n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) -#define luaM_newarray(L, n, t, memcat) cast_to(t*, luaM_new_(L, luaM_arraysize_(n, sizeof(t)), memcat)) +#define luaM_newarray(L, n, t, memcat) cast_to(t*, luaM_new_(L, luaM_arraysize_(L, n, sizeof(t)), memcat)) #define luaM_freearray(L, b, n, t, memcat) luaM_free_(L, (b), (n) * sizeof(t), memcat) #define luaM_reallocarray(L, v, oldn, n, t, memcat) \ - ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(n, sizeof(t)), memcat))) + ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(L, n, sizeof(t)), memcat))) LUAI_FUNC void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat); LUAI_FUNC GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 431501f3..1c75c0b0 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -2,17 +2,26 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details /* -** Implementation of tables (aka arrays, objects, or hash tables). -** Tables keep its elements in two parts: an array part and a hash part. -** Non-negative integer keys are all candidates to be kept in the array -** part. The actual size of the array is the largest `n' such that at -** least half the slots between 0 and n are in use. -** Hash uses a mix of chained scatter table with Brent's variation. -** A main invariant of these tables is that, if an element is not -** in its main position (i.e. the `original' position that its hash gives -** to it), then the colliding element is in its own main position. -** Hence even when the load factor reaches 100%, performance remains good. -*/ + * Implementation of tables (aka arrays, objects, or hash tables). + * + * Tables keep the elements in two parts: an array part and a hash part. + * Integer keys >=1 are all candidates to be kept in the array part. The actual size of the array is the + * largest n such that at least half the slots between 0 and n are in use. + * Hash uses a mix of chained scatter table with Brent's variation. + * + * A main invariant of these tables is that, if an element is not in its main position (i.e. the original + * position that its hash gives to it), then the colliding element is in its own main position. + * Hence even when the load factor reaches 100%, performance remains good. + * + * Table keys can be arbitrary values unless they contain NaN. Keys are hashed and compared using raw equality, + * so even if the key is a userdata with an overridden __eq, it's not used during hash lookups. + * + * Each table has a "boundary", defined as the index k where t[k] ~= nil and t[k+1] == nil. The boundary can be + * computed using a binary search and can be adjusted when the table is modified; crucially, Luau enforces an + * invariant where the boundary must be in the array part - this enforces a consistent iteration order through the + * prefix of the table when using pairs(), and allows to implement algorithms that access elements in 1..#t range + * more efficiently. + */ #include "ltable.h" @@ -25,6 +34,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) +LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary, false) // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 @@ -460,7 +470,20 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) totaluse++; /* compute new size for array part */ int na = computesizes(nums, &nasize); + /* enforce the boundary invariant; for performance, only do hash lookups if we must */ + if (FFlag::LuauTableNewBoundary) + { + bool tbound = t->node != dummynode || nasize < t->sizearray; + int ekindex = ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; + /* move the array size up until the boundary is guaranteed to be inside the array part */ + while (nasize + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, nasize + 1)))) + { + nasize++; + na++; + } + } /* resize the table to new computed sizes */ + LUAU_ASSERT(na <= totaluse); resize(L, t, nasize, totaluse - na); } @@ -520,10 +543,18 @@ static LuaNode* getfreepos(Table* t) */ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { + /* enforce boundary invariant */ + if (FFlag::LuauTableNewBoundary && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + { + rehash(L, t, key); /* grow table */ + + // after rehash, numeric keys might be located in the new array part, but won't be found in the node part + return arrayornewkey(L, t, key); + } + LuaNode* mp = mainposition(t, key); if (!ttisnil(gval(mp)) || mp == dummynode) { - LuaNode* othern; LuaNode* n = getfreepos(t); /* get a free place */ if (n == NULL) { /* cannot find a free place? */ @@ -542,7 +573,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) LUAU_ASSERT(n != dummynode); TValue mk; getnodekey(L, &mk, mp); - othern = mainposition(t, &mk); + LuaNode* othern = mainposition(t, &mk); if (othern != mp) { /* is colliding node out of its main position? */ /* yes; move colliding node into free position */ @@ -704,6 +735,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) { + LUAU_ASSERT(!FFlag::LuauTableNewBoundary); unsigned int i = j; /* i is zero or a present index */ j++; /* find `i' and `j' such that i is present and j is not */ @@ -788,6 +820,12 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } + else if (FFlag::LuauTableNewBoundary) + { + /* validate boundary invariant */ + LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); + return j; + } /* else must find a boundary in hash part */ else if (t->node == dummynode) /* hash part is empty? */ return j; /* that is easy... */ diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 00753742..241a99e3 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,7 +10,9 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauTableClone, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry2, false) + +void (*lua_table_move_telemetry)(lua_State* L, int f, int e, int t, int nf, int nt); static int foreachi(lua_State* L) { @@ -197,6 +199,29 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); + void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; + + if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb) + { + int nf = lua_objlen(L, 1); + int nt = lua_objlen(L, tt); + + bool report = false; + + // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) + if (!(f == 1 || (f >= 1 && f <= nf))) + report = true; + if (!(e == nf || (e >= 1 && e <= nf))) + report = true; + + // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) + if (!(t == nt + 1 || (t >= 1 && t <= nt))) + report = true; + + if (report) + telemetrycb(L, f, e, t, nf, nt); + } + if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); @@ -512,9 +537,6 @@ static int tisfrozen(lua_State* L) static int tclone(lua_State* L) { - if (!FFlag::LuauTableClone) - luaG_runerror(L, "table.clone is not available"); - luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 96a87b7e..34949efb 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAG(LuauTableNewBoundary) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2266,9 +2268,9 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (h->lsizenode == 0 && ttisnil(gval(h->node))) + else if (FFlag::LuauTableNewBoundary || (h->lsizenode == 0 && ttisnil(gval(h->node)))) { - // hash part is empty: fallthrough to exit + // fallthrough to exit VM_NEXT(); } else diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 4e5435b7..8b742f1c 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) - // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -156,12 +154,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : (version != LBC_VERSION && version != LBC_VERSION_FUTURE)) + if (version != LBC_VERSION) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, - FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, LBC_VERSION, version); return 1; } @@ -292,11 +289,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } - if (FFlag::LuauBytecodeV2Force || version == LBC_VERSION_FUTURE) - p->linedefined = readVarInt(data, size, offset); - else - p->linedefined = -1; - + p->linedefined = readVarInt(data, size, offset); p->debugname = readString(strings, data, size, offset); uint8_t lineinfo = read(data, size, offset); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1db782cc..4e8a1d55 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -262,7 +261,7 @@ TEST_CASE_FIXTURE(ACFixture, "get_member_completions") auto ac = autocomplete('1'); - CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2221,7 +2220,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 83d4518d..0ed7dc44 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,8 +241,6 @@ TEST_CASE("Math") TEST_CASE("Table") { - ScopedFastFlag sff("LuauTableClone", true); - runConformance("nextvar.lua"); } @@ -467,8 +465,6 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { - ScopedFastFlag sff("LuauTableCloneType", true); - runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 7f6a6c0d..7dacc669 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1604,6 +1604,20 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") CHECK_EQ((Position{3, 42}), block->body.data[2]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") +{ + ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )"); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index da4ea074..dbae7b54 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1270,7 +1270,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict @@ -1294,7 +1294,7 @@ end TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; CheckResult result = check(R"( function string.len(): number @@ -1302,7 +1302,40 @@ function string.len(): number end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + + // if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash + frontend.clear(); + + result = check(R"( +print(string.len('hello')) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite_2") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; + + CheckResult result = check(R"( +local t: { f: ((x: number) -> number)? } = {} + +function t.f(x) + print(x + 5) + return x .. "asd" +end + +t.f = function(x) + print(x + 5) + return x .. "asd" +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'string' could not be converted into 'number')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") @@ -1319,7 +1352,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; CheckResult result = check(R"( local t: {[string]: () -> number} = {} diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index e85fcbe8..ab9be42c 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -550,4 +550,35 @@ do assert(not pcall(table.clone, 42)) end +-- test boundary invariant maintenance during rehash +do + local arr = table.create(5, 42) + + arr[1] = nil + arr.a = 'a' -- trigger rehash + + assert(#arr == 5) -- technically 0 is also valid, but it happens to be 5 because array capacity is 5 +end + +-- test boundary invariant maintenance when replacing hash keys +do + local arr = {} + arr.a = 'a' + arr.a = nil + arr[1] = 1 -- should rehash and resize array part, otherwise # won't find the boundary in array part + + assert(#arr == 1) +end + +-- test boundary invariant maintenance when table is filled from the end +do + local arr = {} + for i=5,2,-1 do + arr[i] = i + assert(#arr == 0) + end + arr[1] = 1 + assert(#arr == 5) +end + return"OK" From 916c83fdc47c03bcf238a8b7dfbb3e9a4385f842 Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Thu, 31 Mar 2022 18:29:42 -0500 Subject: [PATCH 040/102] Prototype: Added a discussion of set-theoretic models of subtyping (#431) * Added a discussion of set-theoretic models of subtyping to the prototype --- prototyping/Properties/Subtyping.agda | 90 +++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/prototyping/Properties/Subtyping.agda b/prototyping/Properties/Subtyping.agda index 6a0b4203..0a20f244 100644 --- a/prototyping/Properties/Subtyping.agda +++ b/prototyping/Properties/Subtyping.agda @@ -9,6 +9,7 @@ open import Luau.Type using (Type; Scalar; strict; nil; number; string; boolean; open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Equality using (_≢_) open import Properties.Functions using (_∘_) +open import Properties.Product using (_×_; _,_) src = Luau.Type.src strict @@ -219,3 +220,92 @@ any-≮:-none = witness (scalar nil) any none function-≮:-none : ∀ {T U} → ((T ⇒ U) ≮: none) function-≮:-none = witness function function none + +-- A Gentle Introduction To Semantic Subtyping (https://www.cduce.org/papers/gentle.pdf) +-- defines a "set-theoretic" model (sec 2.5) +-- Unfortunately we don't quite have this property, due to uninhabited types, +-- for example (none -> T) is equivalent to (none -> U) +-- when types are interpreted as sets of syntactic values. + +_⊆_ : ∀ {A : Set} → (A → Set) → (A → Set) → Set +(P ⊆ Q) = ∀ a → (P a) → (Q a) + +_⊗_ : ∀ {A B : Set} → (A → Set) → (B → Set) → ((A × B) → Set) +(P ⊗ Q) (a , b) = (P a) × (Q b) + +Comp : ∀ {A : Set} → (A → Set) → (A → Set) +Comp P a = ¬(P a) + +set-theoretic-if : ∀ {S₁ T₁ S₂ T₂} → + + -- This is the "if" part of being a set-theoretic model + (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) → + (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Language T₁)) → Q ⊆ Comp((Language S₂) ⊗ Comp(Language T₂))) + +set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , u) Qtu (S₂t , ¬T₂u) = q (t , u) Qtu (S₁t , ¬T₁u) where + + S₁t : Language S₁ t + S₁t with dec-language S₁ t + S₁t | Left ¬S₁t with p (function-err t) (function-err ¬S₁t) + S₁t | Left ¬S₁t | function-err ¬S₂t = CONTRADICTION (language-comp t ¬S₂t S₂t) + S₁t | Right r = r + + ¬T₁u : ¬(Language T₁ u) + ¬T₁u T₁u with p (function-ok u) (function-ok T₁u) + ¬T₁u T₁u | function-ok T₂u = ¬T₂u T₂u + +not-quite-set-theoretic-only-if : ∀ {S₁ T₁ S₂ T₂} → + + -- We don't quite have that this is a set-theoretic model + -- it's only true when Language T₁ and ¬Language T₂ t₂ are inhabited + -- in particular it's not true when T₁ is none, or T₂ is any. + ∀ s₂ t₂ → Language S₂ s₂ → ¬Language T₂ t₂ → + + -- This is the "only if" part of being a set-theoretic model + (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Language T₁)) → Q ⊆ Comp((Language S₂) ⊗ Comp(Language T₂))) → + (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) + +not-quite-set-theoretic-only-if {S₁} {T₁} {S₂} {T₂} s₂ t₂ S₂s₂ ¬T₂t₂ p = r where + + Q : (Tree × Tree) → Set + Q (t , u) = Either (¬Language S₁ t) (Language T₁ u) + + q : Q ⊆ Comp((Language S₁) ⊗ Comp(Language T₁)) + q (t , u) (Left ¬S₁t) (S₁t , ¬T₁u) = language-comp t ¬S₁t S₁t + q (t , u) (Right T₂u) (S₁t , ¬T₁u) = ¬T₁u T₂u + + r : Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂) + r function function = function + r (function-err t) (function-err ¬S₁t) with dec-language S₂ t + r (function-err t) (function-err ¬S₁t) | Left ¬S₂t = function-err ¬S₂t + r (function-err t) (function-err ¬S₁t) | Right S₂t = CONTRADICTION (p Q q (t , t₂) (Left ¬S₁t) (S₂t , language-comp t₂ ¬T₂t₂)) + r (function-ok t) (function-ok T₁t) with dec-language T₂ t + r (function-ok t) (function-ok T₁t) | Left ¬T₂t = CONTRADICTION (p Q q (s₂ , t) (Right T₁t) (S₂s₂ , language-comp t ¬T₂t)) + r (function-ok t) (function-ok T₁t) | Right T₂t = function-ok T₂t + +-- A counterexample when the argument type is empty. + +set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language none) ⊗ Comp(Language number)) → Q ⊆ Comp((Language none) ⊗ Comp(Language string))) +set-theoretic-counterexample-one Q q ((scalar s) , u) Qtu (scalar () , p) +set-theoretic-counterexample-one Q q ((function-err t) , u) Qtu (scalar-function-err () , p) + +set-theoretic-counterexample-two : (none ⇒ number) ≮: (none ⇒ string) +set-theoretic-counterexample-two = witness + (function-ok (scalar number)) (function-ok (scalar number)) + (function-ok (scalar-scalar number string (λ ()))) + +-- At some point we may deal with overloaded function resolution, which should fix this problem... +-- The reason why this is connected to overloaded functions is that currently we have that the type of +-- f(x) is (tgt T) where f:T. Really we should have the type depend on the type of x, that is use (tgt T U), +-- where U is the type of x. In particular (tgt (S => T) (U & V)) should be the same as (tgt ((S&U) => T) V) +-- and tgt(none => T) should be any. For example +-- +-- tgt((number => string) & (string => bool))(number) +-- is tgt(number => string)(number) & tgt(string => bool)(number) +-- is tgt(number => string)(number) & tgt(string => bool)(number&any) +-- is tgt(number => string)(number) & tgt(string&number => bool)(any) +-- is tgt(number => string)(number) & tgt(none => bool)(any) +-- is string & any +-- is string +-- +-- there's some discussion of this in the Gentle Introduction paper. From dc32a3253ed87ac817ccd6652f536f6f18d347e8 Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Thu, 31 Mar 2022 18:44:41 -0500 Subject: [PATCH 041/102] Add short example of width subtyping (#444) --- docs/_pages/typecheck.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/_pages/typecheck.md b/docs/_pages/typecheck.md index 9443f112..20f73987 100644 --- a/docs/_pages/typecheck.md +++ b/docs/_pages/typecheck.md @@ -137,6 +137,15 @@ local t = {x = 1} -- {x: number} t.y = 2 -- not ok ``` +Sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer + +```lua +type Point1D = { x : number } +type Point2D = { x : number, y : number } +local p : Point2D = { x = 5, y = 37 } +local q : Point1D = p -- ok because Point2D has more properties than Point1D +``` + ### Generic tables This typically occurs when the symbol does not have any annotated types or were not inferred anything concrete. In this case, when you index on a parameter, you're requesting that there is a table with a matching interface. From ffff25a9e502df4555c3bb30feb5f3625fd7d135 Mon Sep 17 00:00:00 2001 From: Alexander McCord <11488393+alexmccord@users.noreply.github.com> Date: Thu, 7 Apr 2022 09:16:44 -0700 Subject: [PATCH 042/102] Improve the UX of reading tagged unions a smidge. (#449) The page is a little narrow, and having to scroll on this horizontally isn't too nice. This fixes the UX for this specific part. --- docs/_pages/typecheck.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/_pages/typecheck.md b/docs/_pages/typecheck.md index 20f73987..63e4c8bb 100644 --- a/docs/_pages/typecheck.md +++ b/docs/_pages/typecheck.md @@ -392,18 +392,20 @@ local account: Account = Account.new("Alexander", 500) Tagged unions are just union types! In particular, they're union types of tables where they have at least _some_ common properties but the structure of the tables are different enough. Here's one example: ```lua -type Result = { type: "ok", value: T } | { type: "err", error: E } +type Ok = { type: "ok", value: T } +type Err = { type: "err", error: E } +type Result = Ok | Err ``` This `Result` type can be discriminated by using type refinements on the property `type`, like so: ```lua if result.type == "ok" then - -- result is known to be { type: "ok", value: T } + -- result is known to be Ok -- and attempting to index for error here will fail print(result.value) elseif result.type == "err" then - -- result is known to be { type: "err", error: E } + -- result is known to be Err -- and attempting to index for value here will fail print(result.error) end From de1381e3f131f2c307f68033ea494d1c35f2c3ca Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 7 Apr 2022 14:29:01 -0700 Subject: [PATCH 043/102] Sync to upstream/release/522 (#450) --- Analysis/include/Luau/Clone.h | 25 ++ Analysis/include/Luau/Frontend.h | 23 +- Analysis/include/Luau/Module.h | 29 +- Analysis/include/Luau/TypeInfer.h | 10 + Analysis/include/Luau/TypeVar.h | 2 + Analysis/include/Luau/Variant.h | 36 +- Analysis/src/Autocomplete.cpp | 117 +++++- Analysis/src/Clone.cpp | 371 ++++++++++++++++++ Analysis/src/Error.cpp | 1 + Analysis/src/Frontend.cpp | 144 +++++-- Analysis/src/IostreamHelpers.cpp | 419 +++++++++------------ Analysis/src/Module.cpp | 359 +----------------- Analysis/src/TxnLog.cpp | 31 +- Analysis/src/TypeInfer.cpp | 156 +++++--- Analysis/src/TypeVar.cpp | 4 + Ast/include/Luau/TimeTrace.h | 11 +- Ast/src/Parser.cpp | 11 + Ast/src/TimeTrace.cpp | 19 +- Compiler/src/ConstantFolding.cpp | 3 +- Sources.cmake | 2 + VM/src/lfunc.h | 2 +- VM/src/ltablib.cpp | 2 +- tests/Autocomplete.test.cpp | 139 ++++++- tests/Compiler.test.cpp | 33 ++ tests/Fixture.cpp | 5 +- tests/Fixture.h | 2 +- tests/Frontend.test.cpp | 64 ++++ tests/Module.test.cpp | 1 + tests/Parser.test.cpp | 21 ++ tests/TypeInfer.functions.test.cpp | 39 +- tests/TypeInfer.intersectionTypes.test.cpp | 39 +- tests/TypeInfer.primitives.test.cpp | 2 + tests/TypeInfer.tables.test.cpp | 15 + tests/TypeInfer.tryUnify.test.cpp | 26 ++ tests/TypeInfer.unionTypes.test.cpp | 25 ++ tools/lldb_formatters.py | 4 +- 36 files changed, 1407 insertions(+), 785 deletions(-) create mode 100644 Analysis/include/Luau/Clone.h create mode 100644 Analysis/src/Clone.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h new file mode 100644 index 00000000..917ef801 --- /dev/null +++ b/Analysis/include/Luau/Clone.h @@ -0,0 +1,25 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +// Only exposed so they can be unit tested. +using SeenTypes = std::unordered_map; +using SeenTypePacks = std::unordered_map; + +struct CloneState +{ + int recursionCount = 0; + bool encounteredFreeType = false; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 0bf8f362..2266f548 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAG(LuauSeparateTypechecks) + namespace Luau { @@ -55,10 +57,19 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa struct SourceNode { + bool isDirty(bool forAutocomplete) const + { + if (FFlag::LuauSeparateTypechecks) + return forAutocomplete ? dirtyAutocomplete : dirty; + else + return dirty; + } + ModuleName name; std::unordered_set requires; std::vector> requireLocations; bool dirty = true; + bool dirtyAutocomplete = true; }; struct FrontendOptions @@ -71,12 +82,16 @@ struct FrontendOptions // When true, we run typechecking twice, once in the regular mode, and once in strict mode // in order to get more precise type information (e.g. for autocomplete). - bool typecheckTwice = false; + bool typecheckTwice_DEPRECATED = false; + + // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) + bool forAutocomplete = false; }; struct CheckResult { std::vector errors; + std::vector timeoutHits; }; struct FrontendModuleResolver : ModuleResolver @@ -123,7 +138,7 @@ struct Frontend CheckResult check(const SourceModule& module); // OLD. TODO KILL LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); - bool isDirty(const ModuleName& name) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); /** Borrow a pointer into the SourceModule cache. @@ -147,10 +162,10 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root); + bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); static LintResult classifyLints(const std::vector& warnings, const Config& config); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 6c689b7c..9a32f614 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -29,8 +29,8 @@ struct SourceModule std::optional environmentName; bool cyclic = false; - std::unique_ptr allocator; - std::unique_ptr names; + std::shared_ptr allocator; + std::shared_ptr names; std::vector parseErrors; AstStatBlock* root = nullptr; @@ -48,6 +48,12 @@ struct SourceModule bool isWithinComment(const SourceModule& sourceModule, Position pos); +struct RequireCycle +{ + Location location; + std::vector path; // one of the paths for a require() to go all the way back to the originating module +}; + struct TypeArena { TypedAllocator typeVars; @@ -77,20 +83,6 @@ struct TypeArena void freeze(TypeArena& arena); void unfreeze(TypeArena& arena); -// Only exposed so they can be unit tested. -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -struct CloneState -{ - int recursionCount = 0; - bool encounteredFreeType = false; -}; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); - struct Module { ~Module(); @@ -98,6 +90,10 @@ struct Module TypeArena interfaceTypes; TypeArena internalTypes; + // Scopes and AST types refer to parse data, so we need to keep that alive + std::shared_ptr allocator; + std::shared_ptr names; + std::vector> scopes; // never empty DenseHashMap astTypes{nullptr}; @@ -109,6 +105,7 @@ struct Module ErrorVec errors; Mode mode; SourceCode::Type type; + bool timeout = false; ScopePtr getModuleScope() const; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 839043cc..215da67f 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -124,6 +124,12 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; +class TimeLimitError : public std::exception +{ +public: + virtual const char* what() const throw(); +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -413,6 +419,10 @@ public: UnifierSharedState unifierState; + std::vector requireCycles; + + std::optional finishTime; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b8c4b362..f61e4044 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -513,6 +513,8 @@ struct SingletonTypes const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId trueType; + const TypeId falseType; const TypeId anyType; const TypeId optionalNumberType; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 63d5a65c..5efe89ed 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -2,45 +2,14 @@ #pragma once #include "Luau/Common.h" - -#ifndef LUAU_USE_STD_VARIANT -#define LUAU_USE_STD_VARIANT 0 -#endif - -#if LUAU_USE_STD_VARIANT -#include -#else #include #include #include #include -#endif namespace Luau { -#if LUAU_USE_STD_VARIANT -template -using Variant = std::variant; - -template -auto visit(Visitor&& vis, Variant&& var) -{ - // This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access - // but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless - // variants since we will never generate them and call into a libc++ function that doesn't throw. - LUAU_ASSERT(!var.valueless_by_exception()); - -#ifdef __APPLE__ - // See https://stackoverflow.com/a/53868971/503215 - return std::__variant_detail::__visitation::__variant::__visit_value(vis, var); -#else - return std::visit(vis, var); -#endif -} - -using std::get_if; -#else template class Variant { @@ -248,6 +217,8 @@ static void fnVisitV(Visitor& vis, std::conditional_t, const template auto visit(Visitor&& vis, const Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -273,6 +244,8 @@ auto visit(Visitor&& vis, const Variant& var) template auto visit(Visitor&& vis, Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative&>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -294,7 +267,6 @@ auto visit(Visitor&& vis, Variant& var) return res; } } -#endif template inline constexpr bool always_false_v = false; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 492edf25..b7201ab3 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) static const std::unordered_set kStatementStartingKeywords = { @@ -625,6 +626,31 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +{ + auto formatKey = [addQuotes](const std::string& key) { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + static bool canSuggestInferredType(ScopePtr scope, TypeId ty) { ty = follow(ty); @@ -1309,17 +1335,38 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + if (FFlag::LuauAutocompleteSingletonTypes) + { + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, result); + } + else + { + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + } } } @@ -1625,17 +1672,33 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (node->is()) { + AutocompleteEntryMap result; + + if (FFlag::LuauAutocompleteSingletonTypes) + { + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, result); + } + if (finder.ancestry.size() >= 2) { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); + } + else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as(); + binExpr && FFlag::LuauAutocompleteSingletonTypes) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { - return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, result); } } } - return {}; + + return {result, finder.ancestry}; } if (node->is()) @@ -1653,18 +1716,31 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. - frontend.check(moduleName); + if (FFlag::LuauSeparateTypechecks) + { + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); + } + else + { + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. + frontend.check(moduleName); + } const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = + (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = + (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1692,7 +1768,8 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = + (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp new file mode 100644 index 00000000..ac9705a7 --- /dev/null +++ b/Analysis/src/Clone.cpp @@ -0,0 +1,371 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Clone.h" +#include "Luau/Module.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TypePack.h" +#include "Luau/Unifiable.h" + +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) + +namespace Luau +{ + +namespace +{ + +struct TypePackCloner; + +/* + * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. + * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. + */ + +struct TypeCloner +{ + TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + : dest(dest) + , typeId(typeId) + , seenTypes(seenTypes) + , seenTypePacks(seenTypePacks) + , cloneState(cloneState) + { + } + + TypeArena& dest; + TypeId typeId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + template + void defaultClone(const T& t); + + void operator()(const Unifiable::Free& t); + void operator()(const Unifiable::Generic& t); + void operator()(const Unifiable::Bound& t); + void operator()(const Unifiable::Error& t); + void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); + void operator()(const FunctionTypeVar& t); + void operator()(const TableTypeVar& t); + void operator()(const MetatableTypeVar& t); + void operator()(const ClassTypeVar& t); + void operator()(const AnyTypeVar& t); + void operator()(const UnionTypeVar& t); + void operator()(const IntersectionTypeVar& t); + void operator()(const LazyTypeVar& t); +}; + +struct TypePackCloner +{ + TypeArena& dest; + TypePackId typePackId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + : dest(dest) + , typePackId(typePackId) + , seenTypes(seenTypes) + , seenTypePacks(seenTypePacks) + , cloneState(cloneState) + { + } + + template + void defaultClone(const T& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{t}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Free& t) + { + cloneState.encounteredFreeType = true; + + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Generic& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Error& t) + { + defaultClone(t); + } + + // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. + // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. + void operator()(const Unifiable::Bound& t) + { + TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const VariadicTypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const TypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePack{}); + TypePack* destTp = getMutable(cloned); + LUAU_ASSERT(destTp != nullptr); + seenTypePacks[typePackId] = cloned; + + for (TypeId ty : t.head) + destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + + if (t.tail) + destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); + } +}; + +template +void TypeCloner::defaultClone(const T& t) +{ + TypeId cloned = dest.addType(t); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Free& t) +{ + cloneState.encounteredFreeType = true; + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Generic& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Bound& t) +{ + TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + seenTypes[typeId] = boundTo; +} + +void TypeCloner::operator()(const Unifiable::Error& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const PrimitiveTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const FunctionTypeVar& t) +{ + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + FunctionTypeVar* ftv = getMutable(result); + LUAU_ASSERT(ftv != nullptr); + + seenTypes[typeId] = result; + + for (TypeId generic : t.generics) + ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); + + for (TypePackId genericPack : t.genericPacks) + ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); + + ftv->tags = t.tags; + ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); + ftv->argNames = t.argNames; + ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); +} + +void TypeCloner::operator()(const TableTypeVar& t) +{ + // If table is now bound to another one, we ignore the content of the original + if (t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + seenTypes[typeId] = boundTo; + return; + } + + TypeId result = dest.addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(result); + LUAU_ASSERT(ttv != nullptr); + + *ttv = t; + + seenTypes[typeId] = result; + + ttv->level = TypeLevel{0, 0}; + + for (const auto& [name, prop] : t.props) + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.indexer) + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), + clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; + + for (TypeId& arg : ttv->instantiatedTypeParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + + if (ttv->state == TableState::Free) + { + cloneState.encounteredFreeType = true; + + ttv->state = TableState::Sealed; + } + + ttv->definitionModuleName = t.definitionModuleName; + ttv->methodDefinitionLocations = t.methodDefinitionLocations; + ttv->tags = t.tags; +} + +void TypeCloner::operator()(const MetatableTypeVar& t) +{ + TypeId result = dest.addType(MetatableTypeVar{}); + MetatableTypeVar* mtv = getMutable(result); + seenTypes[typeId] = result; + + mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); + mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); +} + +void TypeCloner::operator()(const ClassTypeVar& t) +{ + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + ClassTypeVar* ctv = getMutable(result); + + seenTypes[typeId] = result; + + for (const auto& [name, prop] : t.props) + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.parent) + ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); + + if (t.metatable) + ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); +} + +void TypeCloner::operator()(const AnyTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const UnionTypeVar& t) +{ + std::vector options; + options.reserve(t.options.size()); + + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; +} + +void TypeCloner::operator()(const IntersectionTypeVar& t) +{ + TypeId result = dest.addType(IntersectionTypeVar{}); + seenTypes[typeId] = result; + + IntersectionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.parts) + option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); +} + +void TypeCloner::operator()(const LazyTypeVar& t) +{ + defaultClone(t); +} + +} // anonymous namespace + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +{ + if (tp->persistent) + return tp; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypePackId& res = seenTypePacks[tp]; + + if (res == nullptr) + { + TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; + Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + } + + return res; +} + +TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +{ + if (typeId->persistent) + return typeId; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypeId& res = seenTypes[typeId]; + + if (res == nullptr) + { + TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; + Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + + return res; +} + +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +{ + TypeFun result; + + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } + + result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 210c0191..5eb2ea2a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Error.h" +#include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index d8906f6e..000769fe 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -2,6 +2,7 @@ #include "Luau/Frontend.h" #include "Luau/Common.h" +#include "Luau/Clone.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" @@ -16,8 +17,11 @@ #include #include +LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau { @@ -234,12 +238,6 @@ ErrorVec accumulateErrors( return result; } -struct RequireCycle -{ - Location location; - std::vector path; // one of the paths for a require() to go all the way back to the originating module -}; - // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) @@ -356,33 +354,55 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.dirty) + if (it != sourceNodes.end() && !it->second.isDirty(frontendOptions.forAutocomplete)) { // No recheck required. - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + if (FFlag::LuauSeparateTypechecks) + { + if (frontendOptions.forAutocomplete) + { + auto it2 = moduleResolverForAutocomplete.modules.find(name); + if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + } + else + { + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + } - return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + return CheckResult{accumulateErrors( + sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; + } + else + { + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + + return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + } } std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, checkResult, name); - - FrontendOptions frontendOptions = optionOverride.value_or(options); + bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete); // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + for (const ModuleName& moduleName : buildQueue) { LUAU_ASSERT(sourceNodes.count(moduleName)); SourceNode& sourceNode = sourceNodes[moduleName]; - if (!sourceNode.dirty) + if (!sourceNode.isDirty(frontendOptions.forAutocomplete)) continue; LUAU_ASSERT(sourceModules.count(moduleName)); @@ -408,13 +428,44 @@ CheckResult Frontend::check(const ModuleName& name, std::optionaltimeout) + checkResult.timeoutHits.push_back(moduleName); + + stats.timeCheck += getTimestamp() - timestamp; + stats.filesStrict += 1; + + sourceNode.dirtyAutocomplete = false; + continue; + } + + if (FFlag::LuauCyclicModuleTypeSurface) + typeChecker.requireCycles = requireCycles; + ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope); // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (frontendOptions.typecheckTwice) + if (!FFlag::LuauSeparateTypechecks && frontendOptions.typecheckTwice_DEPRECATED) { + if (FFlag::LuauCyclicModuleTypeSurface) + typeCheckerForAutocomplete.requireCycles = requireCycles; + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; } @@ -467,7 +518,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, CheckResult& checkResult, const ModuleName& root) +bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -486,7 +537,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec bool cyclic = false; { - auto [sourceNode, _] = getSourceNode(checkResult, root); + auto [sourceNode, _] = getSourceNode(checkResult, root, forAutocomplete); if (sourceNode) stack.push_back(sourceNode); } @@ -538,7 +589,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.dirty) + if (!it->second.isDirty(forAutocomplete)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization @@ -550,7 +601,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec } } - auto [sourceNode, _] = getSourceNode(checkResult, dep); + auto [sourceNode, _] = getSourceNode(checkResult, dep, forAutocomplete); if (sourceNode) { stack.push_back(sourceNode); @@ -594,7 +645,7 @@ LintResult Frontend::lint(const ModuleName& name, std::optionalsecond.dirty; + return it == sourceNodes.end() || it->second.isDirty(forAutocomplete); } /* @@ -699,8 +750,16 @@ bool Frontend::isDirty(const ModuleName& name) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.modules.count(name)) - return; + if (FFlag::LuauSeparateTypechecks) + { + if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) + return; + } + else + { + if (!moduleResolver.modules.count(name)) + return; + } std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) @@ -722,10 +781,21 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (sourceNode.dirty) - continue; + if (FFlag::LuauSeparateTypechecks) + { + if (sourceNode.dirty && sourceNode.dirtyAutocomplete) + continue; - sourceNode.dirty = true; + sourceNode.dirty = true; + sourceNode.dirtyAutocomplete = true; + } + else + { + if (sourceNode.dirty) + continue; + + sourceNode.dirty = true; + } if (0 == reverseDeps.count(name)) continue; @@ -752,13 +822,13 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.dirty) + if (it != sourceNodes.end() && !it->second.isDirty(forAutocomplete)) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -801,7 +871,19 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.name = name; sourceNode.requires.clear(); sourceNode.requireLocations.clear(); - sourceNode.dirty = true; + + if (FFlag::LuauSeparateTypechecks) + { + if (it == sourceNodes.end()) + { + sourceNode.dirty = true; + sourceNode.dirtyAutocomplete = true; + } + } + else + { + sourceNode.dirty = true; + } for (const auto& [moduleName, location] : requireTrace.requires) sourceNode.requires.insert(moduleName); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 19c2ddab..a8f67589 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -23,9 +23,178 @@ std::ostream& operator<<(std::ostream& stream, const AstName& name) return stream << ""; } -std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm) +template +static void errorToString(std::ostream& stream, const T& err) { - return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }"; + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + stream << "TypeMismatch { " << toString(err.wantedType) << ", " << toString(err.givenType) << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownSymbol { " << err.name << " , context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownProperty { " << toString(err.table) << ", key = " << err.key << " }"; + else if constexpr (std::is_same_v) + stream << "NotATable { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "CannotExtendTable { " << toString(err.tableType) << ", context " << err.context << ", prop \"" << err.prop << "\" }"; + else if constexpr (std::is_same_v) + stream << "OnlyTablesCanHaveMethods { " << toString(err.tableType) << " }"; + else if constexpr (std::is_same_v) + stream << "DuplicateTypeDefinition { " << err.name << " }"; + else if constexpr (std::is_same_v) + stream << "CountMismatch { expected " << err.expected << ", got " << err.actual << ", context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionDoesNotTakeSelf { }"; + else if constexpr (std::is_same_v) + stream << "FunctionRequiresSelf { extraNils " << err.requiredExtraNils << " }"; + else if constexpr (std::is_same_v) + stream << "OccursCheckFailed { }"; + else if constexpr (std::is_same_v) + stream << "UnknownRequire { " << err.modulePath << " }"; + else if constexpr (std::is_same_v) + { + stream << "IncorrectGenericParameterCount { name = " << err.name; + + if (!err.typeFun.typeParams.empty() || !err.typeFun.typePackParams.empty()) + { + stream << "<"; + bool first = true; + for (auto param : err.typeFun.typeParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.ty); + } + + for (auto param : err.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.tp); + } + + stream << ">"; + } + + stream << ", typeFun = " << toString(err.typeFun.type) << ", actualCount = " << err.actualParameters << " }"; + } + else if constexpr (std::is_same_v) + stream << "SyntaxError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CodeTooComplex {}"; + else if constexpr (std::is_same_v) + stream << "UnificationTooComplex {}"; + else if constexpr (std::is_same_v) + { + stream << "UnknownPropButFoundLikeProp { key = '" << err.key << "', suggested = { "; + + bool first = true; + for (Name name : err.candidates) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, table = " << toString(err.table) << " } "; + } + else if constexpr (std::is_same_v) + stream << "GenericError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "ExtraInformation { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "DeprecatedApiUsed { " << err.symbol << ", useInstead = " << err.useInstead << " }"; + else if constexpr (std::is_same_v) + { + stream << "ModuleHasCyclicDependency {"; + + bool first = true; + for (const ModuleName& name : err.cycle) + { + if (first) + first = false; + else + stream << ", "; + + stream << name; + } + + stream << "}"; + } + else if constexpr (std::is_same_v) + stream << "IllegalRequire { " << err.moduleName << ", reason = " << err.reason << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionExitsWithoutReturning {" << toString(err.expectedReturnType) << "}"; + else if constexpr (std::is_same_v) + stream << "DuplicateGenericParameter { " + err.parameterName + " }"; + else if constexpr (std::is_same_v) + stream << "CannotInferBinaryOperation { op = " + toString(err.op) + ", suggested = '" + + (err.suggestedToAnnotate ? *err.suggestedToAnnotate : "") + "', kind " + << err.kind << "}"; + else if constexpr (std::is_same_v) + { + stream << "MissingProperties { superType = '" << toString(err.superType) << "', subType = '" << toString(err.subType) << "', properties = { "; + + bool first = true; + for (Name name : err.properties) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, context " << err.context << " } "; + } + else if constexpr (std::is_same_v) + stream << "SwappedGenericTypeParameter { name = '" + err.name + "', kind = " + std::to_string(err.kind) + " }"; + else if constexpr (std::is_same_v) + stream << "OptionalValueAccess { optional = '" + toString(err.optional) + "' }"; + else if constexpr (std::is_same_v) + { + stream << "MissingUnionProperty { type = '" + toString(err.type) + "', missing = { "; + + bool first = true; + for (auto ty : err.missing) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " }, key = '" + err.key + "' }"; + } + else if constexpr (std::is_same_v) + stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) +{ + auto cb = [&](const auto& e) { + return errorToString(stream, e); + }; + visit(cb, data); + return stream; } std::ostream& operator<<(std::ostream& stream, const TypeError& error) @@ -33,241 +202,6 @@ std::ostream& operator<<(std::ostream& stream, const TypeError& error) return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }"; } -std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error) -{ - return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error) -{ - return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const NotATable& ge) -{ - return stream << "NotATable { " << toString(ge.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error) -{ - return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }"; -} - -std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error) -{ - return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error) -{ - return stream << "DuplicateTypeDefinition { " << error.name << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CountMismatch& error) -{ - return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&) -{ - return stream << "FunctionDoesNotTakeSelf { }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error) -{ - return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&) -{ - return stream << "OccursCheckFailed { }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error) -{ - return stream << "UnknownRequire { " << error.modulePath << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error) -{ - stream << "IncorrectGenericParameterCount { name = " << error.name; - - if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty()) - { - stream << "<"; - bool first = true; - for (auto param : error.typeFun.typeParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(param.ty); - } - - for (auto param : error.typeFun.typePackParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(param.tp); - } - - stream << ">"; - } - - stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }"; - return stream; -} - -std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge) -{ - return stream << "SyntaxError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&) -{ - return stream << "CodeTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&) -{ - return stream << "UnificationTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e) -{ - stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { "; - - bool first = true; - for (Name name : e.candidates) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, table = " << toString(e.table) << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const GenericError& ge) -{ - return stream << "GenericError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e) -{ - return stream << "CannotCallNonFunction { " << toString(e.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error) -{ - return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e) -{ - return stream << "ExtraInformation { " << e.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e) -{ - return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e) -{ - stream << "ModuleHasCyclicDependency {"; - - bool first = true; - for (const ModuleName& name : e.cycle) - { - if (first) - first = false; - else - stream << ", "; - - stream << name; - } - - return stream << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e) -{ - return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingProperties& e) -{ - stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { "; - - bool first = true; - for (Name name : e.properties) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, context " << e.context << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error) -{ - return stream << "DuplicateGenericParameter { " + error.parameterName + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error) -{ - return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" + - (error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind " - << error.kind << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error) -{ - return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error) -{ - return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error) -{ - stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { "; - - bool first = true; - for (auto ty : error.missing) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << toString(ty) << "'"; - } - - return stream << " }, key = '" + error.key + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) -{ - stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; - return stream; -} - std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); @@ -283,15 +217,4 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) return stream << toString(tv); } -std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted) -{ - Luau::visit( - [&](const auto& a) { - lhs << a; - }, - ted); - - return lhs; -} - } // namespace Luau diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0787d3a4..6bb45245 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -2,6 +2,7 @@ #include "Luau/Module.h" #include "Luau/Common.h" +#include "Luau/Clone.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -12,7 +13,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) namespace Luau @@ -113,363 +113,6 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } -namespace -{ - -struct TypePackCloner; - -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ - -struct TypeCloner -{ - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) - : dest(dest) - , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) - , cloneState(cloneState) - { - } - - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - template - void defaultClone(const T& t); - - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); - void operator()(const PrimitiveTypeVar& t); - void operator()(const SingletonTypeVar& t); - void operator()(const FunctionTypeVar& t); - void operator()(const TableTypeVar& t); - void operator()(const MetatableTypeVar& t); - void operator()(const ClassTypeVar& t); - void operator()(const AnyTypeVar& t); - void operator()(const UnionTypeVar& t); - void operator()(const IntersectionTypeVar& t); - void operator()(const LazyTypeVar& t); -}; - -struct TypePackCloner -{ - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) - : dest(dest) - , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) - , cloneState(cloneState) - { - } - - template - void defaultClone(const T& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{t}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const Unifiable::Free& t) - { - cloneState.encounteredFreeType = true; - - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const Unifiable::Generic& t) - { - defaultClone(t); - } - void operator()(const Unifiable::Error& t) - { - defaultClone(t); - } - - // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) - { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const VariadicTypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const TypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; - - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); - } -}; - -template -void TypeCloner::defaultClone(const T& t) -{ - TypeId cloned = dest.addType(t); - seenTypes[typeId] = cloned; -} - -void TypeCloner::operator()(const Unifiable::Free& t) -{ - cloneState.encounteredFreeType = true; - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; -} - -void TypeCloner::operator()(const Unifiable::Generic& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypes[typeId] = boundTo; -} - -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const PrimitiveTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const SingletonTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const FunctionTypeVar& t) -{ - TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionTypeVar* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); - - ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); - ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const TableTypeVar& t) -{ - // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) - { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypes[typeId] = boundTo; - return; - } - - TypeId result = dest.addType(TableTypeVar{}); - TableTypeVar* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); - - *ttv = t; - - seenTypes[typeId] = result; - - ttv->level = TypeLevel{0, 0}; - - for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; - - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; - - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - - if (ttv->state == TableState::Free) - { - cloneState.encounteredFreeType = true; - - ttv->state = TableState::Sealed; - } - - ttv->definitionModuleName = t.definitionModuleName; - ttv->methodDefinitionLocations = t.methodDefinitionLocations; - ttv->tags = t.tags; -} - -void TypeCloner::operator()(const MetatableTypeVar& t) -{ - TypeId result = dest.addType(MetatableTypeVar{}); - MetatableTypeVar* mtv = getMutable(result); - seenTypes[typeId] = result; - - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const ClassTypeVar& t) -{ - TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); - ClassTypeVar* ctv = getMutable(result); - - seenTypes[typeId] = result; - - for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; - - if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); - - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const AnyTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const UnionTypeVar& t) -{ - std::vector options; - options.reserve(t.options.size()); - - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - TypeId result = dest.addType(UnionTypeVar{std::move(options)}); - seenTypes[typeId] = result; -} - -void TypeCloner::operator()(const IntersectionTypeVar& t) -{ - TypeId result = dest.addType(IntersectionTypeVar{}); - seenTypes[typeId] = result; - - IntersectionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); -} - -void TypeCloner::operator()(const LazyTypeVar& t) -{ - defaultClone(t); -} - -} // anonymous namespace - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - if (tp->persistent) - return tp; - - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = seenTypePacks[tp]; - - if (res == nullptr) - { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. - } - - return res; -} - -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - if (typeId->persistent) - return typeId; - - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = seenTypes[typeId]; - - if (res == nullptr) - { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - - return res; -} - -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - TypeFun result; - - for (auto param : typeFun.typeParams) - { - TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); - - result.typeParams.push_back({ty, defaultValue}); - } - - for (auto param : typeFun.typePackParams) - { - TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); - - result.typePackParams.push_back({tp, defaultValue}); - } - - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); - - return result; -} - ScopePtr Module::getModuleScope() const { LUAU_ASSERT(!scopes.empty()); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 876f5f05..5fbb596d 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) + namespace Luau { @@ -78,11 +80,32 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { - for (auto& [ty, rep] : typeVarChanges) - *asMutable(ty) = rep.get()->pending; + if (FFlag::LuauTxnLogPreserveOwner) + { + for (auto& [ty, rep] : typeVarChanges) + { + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; + } - for (auto& [tp, rep] : typePackChanges) - *asMutable(tp) = rep.get()->pending; + for (auto& [tp, rep] : typePackChanges) + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; + } + } + else + { + for (auto& [ty, rep] : typeVarChanges) + *asMutable(ty) = rep.get()->pending; + + for (auto& [tp, rep] : typePackChanges) + *asMutable(tp) = rep.get()->pending; + } clear(); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6df6bff0..10930248 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,6 +22,9 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) +LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) +LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) @@ -35,7 +38,7 @@ LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify3, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -46,6 +49,7 @@ LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) +LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) @@ -53,6 +57,11 @@ LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) namespace Luau { +const char* TimeLimitError::what() const throw() +{ + return "Typeinfer failed to complete in allotted time"; +} + static bool typeCouldHaveMetatable(TypeId ty) { return get(follow(ty)) || get(follow(ty)) || get(follow(ty)); @@ -251,6 +260,12 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona currentModule.reset(new Module()); currentModule->type = module.type; + if (FFlag::LuauSeparateTypechecks) + { + currentModule->allocator = module.allocator; + currentModule->names = module.names; + } + iceHandler->moduleName = module.name; ScopePtr parentScope = environmentScope.value_or(globalScope); @@ -271,7 +286,21 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (prepareModuleScope) prepareModuleScope(module.name, currentModule->getModuleScope()); - checkBlock(moduleScope, *module.root); + if (FFlag::LuauSeparateTypechecks) + { + try + { + checkBlock(moduleScope, *module.root); + } + catch (const TimeLimitError&) + { + currentModule->timeout = true; + } + } + else + { + checkBlock(moduleScope, *module.root); + } if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); @@ -366,6 +395,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) } else ice("Unknown AstStat"); + + if (FFlag::LuauSeparateTypechecks && finishTime && TimeTrace::getClock() > *finishTime) + throw TimeLimitError(); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -1115,22 +1147,18 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify3) + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify4) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); - if (!ttv) + + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) { - if (isTableIntersection(exprTy)) + if (ttv || isTableIntersection(exprTy)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); - else if (!get(exprTy) && !get(exprTy)) + else reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); } - else if (ttv->state == TableState::Sealed) - { - if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) - reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); - } ty = follow(ty); @@ -1153,7 +1181,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify3) + else if (FFlag::LuauStatFunctionSimplify4) { LUAU_ASSERT(function.name->is()); @@ -1163,7 +1191,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (function.func->self) { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); AstExprIndexName* indexName = function.name->as(); if (!indexName) @@ -1202,7 +1230,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); TypeId leftType = checkLValueBinding(scope, *function.name); @@ -2030,7 +2058,11 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) + { + if (FFlag::LuauCheckImplicitNumbericKeys) + unify(numberType, indexer->indexType, value->location); unify(valueType, indexer->indexResultType, value->location); + } else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -2984,35 +3016,33 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - if (get(lhsType) || get(lhsType)) + + if (!FFlag::LuauStatFunctionSimplify4 && (get(lhsType) || get(lhsType))) return lhsType; TableTypeVar* ttv = getMutableTableType(lhsType); - if (!ttv) + + if (FFlag::LuauStatFunctionSimplify4) { - if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) - // This error now gets reported when we check the function body. - reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - - return errorRecoveryType(scope); - } - - if (FFlag::LuauStatFunctionSimplify3) - { - if (lhsType->persistent) - return errorRecoveryType(scope); - - // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check - if (ttv->state == TableState::Sealed) + if (!ttv || ttv->state == TableState::Sealed) { - if (ttv->indexer && isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) - return ttv->indexer->indexResultType; - else - return errorRecoveryType(scope); + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + return *ty; + + return errorRecoveryType(scope); } } else { + if (!ttv) + { + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. + reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); + + return errorRecoveryType(scope); + } + if (lhsType->persistent || ttv->state == TableState::Sealed) return errorRecoveryType(scope); } @@ -3020,7 +3050,12 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T Name name = indexName->index.value; if (ttv->props.count(name)) - return errorRecoveryType(scope); + { + if (FFlag::LuauStatFunctionSimplify4) + return ttv->props[name].type; + else + return errorRecoveryType(scope); + } Property& property = ttv->props[name]; @@ -4155,6 +4190,20 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return anyType; } + // Types of requires that transitively refer to current module have to be replaced with 'any' + std::string humanReadableName; + + if (FFlag::LuauCyclicModuleTypeSurface) + { + humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == humanReadableName) + return anyType; + } + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4163,8 +4212,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) { - std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(TypeError{location, UnknownRequire{reportedModulePath}}); + if (FFlag::LuauCyclicModuleTypeSurface) + { + reportError(TypeError{location, UnknownRequire{humanReadableName}}); + } + else + { + std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(TypeError{location, UnknownRequire{reportedModulePath}}); + } } return errorRecoveryType(scope); @@ -4172,8 +4228,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (module->type != SourceCode::Module) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + if (FFlag::LuauCyclicModuleTypeSurface) + { + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + } + else + { + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + } return errorRecoveryType(scope); } @@ -4185,8 +4248,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + if (FFlag::LuauCyclicModuleTypeSurface) + { + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + } + else + { + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + } return errorRecoveryType(scope); } @@ -4629,7 +4699,9 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { - // TODO: cache singleton types + if (FFlag::LuauAutocompleteSingletonTypes) + return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 36545ad9..dbc412fc 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -652,6 +652,8 @@ static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persist static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; +static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; +static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}}; static TypeVar errorType_{ErrorTypeVar{}}; static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; @@ -665,6 +667,8 @@ SingletonTypes::SingletonTypes() , stringType(&stringType_) , booleanType(&booleanType_) , threadType(&threadType_) + , trueType(&trueType_) + , falseType(&falseType_) , anyType(&anyType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 5018456f..9f7b2bdf 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -9,14 +9,21 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) +namespace Luau +{ +namespace TimeTrace +{ +double getClock(); +uint32_t getClockMicroseconds(); +} // namespace TimeTrace +} // namespace Luau + #if defined(LUAU_ENABLE_TIME_TRACE) namespace Luau { namespace TimeTrace { -uint32_t getClockMicroseconds(); - struct Token { const char* name; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f6dfd904..f9d32178 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,6 +10,7 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) namespace Luau { @@ -1420,6 +1421,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } + else if (FFlag::LuauParseRecoverUnexpectedPack && c == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type annotation"); + nextLexeme(); + } else break; } @@ -1536,6 +1542,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) prefix = name.name; name = parseIndexName("field name", pointPosition); } + else if (FFlag::LuauParseRecoverUnexpectedPack && lexer.current().type == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); + nextLexeme(); + } else if (name.name == "typeof") { Lexeme typeofBegin = lexer.current(); diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index 19564f05..e3807683 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -26,9 +26,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) - -#if defined(LUAU_ENABLE_TIME_TRACE) - namespace Luau { namespace TimeTrace @@ -67,6 +64,14 @@ static double getClockTimestamp() #endif } +double getClock() +{ + static double period = getClockPeriod(); + static double start = getClockTimestamp(); + + return (getClockTimestamp() - start) * period; +} + uint32_t getClockMicroseconds() { static double period = getClockPeriod() * 1e6; @@ -74,7 +79,15 @@ uint32_t getClockMicroseconds() return uint32_t((getClockTimestamp() - start) * period); } +} // namespace TimeTrace +} // namespace Luau +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ struct GlobalContext { GlobalContext() = default; diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 60a7c169..35ea0bf0 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -290,7 +290,8 @@ struct ConstantVisitor : AstVisitor Constant la = analyze(expr->left); Constant ra = analyze(expr->right); - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + // note: ra doesn't need to be constant to fold and/or + if (la.type != Constant::Type_Unknown) foldBinary(result, expr->op, la, ra); } else if (AstExprTypeAssertion* expr = node->as()) diff --git a/Sources.cmake b/Sources.cmake index 59b38497..6f110f1f 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -47,6 +47,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Config.h + Analysis/include/Luau/Clone.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -85,6 +86,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Autocomplete.cpp Analysis/src/BuiltinDefinitions.cpp Analysis/src/Config.cpp + Analysis/src/Clone.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/IostreamHelpers.cpp diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 8047cebe..a260d00a 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -14,6 +14,6 @@ LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); -void luaF_unlinkupval(UpVal* uv); +LUAI_FUNC void luaF_unlinkupval(UpVal* uv); LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 241a99e3..41887f4b 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -201,7 +201,7 @@ static int tmove(lua_State* L) void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; - if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb) + if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) { int nf = lua_objlen(L, 1); int nt = lua_objlen(L, tt); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 4e8a1d55..2e7902f5 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauSeparateTypechecks) using namespace Luau; @@ -25,6 +26,11 @@ static std::optional nullCallback(std::string tag, std::op template struct ACFixtureImpl : BaseType { + ACFixtureImpl() + : Fixture(true, true) + { + } + AutocompleteResult autocomplete(unsigned row, unsigned column) { return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); @@ -72,7 +78,25 @@ struct ACFixtureImpl : BaseType } LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); - return Fixture::check(filteredSource); + return BaseType::check(filteredSource); + } + + LoadDefinitionFileResult loadDefinition(const std::string& source) + { + if (FFlag::LuauSeparateTypechecks) + { + TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); + return result; + } + else + { + return BaseType::loadDefinition(source); + } } const Position& getPosition(char marker) const @@ -2496,7 +2520,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { @@ -2504,13 +2528,11 @@ TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") } )"); - fileResolver.source["Module/A"] = R"( - local a = y. - )"; + check(R"( + local a = y.@1 + )"); - frontend.check("Module/A"); - - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); @@ -2736,6 +2758,107 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") CHECK(ac.entryMap.count("format")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + ScopedFastFlag luauExpectedTypesOfProperties{"LuauExpectedTypesOfProperties", true}; + + check(R"( + type tag = "cat" | "dog" + local function f(a: tag) end + f("@1") + f(@2) + local x: tag = "@3" + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("\"cat\"")); + CHECK(ac.entryMap.count("\"dog\"")); + + ac = autocomplete('3'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="@4"} + )"); + + ac = autocomplete('4'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="cat", fieldx=2} + if x.tag == "@1" or "@2" ~= x.tag then end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + // CLI-48823: assignment to x.tag should also autocomplete, but union l-values are not supported yet +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + + check(R"( +local function f(x: true) end +f(@1) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("true")); + CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); + REQUIRE(ac.entryMap.count("false")); + CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + + check(R"( + type tag = "strange\t\"cat\"" | 'nice\t"dog"' + local function f(x: tag) end + f(@1) + f("@2") + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("\"strange\\t\\\"cat\\\"\"")); + CHECK(ac.entryMap.count("\"nice\\t\\\"dog\\\"\"")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("strange\\t\\\"cat\\\"")); + CHECK(ac.entryMap.count("nice\\t\\\"dog\\\"")); +} + TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { check(R"( diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3dc57da0..83dad729 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1074,6 +1074,39 @@ RETURN R1 1 )"); } +TEST_CASE("AndOrFoldLeft") +{ + // constant folding and/or expression is possible even if just the left hand is constant + CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); + + // however, if right hand side is constant we can't constant fold the entire expression + // (note that we don't need to evaluate the right hand side, but we do need a branch) + CHECK_EQ("\n" + compileFunction0("local a = false if b and a then b() end"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 +4 +RETURN R0 0 +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if b or a then b() end"), R"( +GETIMPORT R0 1 +JUMPIF R0 +0 +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); +} + TEST_CASE("AndOrChainCodegen") { const char* source = R"( diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index a7e7ea39..9dc9feee 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -83,7 +83,7 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul return std::nullopt; } -Fixture::Fixture(bool freeze) +Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) , typeChecker(frontend.typeChecker) @@ -93,8 +93,11 @@ Fixture::Fixture(bool freeze) configResolver.defaultConfig.parseOptions.captureComments = true; registerBuiltinTypes(frontend.typeChecker); + if (prepareAutocomplete) + registerBuiltinTypes(frontend.typeCheckerForAutocomplete); registerTestTypes(); Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); } diff --git a/tests/Fixture.h b/tests/Fixture.h index 4e45a952..0d1233bf 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -91,7 +91,7 @@ struct TestConfigResolver : ConfigResolver struct Fixture { - explicit Fixture(bool freeze = true); + explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 8a59acd1..9fc0a005 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -384,6 +384,70 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") CHECK_EQ(ce2->cycle[1], "game/Gui/Modules/A"); } +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") +{ + ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; + + fileResolver.source["game/A"] = R"( + return {hello = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.A) + return {hello = 2} + )"; + frontend.markDirty("game/A"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + auto ty = requireType("game/A", "me"); + CHECK_EQ(toString(ty), "any"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") +{ + ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; + + fileResolver.source["game/A"] = R"( + return {mod_a = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/B"] = R"( + local me = require(game.A) + return {mod_b = 4} + )"; + + result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.B) + return {mod_a_prime = 3} + )"; + + frontend.markDirty("game/A"); + frontend.markDirty("game/B"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyA = requireType("game/A", "me"); + CHECK_EQ(toString(tyA), "any"); + + result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyB = requireType("game/B", "me"); + CHECK_EQ(toString(tyB), "any"); +} + TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") { fileResolver.source["Modules/A"] = R"( diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 82b7a350..de063121 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/Scope.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 7dacc669..79f9ecab 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2022,6 +2022,15 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") +{ + ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; + + matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", + Location{{0, 20}, {0, 23}}); + matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); +} + TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") { { @@ -2590,4 +2599,16 @@ type Y = (T...) -> U... CHECK_EQ(1, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") +{ + ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; + + ParseResult result = tryParse(R"( +type X = { a: T..., b: number } +type Y = { a: T..., b: number } +type Z = { a: string | T..., b: number } + )"); + REQUIRE_EQ(3, result.errors.size()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index dbae7b54..1713216a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1270,7 +1270,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict @@ -1294,7 +1294,7 @@ end TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( function string.len(): number @@ -1316,7 +1316,7 @@ print(string.len('hello')) TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite_2") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; CheckResult result = check(R"( @@ -1324,12 +1324,12 @@ local t: { f: ((x: number) -> number)? } = {} function t.f(x) print(x + 5) - return x .. "asd" + return x .. "asd" -- 1st error: we know that return type is a number, not a string end t.f = function(x) print(x + 5) - return x .. "asd" + return x .. "asd" -- 2nd error: we know that return type is a number, not a string end )"); @@ -1338,6 +1338,33 @@ end CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; + + CheckResult result = check(R"( +local t = { f = nil :: ((x: number) -> number)? } + +function t.f(x: string): string -- 1st error: new function value type is incompatible + return x .. "asd" +end + +t.f = function(x) + print(x + 5) + return x .. "asd" -- 2nd error: we know that return type is a number, not a string +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string) -> string' could not be converted into '((number) -> number)?' +caused by: + None of the union options are compatible. For example: Type '(string) -> string' could not be converted into '(number) -> number' +caused by: + Argument #1 type is not compatible. Type 'number' could not be converted into 'string')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); +} + TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; @@ -1352,7 +1379,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( local t: {[string]: () -> number} = {} diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index d146f4e8..ac7a6532 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -311,6 +311,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } @@ -326,10 +328,39 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") function xy:w(a:number) return a * 10 end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'y' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[2]), "Cannot add property 'w' to table 'X & Y'"); + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); +} + +TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + + // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one + CheckResult result = check(R"( + type XY = { x: (number) -> number, y: (string) -> string } + + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'XY'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 44b7b0d0..3ddf9813 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -95,6 +95,8 @@ end )"); LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add method to non-table type 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0cc12d19..0484351d 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2922,4 +2922,19 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") +{ + ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; + + CheckResult result = check(R"( + local t: { [string]: number } = { 5, 6, 7 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index d8de2594..c21e1625 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -242,4 +242,30 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") state.tryUnify(&func, typeChecker.anyType); } +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") +{ + ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; + + TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); + TypeId b = typeChecker.numberType; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") +{ + ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; + + TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + TypePackId b = typeChecker.anyTypePack; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 68b7c4fb..ff207a18 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -513,4 +513,29 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + + CheckResult result = check(R"( + type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } + + local a:A = nil + + function a.y(x) + return tostring(x * 2) + end + + function a.y(x: string): number + return tonumber(x) or 0 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // NOTE: union normalization will improve this message + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); +} + + TEST_SUITE_END(); diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index 40f8d6be..b3d2b4f5 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -37,7 +37,7 @@ def getType(target, typeName): return ty def luau_variant_summary(valobj, internal_dict, options): - type_id = valobj.GetChildMemberWithName("typeid").GetValueAsUnsigned() + type_id = valobj.GetChildMemberWithName("typeId").GetValueAsUnsigned() storage = valobj.GetChildMemberWithName("storage") params = templateParams(valobj.GetType().GetCanonicalType().GetName()) stored_type = params[type_id] @@ -89,7 +89,7 @@ class LuauVariantSyntheticChildrenProvider: return None def update(self): - self.type_index = self.valobj.GetChildMemberWithName("typeid").GetValueAsSigned() + self.type_index = self.valobj.GetChildMemberWithName("typeId").GetValueAsSigned() self.type_params = templateParams(self.valobj.GetType().GetCanonicalType().GetName()) if len(self.type_params) > self.type_index: From 510aed7d3ffcb509aaad264f6fd763c4b9c7f6b2 Mon Sep 17 00:00:00 2001 From: Lily Brown Date: Fri, 8 Apr 2022 11:26:47 -0700 Subject: [PATCH 044/102] Fix JsonEncoder for AstExprTable (#454) JsonEncoder wasn't producing valid JSON for `AstExprTable`s. This PR fixes it. The new output looks like ```json { "type": "AstStatBlock", "location": "0,0 - 6,4", "body": [ { "type": "AstStatLocal", "location": "1,8 - 5,9", "vars": [ { "name": "x", "location": "1,14 - 1,15" } ], "values": [ { "type": "AstExprTable", "location": "3,12 - 5,9", "items": [ { "kind": "record", "key": { "type": "AstExprConstantString", "location": "4,12 - 4,15", "value": "foo" }, "value": { "type": "AstExprConstantNumber", "location": "4,18 - 4,21", "value": 123 } } ] } ] } ] } ``` --- Analysis/src/JsonEncoder.cpp | 21 ++++++--------------- tests/JsonEncoder.test.cpp | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 811e7c24..829ffa02 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -403,35 +403,26 @@ struct AstJsonEncoder : public AstVisitor void write(const AstExprTable::Item& item) { writeRaw("{"); - bool comma = pushComma(); + bool c = pushComma(); write("kind", item.kind); switch (item.kind) { case AstExprTable::Item::List: - write(item.value); + write("value", item.value); break; default: - write(item.key); - writeRaw(","); - write(item.value); + write("key", item.key); + write("value", item.value); break; } - popComma(comma); + popComma(c); writeRaw("}"); } void write(class AstExprTable* node) { writeNode(node, "AstExprTable", [&]() { - bool comma = false; - for (const auto& prop : node->items) - { - if (comma) - writeRaw(","); - else - comma = true; - write(prop); - } + PROP(items); }); } diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index cb508072..6711d979 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Ast.h" #include "Luau/JsonEncoder.h" +#include "Luau/Parser.h" #include "doctest.h" @@ -50,4 +51,24 @@ TEST_CASE("encode_AstStatBlock") toJson(&block)); } +TEST_CASE("encode_tables") +{ + std::string src = R"( + local x: { + foo: number + } = { + foo = 123, + } + )"; + + Allocator allocator; + AstNameTable names(allocator); + ParseResult parseResult = Parser::parse(src.c_str(), src.length(), names, allocator); + + REQUIRE(parseResult.errors.size() == 0); + std::string json = toJson(parseResult.root); + + CHECK(json == R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); +} + TEST_SUITE_END(); From d37d0c857ba543ea47f0b8fce5678f7aadf5239e Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Sat, 9 Apr 2022 00:07:08 -0500 Subject: [PATCH 045/102] Prototype: Renamed any/none to unknown/never (#447) * Renamed any/none to unknown/never * Pin hackage version * Update Agda version --- .github/workflows/prototyping.yml | 10 +- prototyping/Luau/StrictMode.agda | 2 +- prototyping/Luau/Subtyping.agda | 6 +- prototyping/Luau/Type.agda | 104 +++++++-------- prototyping/Luau/Type/FromJSON.agda | 6 +- prototyping/Luau/Type/ToString.agda | 6 +- prototyping/Luau/TypeCheck.agda | 14 +- prototyping/Properties/StrictMode.agda | 62 ++++----- prototyping/Properties/Subtyping.agda | 122 +++++++++--------- prototyping/Properties/TypeCheck.agda | 20 +-- .../Tests/PrettyPrinter/smoke_test/out.txt | 6 +- 11 files changed, 181 insertions(+), 177 deletions(-) diff --git a/.github/workflows/prototyping.yml b/.github/workflows/prototyping.yml index 6bc8a81b..ff66881d 100644 --- a/.github/workflows/prototyping.yml +++ b/.github/workflows/prototyping.yml @@ -10,7 +10,9 @@ jobs: linux: strategy: matrix: - agda: [2.6.2.1] + agda: [2.6.2.2] + hackageDate: ["2022-04-07"] + hackageTime: ["23:06:28"] name: prototyping runs-on: ubuntu-latest steps: @@ -18,7 +20,7 @@ jobs: - uses: actions/cache@v2 with: path: ~/.cabal/store - key: prototyping-${{ runner.os }}-${{ matrix.agda }} + key: "prototyping-${{ runner.os }}-${{ matrix.agda }}-${{ matrix.hackageDate }}-${{ matrix.hackageTime }}" - uses: actions/cache@v2 id: luau-ast-cache with: @@ -28,12 +30,12 @@ jobs: run: sudo apt-get install -y cabal-install - name: cabal update working-directory: prototyping - run: cabal update + run: cabal v2-update "hackage.haskell.org,${{ matrix.hackageDate }}T${{ matrix.hackageTime }}Z" - name: cabal install working-directory: prototyping run: | - cabal install Agda-${{ matrix.agda }} cabal install --lib scientific vector aeson --package-env . + cabal install --allow-newer Agda-${{ matrix.agda }} - name: check targets working-directory: prototyping run: | diff --git a/prototyping/Luau/StrictMode.agda b/prototyping/Luau/StrictMode.agda index 0b5fe0da..b6769f01 100644 --- a/prototyping/Luau/StrictMode.agda +++ b/prototyping/Luau/StrictMode.agda @@ -5,7 +5,7 @@ module Luau.StrictMode where open import Agda.Builtin.Equality using (_≡_) open import FFI.Data.Maybe using (just; nothing) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; var; binexp; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; local_←_; _∙_; done; return; name; +; -; *; /; <; >; <=; >=; ··) -open import Luau.Type using (Type; strict; nil; number; string; boolean; none; any; _⇒_; _∪_; _∩_; tgt) +open import Luau.Type using (Type; strict; nil; number; string; boolean; _⇒_; _∪_; _∩_; tgt) open import Luau.Subtyping using (_≮:_) open import Luau.Heap using (Heap; function_is_end) renaming (_[_] to _[_]ᴴ) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) diff --git a/prototyping/Luau/Subtyping.agda b/prototyping/Luau/Subtyping.agda index 7d67eb43..943f459b 100644 --- a/prototyping/Luau/Subtyping.agda +++ b/prototyping/Luau/Subtyping.agda @@ -1,6 +1,6 @@ {-# OPTIONS --rewriting #-} -open import Luau.Type using (Type; Scalar; nil; number; string; boolean; none; any; _⇒_; _∪_; _∩_) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) open import Properties.Equality using (_≢_) module Luau.Subtyping where @@ -29,7 +29,7 @@ data Language where left : ∀ {T U t} → Language T t → Language (T ∪ U) t right : ∀ {T U u} → Language U u → Language (T ∪ U) u _,_ : ∀ {T U t} → Language T t → Language U t → Language (T ∩ U) t - any : ∀ {t} → Language any t + unknown : ∀ {t} → Language unknown t data ¬Language where @@ -42,7 +42,7 @@ data ¬Language where _,_ : ∀ {T U t} → ¬Language T t → ¬Language U t → ¬Language (T ∪ U) t left : ∀ {T U t} → ¬Language T t → ¬Language (T ∩ U) t right : ∀ {T U u} → ¬Language U u → ¬Language (T ∩ U) u - none : ∀ {t} → ¬Language none t + never : ∀ {t} → ¬Language never t -- Subtyping as language inclusion diff --git a/prototyping/Luau/Type.agda b/prototyping/Luau/Type.agda index 87815ddb..30c45388 100644 --- a/prototyping/Luau/Type.agda +++ b/prototyping/Luau/Type.agda @@ -9,8 +9,8 @@ open import FFI.Data.Maybe using (Maybe; just; nothing) data Type : Set where nil : Type _⇒_ : Type → Type → Type - none : Type - any : Type + never : Type + unknown : Type boolean : Type number : Type string : Type @@ -29,8 +29,8 @@ lhs (T ⇒ _) = T lhs (T ∪ _) = T lhs (T ∩ _) = T lhs nil = nil -lhs none = none -lhs any = any +lhs never = never +lhs unknown = unknown lhs number = number lhs boolean = boolean lhs string = string @@ -40,8 +40,8 @@ rhs (_ ⇒ T) = T rhs (_ ∪ T) = T rhs (_ ∩ T) = T rhs nil = nil -rhs none = none -rhs any = any +rhs never = never +rhs unknown = unknown rhs number = number rhs boolean = boolean rhs string = string @@ -49,16 +49,16 @@ rhs string = string _≡ᵀ_ : ∀ (T U : Type) → Dec(T ≡ U) nil ≡ᵀ nil = yes refl nil ≡ᵀ (S ⇒ T) = no (λ ()) -nil ≡ᵀ none = no (λ ()) -nil ≡ᵀ any = no (λ ()) +nil ≡ᵀ never = no (λ ()) +nil ≡ᵀ unknown = no (λ ()) nil ≡ᵀ number = no (λ ()) nil ≡ᵀ boolean = no (λ ()) nil ≡ᵀ (S ∪ T) = no (λ ()) nil ≡ᵀ (S ∩ T) = no (λ ()) nil ≡ᵀ string = no (λ ()) (S ⇒ T) ≡ᵀ string = no (λ ()) -none ≡ᵀ string = no (λ ()) -any ≡ᵀ string = no (λ ()) +never ≡ᵀ string = no (λ ()) +unknown ≡ᵀ string = no (λ ()) boolean ≡ᵀ string = no (λ ()) number ≡ᵀ string = no (λ ()) (S ∪ T) ≡ᵀ string = no (λ ()) @@ -68,48 +68,48 @@ number ≡ᵀ string = no (λ ()) (S ⇒ T) ≡ᵀ (S ⇒ T) | yes refl | yes refl = yes refl (S ⇒ T) ≡ᵀ (U ⇒ V) | _ | no p = no (λ q → p (cong rhs q)) (S ⇒ T) ≡ᵀ (U ⇒ V) | no p | _ = no (λ q → p (cong lhs q)) -(S ⇒ T) ≡ᵀ none = no (λ ()) -(S ⇒ T) ≡ᵀ any = no (λ ()) +(S ⇒ T) ≡ᵀ never = no (λ ()) +(S ⇒ T) ≡ᵀ unknown = no (λ ()) (S ⇒ T) ≡ᵀ number = no (λ ()) (S ⇒ T) ≡ᵀ boolean = no (λ ()) (S ⇒ T) ≡ᵀ (U ∪ V) = no (λ ()) (S ⇒ T) ≡ᵀ (U ∩ V) = no (λ ()) -none ≡ᵀ nil = no (λ ()) -none ≡ᵀ (U ⇒ V) = no (λ ()) -none ≡ᵀ none = yes refl -none ≡ᵀ any = no (λ ()) -none ≡ᵀ number = no (λ ()) -none ≡ᵀ boolean = no (λ ()) -none ≡ᵀ (U ∪ V) = no (λ ()) -none ≡ᵀ (U ∩ V) = no (λ ()) -any ≡ᵀ nil = no (λ ()) -any ≡ᵀ (U ⇒ V) = no (λ ()) -any ≡ᵀ none = no (λ ()) -any ≡ᵀ any = yes refl -any ≡ᵀ number = no (λ ()) -any ≡ᵀ boolean = no (λ ()) -any ≡ᵀ (U ∪ V) = no (λ ()) -any ≡ᵀ (U ∩ V) = no (λ ()) +never ≡ᵀ nil = no (λ ()) +never ≡ᵀ (U ⇒ V) = no (λ ()) +never ≡ᵀ never = yes refl +never ≡ᵀ unknown = no (λ ()) +never ≡ᵀ number = no (λ ()) +never ≡ᵀ boolean = no (λ ()) +never ≡ᵀ (U ∪ V) = no (λ ()) +never ≡ᵀ (U ∩ V) = no (λ ()) +unknown ≡ᵀ nil = no (λ ()) +unknown ≡ᵀ (U ⇒ V) = no (λ ()) +unknown ≡ᵀ never = no (λ ()) +unknown ≡ᵀ unknown = yes refl +unknown ≡ᵀ number = no (λ ()) +unknown ≡ᵀ boolean = no (λ ()) +unknown ≡ᵀ (U ∪ V) = no (λ ()) +unknown ≡ᵀ (U ∩ V) = no (λ ()) number ≡ᵀ nil = no (λ ()) number ≡ᵀ (T ⇒ U) = no (λ ()) -number ≡ᵀ none = no (λ ()) -number ≡ᵀ any = no (λ ()) +number ≡ᵀ never = no (λ ()) +number ≡ᵀ unknown = no (λ ()) number ≡ᵀ number = yes refl number ≡ᵀ boolean = no (λ ()) number ≡ᵀ (T ∪ U) = no (λ ()) number ≡ᵀ (T ∩ U) = no (λ ()) boolean ≡ᵀ nil = no (λ ()) boolean ≡ᵀ (T ⇒ U) = no (λ ()) -boolean ≡ᵀ none = no (λ ()) -boolean ≡ᵀ any = no (λ ()) +boolean ≡ᵀ never = no (λ ()) +boolean ≡ᵀ unknown = no (λ ()) boolean ≡ᵀ boolean = yes refl boolean ≡ᵀ number = no (λ ()) boolean ≡ᵀ (T ∪ U) = no (λ ()) boolean ≡ᵀ (T ∩ U) = no (λ ()) string ≡ᵀ nil = no (λ ()) string ≡ᵀ (x ⇒ x₁) = no (λ ()) -string ≡ᵀ none = no (λ ()) -string ≡ᵀ any = no (λ ()) +string ≡ᵀ never = no (λ ()) +string ≡ᵀ unknown = no (λ ()) string ≡ᵀ boolean = no (λ ()) string ≡ᵀ number = no (λ ()) string ≡ᵀ string = yes refl @@ -117,8 +117,8 @@ string ≡ᵀ (U ∪ V) = no (λ ()) string ≡ᵀ (U ∩ V) = no (λ ()) (S ∪ T) ≡ᵀ nil = no (λ ()) (S ∪ T) ≡ᵀ (U ⇒ V) = no (λ ()) -(S ∪ T) ≡ᵀ none = no (λ ()) -(S ∪ T) ≡ᵀ any = no (λ ()) +(S ∪ T) ≡ᵀ never = no (λ ()) +(S ∪ T) ≡ᵀ unknown = no (λ ()) (S ∪ T) ≡ᵀ number = no (λ ()) (S ∪ T) ≡ᵀ boolean = no (λ ()) (S ∪ T) ≡ᵀ (U ∪ V) with (S ≡ᵀ U) | (T ≡ᵀ V) @@ -128,8 +128,8 @@ string ≡ᵀ (U ∩ V) = no (λ ()) (S ∪ T) ≡ᵀ (U ∩ V) = no (λ ()) (S ∩ T) ≡ᵀ nil = no (λ ()) (S ∩ T) ≡ᵀ (U ⇒ V) = no (λ ()) -(S ∩ T) ≡ᵀ none = no (λ ()) -(S ∩ T) ≡ᵀ any = no (λ ()) +(S ∩ T) ≡ᵀ never = no (λ ()) +(S ∩ T) ≡ᵀ unknown = no (λ ()) (S ∩ T) ≡ᵀ number = no (λ ()) (S ∩ T) ≡ᵀ boolean = no (λ ()) (S ∩ T) ≡ᵀ (U ∪ V) = no (λ ()) @@ -151,29 +151,29 @@ data Mode : Set where nonstrict : Mode src : Mode → Type → Type -src m nil = none -src m number = none -src m boolean = none -src m string = none +src m nil = never +src m number = never +src m boolean = never +src m string = never src m (S ⇒ T) = S -- In nonstrict mode, functions are covaraiant, in strict mode they're contravariant src strict (S ∪ T) = (src strict S) ∩ (src strict T) src nonstrict (S ∪ T) = (src nonstrict S) ∪ (src nonstrict T) src strict (S ∩ T) = (src strict S) ∪ (src strict T) src nonstrict (S ∩ T) = (src nonstrict S) ∩ (src nonstrict T) -src strict none = any -src nonstrict none = none -src strict any = none -src nonstrict any = any +src strict never = unknown +src nonstrict never = never +src strict unknown = never +src nonstrict unknown = unknown tgt : Type → Type -tgt nil = none +tgt nil = never tgt (S ⇒ T) = T -tgt none = none -tgt any = any -tgt number = none -tgt boolean = none -tgt string = none +tgt never = never +tgt unknown = unknown +tgt number = never +tgt boolean = never +tgt string = never tgt (S ∪ T) = (tgt S) ∪ (tgt T) tgt (S ∩ T) = (tgt S) ∩ (tgt T) diff --git a/prototyping/Luau/Type/FromJSON.agda b/prototyping/Luau/Type/FromJSON.agda index 2d6ba689..e3d1e8e7 100644 --- a/prototyping/Luau/Type/FromJSON.agda +++ b/prototyping/Luau/Type/FromJSON.agda @@ -2,7 +2,7 @@ module Luau.Type.FromJSON where -open import Luau.Type using (Type; nil; _⇒_; _∪_; _∩_; any; number; string) +open import Luau.Type using (Type; nil; _⇒_; _∪_; _∩_; unknown; never; number; string) open import Agda.Builtin.List using (List; _∷_; []) open import Agda.Builtin.Bool using (true; false) @@ -42,7 +42,9 @@ typeFromJSON (object o) | just (string "AstTypeFunction") | nothing | nothing = typeFromJSON (object o) | just (string "AstTypeReference") with lookup name o typeFromJSON (object o) | just (string "AstTypeReference") | just (string "nil") = Right nil -typeFromJSON (object o) | just (string "AstTypeReference") | just (string "any") = Right any +typeFromJSON (object o) | just (string "AstTypeReference") | just (string "any") = Right unknown -- not quite right +typeFromJSON (object o) | just (string "AstTypeReference") | just (string "unknown") = Right unknown +typeFromJSON (object o) | just (string "AstTypeReference") | just (string "never") = Right never typeFromJSON (object o) | just (string "AstTypeReference") | just (string "number") = Right number typeFromJSON (object o) | just (string "AstTypeReference") | just (string "string") = Right string typeFromJSON (object o) | just (string "AstTypeReference") | _ = Left "Unknown referenced type" diff --git a/prototyping/Luau/Type/ToString.agda b/prototyping/Luau/Type/ToString.agda index 2efe6632..a41ecec2 100644 --- a/prototyping/Luau/Type/ToString.agda +++ b/prototyping/Luau/Type/ToString.agda @@ -1,7 +1,7 @@ module Luau.Type.ToString where open import FFI.Data.String using (String; _++_) -open import Luau.Type using (Type; nil; _⇒_; none; any; number; boolean; string; _∪_; _∩_; normalizeOptional) +open import Luau.Type using (Type; nil; _⇒_; never; unknown; number; boolean; string; _∪_; _∩_; normalizeOptional) {-# TERMINATING #-} typeToString : Type → String @@ -10,8 +10,8 @@ typeToStringᴵ : Type → String typeToString nil = "nil" typeToString (S ⇒ T) = "(" ++ (typeToString S) ++ ") -> " ++ (typeToString T) -typeToString none = "none" -typeToString any = "any" +typeToString never = "never" +typeToString unknown = "unknown" typeToString number = "number" typeToString boolean = "boolean" typeToString string = "string" diff --git a/prototyping/Luau/TypeCheck.agda b/prototyping/Luau/TypeCheck.agda index c22618bc..aea6507a 100644 --- a/prototyping/Luau/TypeCheck.agda +++ b/prototyping/Luau/TypeCheck.agda @@ -10,7 +10,7 @@ open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr open import Luau.Var using (Var) open import Luau.Addr using (Addr) open import Luau.Heap using (Heap; Object; function_is_end) renaming (_[_] to _[_]ᴴ) -open import Luau.Type using (Type; Mode; nil; any; number; boolean; string; _⇒_; tgt) +open import Luau.Type using (Type; Mode; nil; unknown; number; boolean; string; _⇒_; tgt) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) open import FFI.Data.Vector using (Vector) open import FFI.Data.Maybe using (Maybe; just; nothing) @@ -19,9 +19,9 @@ open import Properties.Product using (_×_; _,_) src : Type → Type src = Luau.Type.src m -orAny : Maybe Type → Type -orAny nothing = any -orAny (just T) = T +orUnknown : Maybe Type → Type +orUnknown nothing = unknown +orUnknown (just T) = T srcBinOp : BinaryOperator → Type srcBinOp + = number @@ -30,8 +30,8 @@ srcBinOp * = number srcBinOp / = number srcBinOp < = number srcBinOp > = number -srcBinOp == = any -srcBinOp ~= = any +srcBinOp == = unknown +srcBinOp ~= = unknown srcBinOp <= = number srcBinOp >= = number srcBinOp ·· = string @@ -89,7 +89,7 @@ data _⊢ᴱ_∈_ where var : ∀ {x T Γ} → - T ≡ orAny(Γ [ x ]ⱽ) → + T ≡ orUnknown(Γ [ x ]ⱽ) → ---------------- Γ ⊢ᴱ (var x) ∈ T diff --git a/prototyping/Properties/StrictMode.agda b/prototyping/Properties/StrictMode.agda index 2ff2b153..1165fdaa 100644 --- a/prototyping/Properties/StrictMode.agda +++ b/prototyping/Properties/StrictMode.agda @@ -9,10 +9,10 @@ open import FFI.Data.Maybe using (Maybe; just; nothing) open import Luau.Heap using (Heap; Object; function_is_end; defn; alloc; ok; next; lookup-not-allocated) renaming (_≡_⊕_↦_ to _≡ᴴ_⊕_↦_; _[_] to _[_]ᴴ; ∅ to ∅ᴴ) open import Luau.StrictMode using (Warningᴱ; Warningᴮ; Warningᴼ; Warningᴴ; UnallocatedAddress; UnboundVariable; FunctionCallMismatch; app₁; app₂; BinOpMismatch₁; BinOpMismatch₂; bin₁; bin₂; BlockMismatch; block₁; return; LocalVarMismatch; local₁; local₂; FunctionDefnMismatch; function₁; function₂; heap; expr; block; addr) open import Luau.Substitution using (_[_/_]ᴮ; _[_/_]ᴱ; _[_/_]ᴮunless_; var_[_/_]ᴱwhenever_) -open import Luau.Subtyping using (_≮:_; witness; any; none; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) +open import Luau.Subtyping using (_≮:_; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) open import Luau.Syntax using (Expr; yes; var; val; var_∈_; _⟨_⟩∈_; _$_; addr; number; bool; string; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name; ==; ~=) -open import Luau.Type using (Type; strict; nil; number; boolean; string; _⇒_; none; any; _∩_; _∪_; tgt; _≡ᵀ_; _≡ᴹᵀ_) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orAny; srcBinOp; tgtBinOp) +open import Luau.Type using (Type; strict; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; tgt; _≡ᵀ_; _≡ᴹᵀ_) +open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orUnknown; srcBinOp; tgtBinOp) open import Luau.Var using (_≡ⱽ_) open import Luau.Addr using (_≡ᴬ_) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_; ⊕-lookup-miss; ⊕-swap; ⊕-over) renaming (_[_] to _[_]ⱽ) @@ -22,7 +22,7 @@ open import Properties.Equality using (_≢_; sym; cong; trans; subst₁) open import Properties.Dec using (Dec; yes; no) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Functions using (_∘_) -open import Properties.Subtyping using (any-≮:; ≡-trans-≮:; ≮:-trans-≡; none-tgt-≮:; tgt-none-≮:; src-any-≮:; any-src-≮:; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-none; any-≮:-scalar; scalar-≮:-none; any-≮:-none) +open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; never-tgt-≮:; tgt-never-≮:; src-unknown-≮:; unknown-src-≮:; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never) open import Properties.TypeCheck(strict) using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) open import Luau.OpSem using (_⟦_⟧_⟶_; _⊢_⟶*_⊣_; _⊢_⟶ᴮ_⊣_; _⊢_⟶ᴱ_⊣_; app₁; app₂; function; beta; return; block; done; local; subst; binOp₀; binOp₁; binOp₂; refl; step; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.RuntimeError using (BinOpError; RuntimeErrorᴱ; RuntimeErrorᴮ; FunctionMismatch; BinOpMismatch₁; BinOpMismatch₂; UnboundVariable; SEGV; app₁; app₂; bin₁; bin₂; block; local; return; +; -; *; /; <; >; <=; >=; ··) @@ -68,12 +68,12 @@ heap-weakeningᴱ Γ H (var x) h p = p heap-weakeningᴱ Γ H (val nil) h p = p heap-weakeningᴱ Γ H (val (addr a)) refl p = p heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p with a ≡ᴬ b -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = a} defn) p | yes refl = any-≮: p -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p | no r = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ (lookup-not-allocated q r))) p +heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = a} defn) p | yes refl = unknown-≮: p +heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p | no r = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ (lookup-not-allocated q r))) p heap-weakeningᴱ Γ H (val (number x)) h p = p heap-weakeningᴱ Γ H (val (bool x)) h p = p heap-weakeningᴱ Γ H (val (string x)) h p = p -heap-weakeningᴱ Γ H (M $ N) h p = none-tgt-≮: (heap-weakeningᴱ Γ H M h (tgt-none-≮: p)) +heap-weakeningᴱ Γ H (M $ N) h p = never-tgt-≮: (heap-weakeningᴱ Γ H M h (tgt-never-≮: p)) heap-weakeningᴱ Γ H (function f ⟨ var x ∈ T ⟩∈ U is B end) h p = p heap-weakeningᴱ Γ H (block var b ∈ T is B end) h p = p heap-weakeningᴱ Γ H (binexp M op N) h p = p @@ -94,11 +94,11 @@ substitutivityᴮ-unless-no : ∀ {Γ Γ′ T V} H B v x y (r : x ≢ y) → (Γ substitutivityᴱ H (var y) v x p = substitutivityᴱ-whenever H v x y (x ≡ⱽ y) p substitutivityᴱ H (val w) v x p = Left p substitutivityᴱ H (binexp M op N) v x p = Left p -substitutivityᴱ H (M $ N) v x p = mapL none-tgt-≮: (substitutivityᴱ H M v x (tgt-none-≮: p)) +substitutivityᴱ H (M $ N) v x p = mapL never-tgt-≮: (substitutivityᴱ H M v x (tgt-never-≮: p)) substitutivityᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x p = Left p substitutivityᴱ H (block var b ∈ T is B end) v x p = Left p substitutivityᴱ-whenever H v x x (yes refl) q = swapLR (≮:-trans q) -substitutivityᴱ-whenever H v x y (no p) q = Left (≡-trans-≮: (cong orAny (sym (⊕-lookup-miss x y _ _ p))) q) +substitutivityᴱ-whenever H v x y (no p) q = Left (≡-trans-≮: (cong orUnknown (sym (⊕-lookup-miss x y _ _ p))) q) substitutivityᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x p = substitutivityᴮ-unless H B v x f (x ≡ⱽ f) p substitutivityᴮ H (local var y ∈ T ← M ∙ B) v x p = substitutivityᴮ-unless H B v x y (x ≡ⱽ y) p @@ -125,9 +125,9 @@ binOpPreservation H (·· v w) = refl reflect-subtypingᴱ : ∀ H M {H′ M′ T} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → (typeOfᴱ H′ ∅ M′ ≮: T) → Either (typeOfᴱ H ∅ M ≮: T) (Warningᴱ H (typeCheckᴱ H ∅ M)) reflect-subtypingᴮ : ∀ H B {H′ B′ T} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → (typeOfᴮ H′ ∅ B′ ≮: T) → Either (typeOfᴮ H ∅ B ≮: T) (Warningᴮ H (typeCheckᴮ H ∅ B)) -reflect-subtypingᴱ H (M $ N) (app₁ s) p = mapLR none-tgt-≮: app₁ (reflect-subtypingᴱ H M s (tgt-none-≮: p)) -reflect-subtypingᴱ H (M $ N) (app₂ v s) p = Left (none-tgt-≮: (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (tgt-none-≮: p))) -reflect-subtypingᴱ H (M $ N) (beta (function f ⟨ var y ∈ T ⟩∈ U is B end) v refl q) p = Left (≡-trans-≮: (cong tgt (cong orAny (cong typeOfᴹᴼ q))) p) +reflect-subtypingᴱ H (M $ N) (app₁ s) p = mapLR never-tgt-≮: app₁ (reflect-subtypingᴱ H M s (tgt-never-≮: p)) +reflect-subtypingᴱ H (M $ N) (app₂ v s) p = Left (never-tgt-≮: (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (tgt-never-≮: p))) +reflect-subtypingᴱ H (M $ N) (beta (function f ⟨ var y ∈ T ⟩∈ U is B end) v refl q) p = Left (≡-trans-≮: (cong tgt (cong orUnknown (cong typeOfᴹᴼ q))) p) reflect-subtypingᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) p = Left p reflect-subtypingᴱ H (block var b ∈ T is B end) (block s) p = Left p reflect-subtypingᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) p = mapR BlockMismatch (swapLR (≮:-trans p)) @@ -152,8 +152,8 @@ reflect-substitutionᴱ H (var y) v x W = reflect-substitutionᴱ-whenever H v x reflect-substitutionᴱ H (val (addr a)) v x (UnallocatedAddress r) = Left (UnallocatedAddress r) reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) with substitutivityᴱ H N v x p reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Right W = Right (Right W) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q with substitutivityᴱ H M v x (src-any-≮: q) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Left r = Left ((FunctionCallMismatch ∘ any-src-≮: q) r) +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q with substitutivityᴱ H M v x (src-unknown-≮: q) +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Left r = Left ((FunctionCallMismatch ∘ unknown-src-≮: q) r) reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Right W = Right (Right W) reflect-substitutionᴱ H (M $ N) v x (app₁ W) = mapL app₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴱ H (M $ N) v x (app₂ W) = mapL app₂ (reflect-substitutionᴱ H N v x W) @@ -187,7 +187,7 @@ reflect-weakeningᴮ : ∀ Γ H B {H′} → (H ⊑ H′) → Warningᴮ H′ (t reflect-weakeningᴱ Γ H (var x) h (UnboundVariable p) = (UnboundVariable p) reflect-weakeningᴱ Γ H (val (addr a)) h (UnallocatedAddress p) = UnallocatedAddress (lookup-⊑-nothing a h p) -reflect-weakeningᴱ Γ H (M $ N) h (FunctionCallMismatch p) = FunctionCallMismatch (heap-weakeningᴱ Γ H N h (any-src-≮: p (heap-weakeningᴱ Γ H M h (src-any-≮: p)))) +reflect-weakeningᴱ Γ H (M $ N) h (FunctionCallMismatch p) = FunctionCallMismatch (heap-weakeningᴱ Γ H N h (unknown-src-≮: p (heap-weakeningᴱ Γ H M h (src-unknown-≮: p)))) reflect-weakeningᴱ Γ H (M $ N) h (app₁ W) = app₁ (reflect-weakeningᴱ Γ H M h W) reflect-weakeningᴱ Γ H (M $ N) h (app₂ W) = app₂ (reflect-weakeningᴱ Γ H N h W) reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₁ p) = BinOpMismatch₁ (heap-weakeningᴱ Γ H M h p) @@ -214,19 +214,19 @@ reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (fu reflectᴱ : ∀ H M {H′ M′} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → Warningᴱ H′ (typeCheckᴱ H′ ∅ M′) → Either (Warningᴱ H (typeCheckᴱ H ∅ M)) (Warningᴴ H (typeCheckᴴ H)) reflectᴮ : ∀ H B {H′ B′} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → Warningᴮ H′ (typeCheckᴮ H′ ∅ B′) → Either (Warningᴮ H (typeCheckᴮ H ∅ B)) (Warningᴴ H (typeCheckᴴ H)) -reflectᴱ H (M $ N) (app₁ s) (FunctionCallMismatch p) = cond (Left ∘ FunctionCallMismatch ∘ heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) ∘ any-src-≮: p) (Left ∘ app₁) (reflect-subtypingᴱ H M s (src-any-≮: p)) +reflectᴱ H (M $ N) (app₁ s) (FunctionCallMismatch p) = cond (Left ∘ FunctionCallMismatch ∘ heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) ∘ unknown-src-≮: p) (Left ∘ app₁) (reflect-subtypingᴱ H M s (src-unknown-≮: p)) reflectᴱ H (M $ N) (app₁ s) (app₁ W′) = mapL app₁ (reflectᴱ H M s W′) reflectᴱ H (M $ N) (app₁ s) (app₂ W′) = Left (app₂ (reflect-weakeningᴱ ∅ H N (rednᴱ⊑ s) W′)) -reflectᴱ H (M $ N) (app₂ p s) (FunctionCallMismatch q) = cond (λ r → Left (FunctionCallMismatch (any-src-≮: r (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (src-any-≮: r))))) (Left ∘ app₂) (reflect-subtypingᴱ H N s q) +reflectᴱ H (M $ N) (app₂ p s) (FunctionCallMismatch q) = cond (λ r → Left (FunctionCallMismatch (unknown-src-≮: r (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (src-unknown-≮: r))))) (Left ∘ app₂) (reflect-subtypingᴱ H N s q) reflectᴱ H (M $ N) (app₂ p s) (app₁ W′) = Left (app₁ (reflect-weakeningᴱ ∅ H M (rednᴱ⊑ s) W′)) reflectᴱ H (M $ N) (app₂ p s) (app₂ W′) = mapL app₂ (reflectᴱ H N s W′) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) with substitutivityᴮ H B v x q reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Left r = Right (addr a p (FunctionDefnMismatch r)) -reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Right r = Left (FunctionCallMismatch (≮:-trans-≡ r ((cong src (cong orAny (cong typeOfᴹᴼ (sym p))))))) +reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Right r = Left (FunctionCallMismatch (≮:-trans-≡ r ((cong src (cong orUnknown (cong typeOfᴹᴼ (sym p))))))) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) with reflect-substitutionᴮ _ B v x W′ reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Left W = Right (addr a p (function₁ W)) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Left W) = Left (app₂ W) -reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Right q) = Left (FunctionCallMismatch (≮:-trans-≡ q (cong src (cong orAny (cong typeOfᴹᴼ (sym p)))))) +reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Right q) = Left (FunctionCallMismatch (≮:-trans-≡ q (cong src (cong orUnknown (cong typeOfᴹᴼ (sym p)))))) reflectᴱ H (block var b ∈ T is B end) (block s) (BlockMismatch p) = Left (cond BlockMismatch block₁ (reflect-subtypingᴮ H B s p)) reflectᴱ H (block var b ∈ T is B end) (block s) (block₁ W′) = mapL block₁ (reflectᴮ H B s W′) reflectᴱ H (block var b ∈ T is B end) (return v) W′ = Left (block₁ (return W′)) @@ -283,8 +283,8 @@ reflect* H B (step s t) W = cond (reflectᴮ H B s) (reflectᴴᴮ H B s) (refle isntNumber : ∀ H v → (valueType v ≢ number) → (typeOfᴱ H ∅ (val v) ≮: number) isntNumber H nil p = scalar-≢-impl-≮: nil number (λ ()) isntNumber H (addr a) p with remember (H [ a ]ᴴ) -isntNumber H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (function-≮:-scalar number) -isntNumber H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (any-≮:-scalar number) +isntNumber H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (function-≮:-scalar number) +isntNumber H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (unknown-≮:-scalar number) isntNumber H (number x) p = CONTRADICTION (p refl) isntNumber H (bool x) p = scalar-≢-impl-≮: boolean number (λ ()) isntNumber H (string x) p = scalar-≢-impl-≮: string number (λ ()) @@ -292,8 +292,8 @@ isntNumber H (string x) p = scalar-≢-impl-≮: string number (λ ()) isntString : ∀ H v → (valueType v ≢ string) → (typeOfᴱ H ∅ (val v) ≮: string) isntString H nil p = scalar-≢-impl-≮: nil string (λ ()) isntString H (addr a) p with remember (H [ a ]ᴴ) -isntString H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (function-≮:-scalar string) -isntString H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (any-≮:-scalar string) +isntString H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (function-≮:-scalar string) +isntString H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (unknown-≮:-scalar string) isntString H (number x) p = scalar-≢-impl-≮: number string (λ ()) isntString H (bool x) p = scalar-≢-impl-≮: boolean string (λ ()) isntString H (string x) p = CONTRADICTION (p refl) @@ -305,14 +305,14 @@ isntFunction H (number x) p = scalar-≮:-function number isntFunction H (bool x) p = scalar-≮:-function boolean isntFunction H (string x) p = scalar-≮:-function string -isntEmpty : ∀ H v → (typeOfᴱ H ∅ (val v) ≮: none) -isntEmpty H nil = scalar-≮:-none nil +isntEmpty : ∀ H v → (typeOfᴱ H ∅ (val v) ≮: never) +isntEmpty H nil = scalar-≮:-never nil isntEmpty H (addr a) with remember (H [ a ]ᴴ) -isntEmpty H (addr a) | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , p) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ p)) function-≮:-none -isntEmpty H (addr a) | (nothing , p) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ p)) any-≮:-none -isntEmpty H (number x) = scalar-≮:-none number -isntEmpty H (bool x) = scalar-≮:-none boolean -isntEmpty H (string x) = scalar-≮:-none string +isntEmpty H (addr a) | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , p) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ p)) function-≮:-never +isntEmpty H (addr a) | (nothing , p) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ p)) unknown-≮:-never +isntEmpty H (number x) = scalar-≮:-never number +isntEmpty H (bool x) = scalar-≮:-never boolean +isntEmpty H (string x) = scalar-≮:-never string runtimeBinOpWarning : ∀ H {op} v → BinOpError op (valueType v) → (typeOfᴱ H ∅ (val v) ≮: srcBinOp op) runtimeBinOpWarning H v (+ p) = isntNumber H v p @@ -330,7 +330,7 @@ runtimeWarningᴮ : ∀ H B → RuntimeErrorᴮ H B → Warningᴮ H (typeCheck runtimeWarningᴱ H (var x) UnboundVariable = UnboundVariable refl runtimeWarningᴱ H (val (addr a)) (SEGV p) = UnallocatedAddress p -runtimeWarningᴱ H (M $ N) (FunctionMismatch v w p) = FunctionCallMismatch (any-src-≮: (isntEmpty H w) (isntFunction H v p)) +runtimeWarningᴱ H (M $ N) (FunctionMismatch v w p) = FunctionCallMismatch (unknown-src-≮: (isntEmpty H w) (isntFunction H v p)) runtimeWarningᴱ H (M $ N) (app₁ err) = app₁ (runtimeWarningᴱ H M err) runtimeWarningᴱ H (M $ N) (app₂ err) = app₂ (runtimeWarningᴱ H N err) runtimeWarningᴱ H (block var b ∈ T is B end) (block err) = block₁ (runtimeWarningᴮ H B err) diff --git a/prototyping/Properties/Subtyping.agda b/prototyping/Properties/Subtyping.agda index 0a20f244..cc6bb5c1 100644 --- a/prototyping/Properties/Subtyping.agda +++ b/prototyping/Properties/Subtyping.agda @@ -4,8 +4,8 @@ module Properties.Subtyping where open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) -open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; any; none; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) -open import Luau.Type using (Type; Scalar; strict; nil; number; string; boolean; none; any; _⇒_; _∪_; _∩_; tgt) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) +open import Luau.Type using (Type; Scalar; strict; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; tgt) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Equality using (_≢_) open import Properties.Functions using (_∘_) @@ -47,8 +47,8 @@ dec-language (T₁ ⇒ T₂) (scalar s) = Left (function-scalar s) dec-language (T₁ ⇒ T₂) function = Right function dec-language (T₁ ⇒ T₂) (function-ok t) = mapLR function-ok function-ok (dec-language T₂ t) dec-language (T₁ ⇒ T₂) (function-err t) = mapLR function-err function-err (swapLR (dec-language T₁ t)) -dec-language none t = Left none -dec-language any t = Right any +dec-language never t = Left never +dec-language unknown t = Right unknown dec-language (T₁ ∪ T₂) t = cond (λ p → cond (Left ∘ _,_ p) (Right ∘ right) (dec-language T₂ t)) (Right ∘ left) (dec-language T₁ t) dec-language (T₁ ∩ T₂) t = cond (Left ∘ left) (λ p → cond (Left ∘ right) (Right ∘ _,_ p) (dec-language T₂ t)) (dec-language T₁ t) @@ -60,7 +60,7 @@ language-comp t (left p) (q₁ , q₂) = language-comp t p q₁ language-comp t (right p) (q₁ , q₂) = language-comp t p q₂ language-comp (scalar s) (scalar-scalar s p₁ p₂) (scalar s) = p₂ refl language-comp (scalar s) (function-scalar s) (scalar s) = language-comp function (scalar-function s) function -language-comp (scalar s) none (scalar ()) +language-comp (scalar s) never (scalar ()) language-comp function (scalar-function ()) function language-comp (function-ok t) (scalar-function-ok ()) (function-ok q) language-comp (function-ok t) (function-ok p) (function-ok q) = language-comp t p q @@ -104,11 +104,11 @@ function-≮:-scalar s = witness function function (scalar-function s) scalar-≮:-function : ∀ {S T U} → (Scalar U) → (U ≮: (S ⇒ T)) scalar-≮:-function s = witness (scalar s) (scalar s) (function-scalar s) -any-≮:-scalar : ∀ {U} → (Scalar U) → (any ≮: U) -any-≮:-scalar s = witness (function-ok (scalar s)) any (scalar-function-ok s) +unknown-≮:-scalar : ∀ {U} → (Scalar U) → (unknown ≮: U) +unknown-≮:-scalar s = witness (function-ok (scalar s)) unknown (scalar-function-ok s) -scalar-≮:-none : ∀ {U} → (Scalar U) → (U ≮: none) -scalar-≮:-none s = witness (scalar s) (scalar s) none +scalar-≮:-never : ∀ {U} → (Scalar U) → (U ≮: never) +scalar-≮:-never s = witness (scalar s) (scalar s) never scalar-≢-impl-≮: : ∀ {T U} → (Scalar T) → (Scalar U) → (T ≢ U) → (T ≮: U) scalar-≢-impl-≮: s₁ s₂ p = witness (scalar s₁) (scalar s₁) (scalar-scalar s₁ s₂ p) @@ -117,8 +117,8 @@ scalar-≢-impl-≮: s₁ s₂ p = witness (scalar s₁) (scalar s₁) (scalar-s tgt-function-ok : ∀ {T t} → (Language (tgt T) t) → Language T (function-ok t) tgt-function-ok {T = nil} (scalar ()) tgt-function-ok {T = T₁ ⇒ T₂} p = function-ok p -tgt-function-ok {T = none} (scalar ()) -tgt-function-ok {T = any} p = any +tgt-function-ok {T = never} (scalar ()) +tgt-function-ok {T = unknown} p = unknown tgt-function-ok {T = boolean} (scalar ()) tgt-function-ok {T = number} (scalar ()) tgt-function-ok {T = string} (scalar ()) @@ -131,7 +131,7 @@ function-ok-tgt (function-ok p) = p function-ok-tgt (left p) = left (function-ok-tgt p) function-ok-tgt (right p) = right (function-ok-tgt p) function-ok-tgt (p₁ , p₂) = (function-ok-tgt p₁ , function-ok-tgt p₂) -function-ok-tgt any = any +function-ok-tgt unknown = unknown skalar-function-ok : ∀ {t} → (¬Language skalar (function-ok t)) skalar-function-ok = (scalar-function-ok number , (scalar-function-ok string , (scalar-function-ok nil , scalar-function-ok boolean))) @@ -142,22 +142,22 @@ skalar-scalar boolean = right (right (right (scalar boolean))) skalar-scalar string = right (left (scalar string)) skalar-scalar nil = right (right (left (scalar nil))) -tgt-none-≮: : ∀ {T U} → (tgt T ≮: U) → (T ≮: (skalar ∪ (none ⇒ U))) -tgt-none-≮: (witness t p q) = witness (function-ok t) (tgt-function-ok p) (skalar-function-ok , function-ok q) +tgt-never-≮: : ∀ {T U} → (tgt T ≮: U) → (T ≮: (skalar ∪ (never ⇒ U))) +tgt-never-≮: (witness t p q) = witness (function-ok t) (tgt-function-ok p) (skalar-function-ok , function-ok q) -none-tgt-≮: : ∀ {T U} → (T ≮: (skalar ∪ (none ⇒ U))) → (tgt T ≮: U) -none-tgt-≮: (witness (scalar s) p (q₁ , q₂)) = CONTRADICTION (≮:-refl (witness (scalar s) (skalar-scalar s) q₁)) -none-tgt-≮: (witness function p (q₁ , scalar-function ())) -none-tgt-≮: (witness (function-ok t) p (q₁ , function-ok q₂)) = witness t (function-ok-tgt p) q₂ -none-tgt-≮: (witness (function-err (scalar s)) p (q₁ , function-err (scalar ()))) +never-tgt-≮: : ∀ {T U} → (T ≮: (skalar ∪ (never ⇒ U))) → (tgt T ≮: U) +never-tgt-≮: (witness (scalar s) p (q₁ , q₂)) = CONTRADICTION (≮:-refl (witness (scalar s) (skalar-scalar s) q₁)) +never-tgt-≮: (witness function p (q₁ , scalar-function ())) +never-tgt-≮: (witness (function-ok t) p (q₁ , function-ok q₂)) = witness t (function-ok-tgt p) q₂ +never-tgt-≮: (witness (function-err (scalar s)) p (q₁ , function-err (scalar ()))) -- Properties of src function-err-src : ∀ {T t} → (¬Language (src T) t) → Language T (function-err t) -function-err-src {T = nil} none = scalar-function-err nil +function-err-src {T = nil} never = scalar-function-err nil function-err-src {T = T₁ ⇒ T₂} p = function-err p -function-err-src {T = none} (scalar-scalar number () p) -function-err-src {T = none} (scalar-function-ok ()) -function-err-src {T = any} none = any +function-err-src {T = never} (scalar-scalar number () p) +function-err-src {T = never} (scalar-function-ok ()) +function-err-src {T = unknown} never = unknown function-err-src {T = boolean} p = scalar-function-err boolean function-err-src {T = number} p = scalar-function-err number function-err-src {T = string} p = scalar-function-err string @@ -168,8 +168,8 @@ function-err-src {T = T₁ ∩ T₂} (p₁ , p₂) = function-err-src p₁ , fun ¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) ¬function-err-src {T = nil} (scalar ()) ¬function-err-src {T = T₁ ⇒ T₂} p = function-err p -¬function-err-src {T = none} any = none -¬function-err-src {T = any} (scalar ()) +¬function-err-src {T = never} unknown = never +¬function-err-src {T = unknown} (scalar ()) ¬function-err-src {T = boolean} (scalar ()) ¬function-err-src {T = number} (scalar ()) ¬function-err-src {T = string} (scalar ()) @@ -178,53 +178,53 @@ function-err-src {T = T₁ ∩ T₂} (p₁ , p₂) = function-err-src p₁ , fun ¬function-err-src {T = T₁ ∩ T₂} (right p) = right (¬function-err-src p) src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) -src-¬function-err {T = nil} p = none +src-¬function-err {T = nil} p = never src-¬function-err {T = T₁ ⇒ T₂} (function-err p) = p -src-¬function-err {T = none} (scalar-function-err ()) -src-¬function-err {T = any} p = none -src-¬function-err {T = boolean} p = none -src-¬function-err {T = number} p = none -src-¬function-err {T = string} p = none +src-¬function-err {T = never} (scalar-function-err ()) +src-¬function-err {T = unknown} p = never +src-¬function-err {T = boolean} p = never +src-¬function-err {T = number} p = never +src-¬function-err {T = string} p = never src-¬function-err {T = T₁ ∪ T₂} (left p) = left (src-¬function-err p) src-¬function-err {T = T₁ ∪ T₂} (right p) = right (src-¬function-err p) src-¬function-err {T = T₁ ∩ T₂} (p₁ , p₂) = (src-¬function-err p₁ , src-¬function-err p₂) src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) -src-¬scalar number (scalar number) = none -src-¬scalar boolean (scalar boolean) = none -src-¬scalar string (scalar string) = none -src-¬scalar nil (scalar nil) = none +src-¬scalar number (scalar number) = never +src-¬scalar boolean (scalar boolean) = never +src-¬scalar string (scalar string) = never +src-¬scalar nil (scalar nil) = never src-¬scalar s (left p) = left (src-¬scalar s p) src-¬scalar s (right p) = right (src-¬scalar s p) src-¬scalar s (p₁ , p₂) = (src-¬scalar s p₁ , src-¬scalar s p₂) -src-¬scalar s any = none +src-¬scalar s unknown = never -src-any-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ any)) -src-any-≮: (witness t p q) = witness (function-err t) (function-err-src q) (¬function-err-src p) +src-unknown-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ unknown)) +src-unknown-≮: (witness t p q) = witness (function-err t) (function-err-src q) (¬function-err-src p) -any-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ any)) → (U ≮: src T) -any-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) -any-src-≮: r (witness (function-ok (scalar s)) p (function-ok (scalar-scalar s () q))) -any-src-≮: r (witness (function-ok (function-ok _)) p (function-ok (scalar-function-ok ()))) -any-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) +unknown-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ unknown)) → (U ≮: src T) +unknown-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) +unknown-src-≮: r (witness (function-ok (scalar s)) p (function-ok (scalar-scalar s () q))) +unknown-src-≮: r (witness (function-ok (function-ok _)) p (function-ok (scalar-function-ok ()))) +unknown-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) --- Properties of any and none -any-≮: : ∀ {T U} → (T ≮: U) → (any ≮: U) -any-≮: (witness t p q) = witness t any q +-- Properties of unknown and never +unknown-≮: : ∀ {T U} → (T ≮: U) → (unknown ≮: U) +unknown-≮: (witness t p q) = witness t unknown q -none-≮: : ∀ {T U} → (T ≮: U) → (T ≮: none) -none-≮: (witness t p q) = witness t p none +never-≮: : ∀ {T U} → (T ≮: U) → (T ≮: never) +never-≮: (witness t p q) = witness t p never -any-≮:-none : (any ≮: none) -any-≮:-none = witness (scalar nil) any none +unknown-≮:-never : (unknown ≮: never) +unknown-≮:-never = witness (scalar nil) unknown never -function-≮:-none : ∀ {T U} → ((T ⇒ U) ≮: none) -function-≮:-none = witness function function none +function-≮:-never : ∀ {T U} → ((T ⇒ U) ≮: never) +function-≮:-never = witness function function never -- A Gentle Introduction To Semantic Subtyping (https://www.cduce.org/papers/gentle.pdf) -- defines a "set-theoretic" model (sec 2.5) -- Unfortunately we don't quite have this property, due to uninhabited types, --- for example (none -> T) is equivalent to (none -> U) +-- for example (never -> T) is equivalent to (never -> U) -- when types are interpreted as sets of syntactic values. _⊆_ : ∀ {A : Set} → (A → Set) → (A → Set) → Set @@ -258,7 +258,7 @@ not-quite-set-theoretic-only-if : ∀ {S₁ T₁ S₂ T₂} → -- We don't quite have that this is a set-theoretic model -- it's only true when Language T₁ and ¬Language T₂ t₂ are inhabited - -- in particular it's not true when T₁ is none, or T₂ is any. + -- in particular it's not true when T₁ is never, or T₂ is unknown. ∀ s₂ t₂ → Language S₂ s₂ → ¬Language T₂ t₂ → -- This is the "only if" part of being a set-theoretic model @@ -285,11 +285,11 @@ not-quite-set-theoretic-only-if {S₁} {T₁} {S₂} {T₂} s₂ t₂ S₂s₂ -- A counterexample when the argument type is empty. -set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language none) ⊗ Comp(Language number)) → Q ⊆ Comp((Language none) ⊗ Comp(Language string))) +set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language never) ⊗ Comp(Language number)) → Q ⊆ Comp((Language never) ⊗ Comp(Language string))) set-theoretic-counterexample-one Q q ((scalar s) , u) Qtu (scalar () , p) set-theoretic-counterexample-one Q q ((function-err t) , u) Qtu (scalar-function-err () , p) -set-theoretic-counterexample-two : (none ⇒ number) ≮: (none ⇒ string) +set-theoretic-counterexample-two : (never ⇒ number) ≮: (never ⇒ string) set-theoretic-counterexample-two = witness (function-ok (scalar number)) (function-ok (scalar number)) (function-ok (scalar-scalar number string (λ ()))) @@ -298,14 +298,14 @@ set-theoretic-counterexample-two = witness -- The reason why this is connected to overloaded functions is that currently we have that the type of -- f(x) is (tgt T) where f:T. Really we should have the type depend on the type of x, that is use (tgt T U), -- where U is the type of x. In particular (tgt (S => T) (U & V)) should be the same as (tgt ((S&U) => T) V) --- and tgt(none => T) should be any. For example +-- and tgt(never => T) should be unknown. For example -- -- tgt((number => string) & (string => bool))(number) -- is tgt(number => string)(number) & tgt(string => bool)(number) --- is tgt(number => string)(number) & tgt(string => bool)(number&any) --- is tgt(number => string)(number) & tgt(string&number => bool)(any) --- is tgt(number => string)(number) & tgt(none => bool)(any) --- is string & any +-- is tgt(number => string)(number) & tgt(string => bool)(number&unknown) +-- is tgt(number => string)(number) & tgt(string&number => bool)(unknown) +-- is tgt(number => string)(number) & tgt(never => bool)(unknown) +-- is string & unknown -- is string -- -- there's some discussion of this in the Gentle Introduction paper. diff --git a/prototyping/Properties/TypeCheck.agda b/prototyping/Properties/TypeCheck.agda index ead0c097..a5916a13 100644 --- a/prototyping/Properties/TypeCheck.agda +++ b/prototyping/Properties/TypeCheck.agda @@ -8,9 +8,9 @@ open import Agda.Builtin.Equality using (_≡_; refl) open import Agda.Builtin.Bool using (Bool; true; false) open import FFI.Data.Maybe using (Maybe; just; nothing) open import FFI.Data.Either using (Either) -open import Luau.TypeCheck(m) using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orAny; tgtBinOp) +open import Luau.TypeCheck(m) using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orUnknown; tgtBinOp) open import Luau.Syntax using (Block; Expr; Value; BinaryOperator; yes; nil; addr; number; bool; string; val; var; binexp; _$_; function_is_end; block_is_end; _∙_; return; done; local_←_; _⟨_⟩; _⟨_⟩∈_; var_∈_; name; fun; arg; +; -; *; /; <; >; ==; ~=; <=; >=) -open import Luau.Type using (Type; nil; any; none; number; boolean; string; _⇒_; tgt) +open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_; tgt) open import Luau.RuntimeType using (RuntimeType; nil; number; function; string; valueType) open import Luau.VarCtxt using (VarCtxt; ∅; _↦_; _⊕_↦_; _⋒_; _⊝_) renaming (_[_] to _[_]ⱽ) open import Luau.Addr using (Addr) @@ -42,8 +42,8 @@ typeOfⱽ H (string x) = just string typeOfᴱ : Heap yes → VarCtxt → (Expr yes) → Type typeOfᴮ : Heap yes → VarCtxt → (Block yes) → Type -typeOfᴱ H Γ (var x) = orAny(Γ [ x ]ⱽ) -typeOfᴱ H Γ (val v) = orAny(typeOfⱽ H v) +typeOfᴱ H Γ (var x) = orUnknown(Γ [ x ]ⱽ) +typeOfᴱ H Γ (val v) = orUnknown(typeOfⱽ H v) typeOfᴱ H Γ (M $ N) = tgt(typeOfᴱ H Γ M) typeOfᴱ H Γ (function f ⟨ var x ∈ S ⟩∈ T is B end) = S ⇒ T typeOfᴱ H Γ (block var b ∈ T is B end) = T @@ -54,7 +54,7 @@ typeOfᴮ H Γ (local var x ∈ T ← M ∙ B) = typeOfᴮ H (Γ ⊕ x ↦ T) B typeOfᴮ H Γ (return M ∙ B) = typeOfᴱ H Γ M typeOfᴮ H Γ done = nil -mustBeFunction : ∀ H Γ v → (none ≢ src (typeOfᴱ H Γ (val v))) → (function ≡ valueType(v)) +mustBeFunction : ∀ H Γ v → (never ≢ src (typeOfᴱ H Γ (val v))) → (function ≡ valueType(v)) mustBeFunction H Γ nil p = CONTRADICTION (p refl) mustBeFunction H Γ (addr a) p = refl mustBeFunction H Γ (number n) p = CONTRADICTION (p refl) @@ -64,17 +64,17 @@ mustBeFunction H Γ (string x) p = CONTRADICTION (p refl) mustBeNumber : ∀ H Γ v → (typeOfᴱ H Γ (val v) ≡ number) → (valueType(v) ≡ number) mustBeNumber H Γ (addr a) p with remember (H [ a ]ᴴ) -mustBeNumber H Γ (addr a) p | (just O , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeNumber H Γ (addr a) p | (just O , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeNumber H Γ (addr a) p | (just function f ⟨ var x ∈ T ⟩∈ U is B end , q) | () -mustBeNumber H Γ (addr a) p | (nothing , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeNumber H Γ (addr a) p | (nothing , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeNumber H Γ (addr a) p | nothing , q | () mustBeNumber H Γ (number n) p = refl mustBeString : ∀ H Γ v → (typeOfᴱ H Γ (val v) ≡ string) → (valueType(v) ≡ string) mustBeString H Γ (addr a) p with remember (H [ a ]ᴴ) -mustBeString H Γ (addr a) p | (just O , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeString H Γ (addr a) p | (just O , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeString H Γ (addr a) p | (just function f ⟨ var x ∈ T ⟩∈ U is B end , q) | () -mustBeString H Γ (addr a) p | (nothing , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeString H Γ (addr a) p | (nothing , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeString H Γ (addr a) p | (nothing , q) | () mustBeString H Γ (string x) p = refl @@ -83,7 +83,7 @@ typeCheckᴮ : ∀ H Γ B → (Γ ⊢ᴮ B ∈ (typeOfᴮ H Γ B)) typeCheckᴱ H Γ (var x) = var refl typeCheckᴱ H Γ (val nil) = nil -typeCheckᴱ H Γ (val (addr a)) = addr (orAny (typeOfᴹᴼ (H [ a ]ᴴ))) +typeCheckᴱ H Γ (val (addr a)) = addr (orUnknown (typeOfᴹᴼ (H [ a ]ᴴ))) typeCheckᴱ H Γ (val (number n)) = number typeCheckᴱ H Γ (val (bool b)) = bool typeCheckᴱ H Γ (val (string x)) = string diff --git a/prototyping/Tests/PrettyPrinter/smoke_test/out.txt b/prototyping/Tests/PrettyPrinter/smoke_test/out.txt index 34e0c4fe..ca95cae9 100644 --- a/prototyping/Tests/PrettyPrinter/smoke_test/out.txt +++ b/prototyping/Tests/PrettyPrinter/smoke_test/out.txt @@ -10,10 +10,10 @@ local function comp(f) end local id2 = comp(id)(id) local nil2 = id2(nil) -local a : any = nil +local a : unknown = nil local b : nil = nil local c : (nil) -> nil = nil -local d : (any & nil) = nil -local e : any? = nil +local d : (unknown & nil) = nil +local e : unknown? = nil local f : number = 123.0 return id2(nil2) From 8e7845076b240cd10ea73680d48cad9525b31c20 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 14 Apr 2022 16:57:43 -0700 Subject: [PATCH 046/102] Sync to upstream/release/523 (#459) --- Analysis/include/Luau/Clone.h | 9 +- Analysis/include/Luau/Error.h | 10 +- Analysis/include/Luau/Frontend.h | 1 + Analysis/include/Luau/LValue.h | 4 + Analysis/include/Luau/Module.h | 2 +- Analysis/include/Luau/Normalize.h | 19 + Analysis/include/Luau/RecursionCounter.h | 26 +- Analysis/include/Luau/Substitution.h | 1 + Analysis/include/Luau/ToString.h | 3 + Analysis/include/Luau/TxnLog.h | 32 +- Analysis/include/Luau/TypeInfer.h | 24 +- Analysis/include/Luau/TypePack.h | 14 +- Analysis/include/Luau/TypeVar.h | 31 +- Analysis/include/Luau/Unifiable.h | 17 +- Analysis/include/Luau/Unifier.h | 22 +- Analysis/include/Luau/UnifierSharedState.h | 2 + Analysis/include/Luau/VisitTypeVar.h | 9 + Analysis/src/Autocomplete.cpp | 8 +- Analysis/src/Clone.cpp | 116 ++- Analysis/src/Error.cpp | 28 +- Analysis/src/Frontend.cpp | 41 +- Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/LValue.cpp | 17 + Analysis/src/Linter.cpp | 93 +- Analysis/src/Module.cpp | 39 +- Analysis/src/Normalize.cpp | 814 +++++++++++++++++ Analysis/src/Quantify.cpp | 36 + Analysis/src/Substitution.cpp | 138 ++- Analysis/src/ToDot.cpp | 31 + Analysis/src/ToString.cpp | 165 +++- Analysis/src/TopoSortStatements.cpp | 1 + Analysis/src/TxnLog.cpp | 44 +- Analysis/src/TypeAttach.cpp | 13 + Analysis/src/TypeInfer.cpp | 488 +++++++++-- Analysis/src/TypePack.cpp | 46 +- Analysis/src/TypeVar.cpp | 28 +- Analysis/src/Unifier.cpp | 337 +++++-- Ast/include/Luau/DenseHash.h | 118 ++- Ast/include/Luau/Lexer.h | 2 +- Ast/src/Lexer.cpp | 10 +- Ast/src/Parser.cpp | 5 +- Compiler/src/Compiler.cpp | 4 +- Compiler/src/CostModel.cpp | 258 ++++++ Compiler/src/CostModel.h | 18 + Sources.cmake | 6 + VM/src/ltable.cpp | 47 +- VM/src/ltablib.cpp | 2 +- VM/src/lvmexecute.cpp | 4 +- tests/CostModel.test.cpp | 101 +++ tests/Fixture.cpp | 4 +- tests/JsonEncoder.test.cpp | 4 +- tests/Linter.test.cpp | 3 +- tests/Module.test.cpp | 75 +- tests/NonstrictMode.test.cpp | 34 + tests/Normalize.test.cpp | 967 +++++++++++++++++++++ tests/Parser.test.cpp | 20 + tests/ToDot.test.cpp | 77 +- tests/ToString.test.cpp | 2 + tests/TopoSort.test.cpp | 32 +- tests/Transpiler.test.cpp | 2 +- tests/TypeInfer.annotations.test.cpp | 10 + tests/TypeInfer.builtins.test.cpp | 13 +- tests/TypeInfer.classes.test.cpp | 3 + tests/TypeInfer.functions.test.cpp | 177 +++- tests/TypeInfer.generics.test.cpp | 77 +- tests/TypeInfer.intersectionTypes.test.cpp | 46 +- tests/TypeInfer.oop.test.cpp | 16 +- tests/TypeInfer.operators.test.cpp | 3 +- tests/TypeInfer.provisional.test.cpp | 135 ++- tests/TypeInfer.refinements.test.cpp | 12 +- tests/TypeInfer.singletons.test.cpp | 11 +- tests/TypeInfer.tables.test.cpp | 61 +- tests/TypeInfer.test.cpp | 66 +- tests/TypeInfer.typePacks.cpp | 38 +- tests/TypeInfer.unionTypes.test.cpp | 25 +- tests/conformance/nextvar.lua | 15 + 76 files changed, 4575 insertions(+), 639 deletions(-) create mode 100644 Analysis/include/Luau/Normalize.h create mode 100644 Analysis/src/Normalize.cpp create mode 100644 Compiler/src/CostModel.cpp create mode 100644 Compiler/src/CostModel.h create mode 100644 tests/CostModel.test.cpp create mode 100644 tests/Normalize.test.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 917ef801..78aa92c7 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -14,12 +14,15 @@ using SeenTypePacks = std::unordered_map; struct CloneState { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + int recursionCount = 0; bool encounteredFreeType = false; }; -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 53b946a0..70683141 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -287,12 +287,20 @@ struct TypesAreUnrelated bool operator==(const TypesAreUnrelated& rhs) const; }; +struct NormalizationTooComplex +{ + bool operator==(const NormalizationTooComplex&) const + { + return true; + } +}; + using TypeErrorData = Variant; + MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated, NormalizationTooComplex>; struct TypeError { diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 2266f548..e24e433c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -70,6 +70,7 @@ struct SourceNode std::vector> requireLocations; bool dirty = true; bool dirtyAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; struct FrontendOptions diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 3d510d5f..afb71415 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -35,8 +35,12 @@ const LValue* baseof(const LValue& lvalue); std::optional tryGetLValue(const class AstExpr& expr); // Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +// TODO: remove with FFlagLuauTypecheckOptPass std::pair> getFullName(const LValue& lvalue); +// Utility function: breaks down an LValue to get at the Symbol +Symbol getBaseSymbol(const LValue& lvalue); + template const T* get(const LValue& lvalue) { diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 9a32f614..0dd44188 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -113,7 +113,7 @@ struct Module // This helps us to force TypeVar ownership into a DAG rather than a DCG. // Returns true if there were any free types encountered in the public interface. This // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(); + bool clonePublicInterface(InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h new file mode 100644 index 00000000..262b54b2 --- /dev/null +++ b/Analysis/include/Luau/Normalize.h @@ -0,0 +1,19 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Module.h" + +namespace Luau +{ + +struct InternalErrorReporter; + +bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); + +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); + +} // namespace Luau diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 89632cea..03ae2c83 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -4,10 +4,21 @@ #include "Luau/Common.h" #include +#include + +LUAU_FASTFLAG(LuauRecursionLimitException); namespace Luau { +struct RecursionLimitException : public std::exception +{ + const char* what() const noexcept + { + return "Internal recursion counter limit exceeded"; + } +}; + struct RecursionCounter { RecursionCounter(int* count) @@ -28,11 +39,22 @@ private: struct RecursionLimiter : RecursionCounter { - RecursionLimiter(int* count, int limit) + // TODO: remove ctx after LuauRecursionLimitException is removed + RecursionLimiter(int* count, int limit, const char* ctx) : RecursionCounter(count) { + LUAU_ASSERT(ctx); if (limit > 0 && *count > limit) - throw std::runtime_error("Internal recursion counter limit exceeded"); + { + if (FFlag::LuauRecursionLimitException) + throw RecursionLimitException(); + else + { + std::string m = "Internal recursion counter limit exceeded: "; + m += ctx; + throw std::runtime_error(m); + } + } } }; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 9662d5b3..6f5931e1 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -90,6 +90,7 @@ struct Tarjan std::vector lowlink; int childCount = 0; + int childLimit = 0; // This should never be null; ensure you initialize it before calling // substitution methods. diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 49ee82fe..f4db5e35 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -28,6 +28,7 @@ struct ToStringOptions bool functionTypeArguments = false; // If true, output function type argument names when they are available bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -73,6 +74,8 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp std::string dump(TypeId ty); std::string dump(TypePackId ty); +std::string dump(const std::shared_ptr& scope, const char* name); + std::string generateName(size_t n); } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index c8ebaaeb..995ed6c6 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -7,7 +7,7 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -64,13 +64,17 @@ T* getMutable(PendingTypePack* pending) struct TxnLog { TxnLog() - : ownedSeen() + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , ownedSeen() , sharedSeen(&ownedSeen) { } explicit TxnLog(TxnLog* parent) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) { if (parent) { @@ -83,14 +87,19 @@ struct TxnLog } explicit TxnLog(std::vector>* sharedSeen) - : sharedSeen(sharedSeen) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , sharedSeen(sharedSeen) { } TxnLog(TxnLog* parent, std::vector>* sharedSeen) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) , sharedSeen(sharedSeen) { + LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); } TxnLog(const TxnLog&) = delete; @@ -243,6 +252,12 @@ struct TxnLog return Luau::getMutable(ty); } + template + const T* get(TID ty) const + { + return this->getMutable(ty); + } + // Returns whether a given type or type pack is a given state, respecting the // log's pending state. // @@ -263,11 +278,8 @@ private: // unique_ptr is used to give us stable pointers across insertions into the // map. Otherwise, it would be really easy to accidentally invalidate the // pointers returned from queue/pending. - // - // We can't use a DenseHashMap here because we need a non-const iterator - // over the map when we concatenate. - std::unordered_map, DenseHashPointer> typeVarChanges; - std::unordered_map, DenseHashPointer> typePackChanges; + DenseHashMap> typeVarChanges; + DenseHashMap> typePackChanges; TxnLog* parent = nullptr; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 215da67f..ac880135 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -76,19 +76,32 @@ struct Instantiation : Substitution // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack) + Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) + , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } + InternalErrorReporter* iceHandler; + TypeId anyType; TypePackId anyTypePack; + bool normalizationTooComplex = false; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; TypePackId clean(TypePackId tp) override; + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } }; // A substitution which replaces the type parameters of a type function by arguments @@ -139,6 +152,7 @@ struct TypeChecker TypeChecker& operator=(const TypeChecker&) = delete; ModulePtr check(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); + ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); std::vector> getScopes() const; @@ -160,6 +174,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); ExprResult checkExpr( @@ -172,6 +187,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); @@ -258,6 +274,8 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -395,6 +413,7 @@ private: void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; + bool useConstrainedIntersections() const; public: /** Extract the types in a type pack, given the assumption that the pack must have some exact length. @@ -421,7 +440,10 @@ public: std::vector requireCycles; + // Type inference limits std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; public: const TypeId nilType; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 85fa467f..bbc65f94 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -40,6 +40,7 @@ struct TypePack struct VariadicTypePack { TypeId ty; + bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail. }; struct TypePackVar @@ -109,10 +110,10 @@ private: }; TypePackIterator begin(TypePackId tp); -TypePackIterator begin(TypePackId tp, TxnLog* log); +TypePackIterator begin(TypePackId tp, const TxnLog* log); TypePackIterator end(TypePackId tp); -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); @@ -122,7 +123,7 @@ TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); size_t size(const TypePack& tp, TxnLog* log = nullptr); -std::optional first(TypePackId tp); +std::optional first(TypePackId tp, bool ignoreHiddenVariadics = true); TypePackVar* asMutable(TypePackId tp); TypePack* asMutable(const TypePack* tp); @@ -154,5 +155,12 @@ bool isEmpty(TypePackId tp); /// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known std::pair, std::optional> flatten(TypePackId tp); +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log); + +/// Returs true if the type pack arose from a function that is declared to be variadic. +/// Returns *false* for function argument packs that are inferred to be safe to oversaturate! +bool isVariadic(TypePackId tp); +bool isVariadic(TypePackId tp, const TxnLog& log); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index f61e4044..ae7d1377 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -109,6 +109,23 @@ struct PrimitiveTypeVar } }; +struct ConstrainedTypeVar +{ + explicit ConstrainedTypeVar(TypeLevel level) + : level(level) + { + } + + explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) + : parts(parts) + , level(level) + { + } + + std::vector parts; + TypeLevel level; +}; + // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -248,6 +265,7 @@ struct FunctionTypeVar MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; Tags tags; + bool hasNoGenerics = false; }; enum class TableState @@ -418,8 +436,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -436,6 +454,7 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) + , normal(persistent) // We assume that all persistent types are irreducable. { } @@ -446,6 +465,10 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; + // Normalization sets this for types that are fully normalized. + // This implies that they are transitively immutable. + bool normal = false; + std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -458,7 +481,7 @@ struct TypeVar final TypeVar& operator=(TypeVariant&& rhs); }; -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real @@ -545,6 +568,8 @@ void persist(TypePackId tp); const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); +std::optional getLevel(TypePackId tp); + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index e8eafe68..64fa131d 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -56,6 +56,14 @@ struct TypeLevel } }; +inline TypeLevel max(const TypeLevel& a, const TypeLevel& b) +{ + if (a.subsumes(b)) + return b; + else + return a; +} + inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) { if (a.subsumes(b)) @@ -64,7 +72,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) return b; } -namespace Unifiable +} // namespace Luau + +namespace Luau::Unifiable { using Name = std::string; @@ -125,7 +135,6 @@ private: }; template -using Variant = Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Generic, Error, Value...>; -} // namespace Unifiable -} // namespace Luau +} // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 474af50c..340feb7f 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -49,14 +49,14 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; + bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, + UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -106,7 +106,12 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); + void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); + void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); + public: + void unifyLowerBound(TypePackId subTy, TypePackId superTy); + // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); @@ -115,12 +120,7 @@ public: Unifier makeChildUnifier(); - // A utility function that appends the given error to the unifier's error log. - // This allows setting a breakpoint wherever the unifier reports an error. - void reportError(TypeError error) - { - errors.push_back(error); - } + void reportError(TypeError err); private: bool isNonstrictMode() const; @@ -135,4 +135,6 @@ private: std::optional firstPackErrorPos; }; +void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp); + } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 9a3ba56d..1a0b8b76 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -28,7 +28,9 @@ struct TypeIdPairHash struct UnifierCounters { int recursionCount = 0; + int recursionLimit = 0; int iterationCount = 0; + int iterationLimit = 0; }; struct UnifierSharedState diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 740854b3..d11cbd0d 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -82,6 +82,15 @@ void visit(TypeId ty, F& f, Set& seen) else if (auto etv = get(ty)) apply(ty, *etv, seen, f); + else if (auto ctv = get(ty)) + { + if (apply(ty, *ctv, seen, f)) + { + for (TypeId part : ctv->parts) + visit(part, f, seen); + } + } + else if (auto ptv = get(ty)) apply(ty, *ptv, seen, f); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b7201ab3..e0e79cb4 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -151,8 +151,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp auto idxExpr = nodes.back()->as(); bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto args = Luau::flatten(func->argTypes); - bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value(); + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ac9705a7..8e7f7c07 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -6,7 +6,10 @@ #include "Luau/TypePack.h" #include "Luau/Unifiable.h" +LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) + LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -23,11 +26,11 @@ struct TypePackCloner; struct TypeCloner { - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) : dest(dest) , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) , cloneState(cloneState) { } @@ -46,6 +49,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); @@ -65,11 +69,11 @@ struct TypePackCloner SeenTypePacks& seenTypePacks; CloneState& cloneState; - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) : dest(dest) , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) , cloneState(cloneState) { } @@ -103,13 +107,15 @@ struct TypePackCloner // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypePackId cloned = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); seenTypePacks[typePackId] = cloned; } void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); seenTypePacks[typePackId] = cloned; } @@ -121,10 +127,10 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + destTp->head.push_back(clone(ty, dest, cloneState)); if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); + destTp->tail = clone(*t.tail, dest, cloneState); } }; @@ -150,7 +156,9 @@ void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Bound& t) { - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypeId boundTo = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + boundTo = dest.addType(BoundTypeVar{boundTo}); seenTypes[typeId] = boundTo; } @@ -164,6 +172,23 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const ConstrainedTypeVar& t) +{ + cloneState.encounteredFreeType = true; + + TypeId res = dest.addType(ConstrainedTypeVar{t.level}); + ConstrainedTypeVar* ctv = getMutable(res); + LUAU_ASSERT(ctv); + + seenTypes[typeId] = res; + + std::vector parts; + for (TypeId part : t.parts) + parts.push_back(clone(part, dest, cloneState)); + + ctv->parts = std::move(parts); +} + void TypeCloner::operator()(const SingletonTypeVar& t) { defaultClone(t); @@ -178,23 +203,26 @@ void TypeCloner::operator()(const FunctionTypeVar& t) seenTypes[typeId] = result; for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); + ftv->generics.push_back(clone(generic, dest, cloneState)); for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); + ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); + ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); + ftv->retType = clone(t.retType, dest, cloneState); + + if (FFlag::LuauTypecheckOptPass) + ftv->hasNoGenerics = t.hasNoGenerics; } void TypeCloner::operator()(const TableTypeVar& t) { // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) + if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypeId boundTo = clone(*t.boundTo, dest, cloneState); seenTypes[typeId] = boundTo; return; } @@ -209,18 +237,20 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; + if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, cloneState); + for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + arg = clone(arg, dest, cloneState); for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + arg = clone(arg, dest, cloneState); if (ttv->state == TableState::Free) { @@ -240,8 +270,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t) MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); + mtv->table = clone(t.table, dest, cloneState); + mtv->metatable = clone(t.metatable, dest, cloneState); } void TypeCloner::operator()(const ClassTypeVar& t) @@ -252,13 +282,13 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); + ctv->parent = clone(*t.parent, dest, cloneState); if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); + ctv->metatable = clone(*t.metatable, dest, cloneState); } void TypeCloner::operator()(const AnyTypeVar& t) @@ -272,7 +302,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) options.reserve(t.options.size()); for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + options.push_back(clone(ty, dest, cloneState)); TypeId result = dest.addType(UnionTypeVar{std::move(options)}); seenTypes[typeId] = result; @@ -287,7 +317,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + option->parts.push_back(clone(ty, dest, cloneState)); } void TypeCloner::operator()(const LazyTypeVar& t) @@ -297,36 +327,36 @@ void TypeCloner::operator()(const LazyTypeVar& t) } // anonymous namespace -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypePackId"); - TypePackId& res = seenTypePacks[tp]; + TypePackId& res = cloneState.seenTypePacks[tp]; if (res == nullptr) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; + TypePackCloner cloner{dest, tp, cloneState}; Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } return res; } -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) { if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypeId"); - TypeId& res = seenTypes[typeId]; + TypeId& res = cloneState.seenTypes[typeId]; if (res == nullptr) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; + TypeCloner cloner{dest, typeId, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. // Persistent types are not being cloned and we get the original type back which might be read-only @@ -337,33 +367,33 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks return res; } -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { TypeFun result; for (auto param : typeFun.typeParams) { - TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + TypeId ty = clone(param.ty, dest, cloneState); std::optional defaultValue; if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + defaultValue = clone(*param.defaultValue, dest, cloneState); result.typeParams.push_back({ty, defaultValue}); } for (auto param : typeFun.typePackParams) { - TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + TypePackId tp = clone(param.tp, dest, cloneState); std::optional defaultValue; if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + defaultValue = clone(*param.defaultValue, dest, cloneState); result.typePackParams.push_back({tp, defaultValue}); } - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); + result.type = clone(typeFun.type, dest, cloneState); return result; } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 5eb2ea2a..cbec0b15 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,7 +8,6 @@ #include -LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -252,14 +251,7 @@ struct ErrorConverter std::string operator()(const Luau::SyntaxError& e) const { - if (FFlag::BetterDiagnosticCodesInStudio) - { - return e.message; - } - else - { - return "Syntax error: " + e.message; - } + return e.message; } std::string operator()(const Luau::CodeTooComplex&) const @@ -451,6 +443,11 @@ struct ErrorConverter { return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; } + + std::string operator()(const NormalizationTooComplex&) const + { + return "Code is too complex to typecheck! Consider simplifying the code around this area"; + } }; struct InvalidNameChecker @@ -716,14 +713,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); + return ::Luau::clone(ty, destArena, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; if constexpr (false) @@ -844,18 +841,19 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& e.left = clone(e.left); e.right = clone(e.right); } + else if constexpr (std::is_same_v) + { + } else static_assert(always_false_v, "Non-exhaustive type switch"); } void copyErrors(ErrorVec& errors, TypeArena& destArena) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 000769fe..8b0b2210 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -11,16 +11,18 @@ #include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/Common.h" #include #include #include +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau @@ -97,13 +99,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -113,7 +113,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -440,13 +440,42 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) + typeCheckerForAutocomplete.instantiationChildLimit = + std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; + } + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + double duration = getTimestamp() - timestamp; + if (moduleForAutocomplete->timeout) + { checkResult.timeoutHits.push_back(moduleName); - stats.timeCheck += getTimestamp() - timestamp; + if (FFlag::LuauAutocompleteDynamicLimits) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + } + else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0) + { + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + } + + stats.timeCheck += duration; stats.filesStrict += 1; sourceNode.dirtyAutocomplete = false; diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index a8f67589..0eaa485e 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -184,6 +184,8 @@ static void errorToString(std::ostream& stream, const T& err) } else if constexpr (std::is_same_v) stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else if constexpr (std::is_same_v) + stream << "NormalizationTooComplex { }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index c9466a40..72555ab4 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauTypecheckOptPass) + namespace Luau { @@ -79,6 +81,8 @@ std::optional tryGetLValue(const AstExpr& node) std::pair> getFullName(const LValue& lvalue) { + LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); + const LValue* current = &lvalue; std::vector keys; while (auto field = get(*current)) @@ -92,6 +96,19 @@ std::pair> getFullName(const LValue& lvalue) return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } +Symbol getBaseSymbol(const LValue& lvalue) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOptPass); + + const LValue* current = &lvalue; + while (auto field = get(*current)) + current = baseof(*current); + + const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); + return *symbol; +} + void merge(RefinementMap& l, const RefinementMap& r, std::function f) { for (const auto& [k, a] : r) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index b7480e34..5608e4b3 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,7 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) -LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false) namespace Luau { @@ -1140,25 +1139,8 @@ private: Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. Kind_Vector, // 'vector' but only used when type is used Kind_Userdata, // custom userdata type - - // TODO: remove these with LuauLintNoRobloxBits - Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. - Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. }; - bool containsPropName(TypeId ty, const std::string& propName) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - if (auto ctv = get(ty)) - return lookupClassProp(ctv, propName) != nullptr; - - if (auto ttv = get(ty)) - return ttv->props.find(propName) != ttv->props.end(); - - return false; - } - TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || @@ -1168,23 +1150,10 @@ private: if (name == "vector") return Kind_Vector; - if (FFlag::LuauLintNoRobloxBits) - { - if (std::optional maybeTy = context->scope->lookupType(name)) - return Kind_Userdata; + if (std::optional maybeTy = context->scope->lookupType(name)) + return Kind_Userdata; - return Kind_Unknown; - } - else - { - if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; - - return Kind_Unknown; - } + return Kind_Unknown; } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1202,67 +1171,11 @@ private: { if (kind == ek) return; - - // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) - return; } emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); } - bool acceptsClassName(AstName method) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || - method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); - } - - bool visit(AstExprCall* node) override - { - // TODO: Simply remove the override - if (FFlag::LuauLintNoRobloxBits) - return true; - - if (AstExprIndexName* index = node->func->as()) - { - AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; - - if (arg0) - { - if (node->self && index->index == "IsA" && node->args.size == 1) - { - validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type"); - } - else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1) - { - AstExprGlobal* g = index->expr->as(); - - if (g && (g->name == "game" || g->name == "Game")) - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - else if (node->self && acceptsClassName(index->index) && node->args.size == 1) - { - validateType(arg0, {Kind_Class}, "class type"); - } - else if (!node->self && index->index == "new" && node->args.size <= 2) - { - AstExprGlobal* g = index->expr->as(); - - if (g && g->name == "Instance") - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - } - } - - return true; - } - bool visit(AstExprBinary* node) override { if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 6bb45245..e2e3b436 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,8 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" -#include "Luau/Common.h" #include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -14,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) namespace Luau { @@ -143,32 +145,51 @@ Module::~Module() unfreeze(internalTypes); } -bool Module::clonePublicInterface() +bool Module::clonePublicInterface(InternalErrorReporter& ice) { LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; ScopePtr moduleScope = getModuleScope(); - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, cloneState); if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, cloneState); + + if (FFlag::LuauLowerBoundsCalculation) + { + normalize(moduleScope->returnType, interfaceTypes, ice); + if (moduleScope->varargPack) + normalize(*moduleScope->varargPack, interfaceTypes, ice); + } for (auto& [name, tf] : moduleScope->exportedTypeBindings) - tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); + { + tf = clone(tf, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(tf.type, interfaceTypes, ice); + } for (TypeId ty : moduleScope->returnType) + { if (get(follow(ty))) - *asMutable(ty) = AnyTypeVar{}; + { + auto t = asMutable(ty); + t->ty = AnyTypeVar{}; + t->normal = true; + } + } if (FFlag::LuauCloneDeclaredGlobals) { for (auto& [name, ty] : declaredGlobals) - ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState); + { + ty = clone(ty, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(ty, interfaceTypes, ice); + } } freeze(internalTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp new file mode 100644 index 00000000..40341ac1 --- /dev/null +++ b/Analysis/src/Normalize.cpp @@ -0,0 +1,814 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Normalize.h" + +#include + +#include "Luau/Clone.h" +#include "Luau/DenseHash.h" +#include "Luau/Substitution.h" +#include "Luau/Unifier.h" +#include "Luau/VisitTypeVar.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) + +// This could theoretically be 2000 on amd64, but x86 requires this. +LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineIntersectionFix, false); + +namespace Luau +{ + +namespace +{ + +struct Replacer : Substitution +{ + TypeId sourceType; + TypeId replacedType; + DenseHashMap replacedTypes{nullptr}; + DenseHashMap replacedPacks{nullptr}; + + Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) + : Substitution(TxnLog::empty(), arena) + , sourceType(sourceType) + , replacedType(replacedType) + { + } + + bool isDirty(TypeId ty) override + { + if (!sourceType) + return false; + + auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { + return end(vec) != std::find(begin(vec), end(vec), sourceType); + }; + + // Walk every kind of TypeVar and find pointers to sourceType + if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + { + if (vecHasSourceType(t->generics)) + return true; + + return false; + } + else if (auto t = get(ty)) + { + if (t->boundTo) + return *t->boundTo == sourceType; + + for (const auto& [_name, prop] : t->props) + { + if (prop.type == sourceType) + return true; + } + + if (auto indexer = t->indexer) + { + if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) + return true; + } + + if (vecHasSourceType(t->instantiatedTypeParams)) + return true; + + return false; + } + else if (auto t = get(ty)) + return t->table == sourceType || t->metatable == sourceType; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->options); + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + + LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); + LUAU_UNREACHABLE(); + } + + bool isDirty(TypePackId tp) override + { + if (auto it = replacedPacks.find(tp)) + return false; + + if (auto pack = get(tp)) + { + for (TypeId ty : pack->head) + { + if (ty == sourceType) + return true; + } + return false; + } + else if (auto vtp = get(tp)) + return vtp->ty == sourceType; + else + return false; + } + + TypeId clean(TypeId ty) override + { + LUAU_ASSERT(sourceType && replacedType); + + // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType + // Before returning, memoize the result for later use. + + // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This + // function returns the identity for things like primitives. + TypeId res = clone(ty); + + if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + // The constituent typepacks are cleaned separately. We just need to walk the generics array. + for (TypeId& g : t->generics) + { + if (g == sourceType) + g = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (auto& [_key, prop] : t->props) + { + if (prop.type == sourceType) + prop.type = replacedType; + } + } + else if (auto t = getMutable(res)) + { + if (t->table == sourceType) + t->table = replacedType; + if (t->metatable == sourceType) + t->table = replacedType; + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& option : t->options) + { + if (option == sourceType) + option = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else + LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); + + replacedTypes[ty] = res; + return res; + } + + TypePackId clean(TypePackId tp) override + { + TypePackId res = clone(tp); + + if (auto pack = getMutable(res)) + { + for (TypeId& type : pack->head) + { + if (type == sourceType) + type = replacedType; + } + } + else if (auto vtp = getMutable(res)) + { + if (vtp->ty == sourceType) + vtp->ty = replacedType; + } + + replacedPacks[tp] = res; + return res; + } + + TypeId smartClone(TypeId t) + { + std::optional res = replace(t); + LUAU_ASSERT(res.has_value()); // TODO think about this + if (*res == t) + return clone(t); + return *res; + } +}; + +} // anonymous namespace + +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +template +static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + int count = 0; + auto isNormal = [&](TypeId ty) { + ++count; + if (count >= FInt::LuauNormalizeIterationLimit) + ice.ice("Luau::areNormal hit iteration limit"); + + return ty->normal || seen.find(asMutable(ty)); + }; + + return std::all_of(begin(t), end(t), isNormal); +} + +static bool areNormal(const std::vector& types, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + return areNormal_(types, seen, ice); +} + +static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + tp = follow(tp); + if (get(tp)) + return false; + + auto [head, tail] = flatten(tp); + + if (!areNormal_(head, seen, ice)) + return false; + + if (!tail) + return true; + + if (auto vtp = get(*tail)) + return vtp->ty->normal || seen.find(asMutable(vtp->ty)); + + return true; +} + +#define CHECK_ITERATION_LIMIT(...) \ + do \ + { \ + if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ + { \ + limitExceeded = true; \ + return __VA_ARGS__; \ + } \ + ++iterationLimit; \ + } while (false) + +struct Normalize +{ + TypeArena& arena; + InternalErrorReporter& ice; + + // Debug data. Types being normalized are invalidated but trying to see what's going on is painful. + // To actually see the original type, read it by using the pointer of the type being normalized. + // e.g. in lldb, `e dump(originalTys[ty])`. + SeenTypes originalTys; + SeenTypePacks originalTps; + + int iterationLimit = 0; + bool limitExceeded = false; + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + template + void cycle(TID) + { + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + LUAU_ASSERT(!ty->normal); + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type. + LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); + + asMutable(ty)->normal = btv.boundTo->normal; + return !ty->normal; + } + + bool operator()(TypeId ty, const PrimitiveTypeVar&) + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar&) + { + if (!ty->normal) + asMutable(ty)->normal = true; + + return false; + } + + bool operator()(TypeId ty, const ErrorTypeVar&) + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + ConstrainedTypeVar* ctv = const_cast(&ctvRef); + + std::vector parts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId part : parts) + visit_detail::visit(part, *this, seen); + + std::vector newParts = normalizeUnion(parts); + + const bool normal = areNormal(newParts, seen, ice); + + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const FunctionTypeVar& ftv) = delete; + bool operator()(TypeId ty, const FunctionTypeVar& ftv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + visit_detail::visit(ftv.argTypes, *this, seen); + visit_detail::visit(ftv.retType, *this, seen); + + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + + return false; + } + + bool operator()(TypeId ty, const TableTypeVar& ttv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + bool normal = true; + + auto checkNormal = [&](TypeId t) { + // if t is on the stack, it is possible that this type is normal. + // If t is not normal and it is not on the stack, this type is definitely not normal. + if (!t->normal && !seen.find(asMutable(t))) + normal = false; + }; + + if (ttv.boundTo) + { + visit_detail::visit(*ttv.boundTo, *this, seen); + asMutable(ty)->normal = (*ttv.boundTo)->normal; + return false; + } + + for (const auto& [_name, prop] : ttv.props) + { + visit_detail::visit(prop.type, *this, seen); + checkNormal(prop.type); + } + + if (ttv.indexer) + { + visit_detail::visit(ttv.indexer->indexType, *this, seen); + checkNormal(ttv.indexer->indexType); + visit_detail::visit(ttv.indexer->indexResultType, *this, seen); + checkNormal(ttv.indexer->indexResultType); + } + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const MetatableTypeVar& mtv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + visit_detail::visit(mtv.table, *this, seen); + visit_detail::visit(mtv.metatable, *this, seen); + + asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; + + return false; + } + + bool operator()(TypeId ty, const ClassTypeVar& ctv) + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool operator()(TypeId ty, const AnyTypeVar&) + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool operator()(TypeId ty, const UnionTypeVar& utvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + UnionTypeVar* utv = &const_cast(utvRef); + std::vector options = std::move(utv->options); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId option : options) + visit_detail::visit(option, *this, seen); + + std::vector newOptions = normalizeUnion(options); + + const bool normal = areNormal(newOptions, seen, ice); + + LUAU_ASSERT(!newOptions.empty()); + + if (newOptions.size() == 1) + *asMutable(ty) = BoundTypeVar{newOptions[0]}; + else + utv->options = std::move(newOptions); + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + IntersectionTypeVar* itv = &const_cast(itvRef); + + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + visit_detail::visit(part, *this, seen); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } + + return false; + } + + bool operator()(TypeId ty, const LazyTypeVar&) + { + return false; + } + + std::vector normalizeUnion(const std::vector& options) + { + if (options.size() == 1) + return options; + + std::vector result; + + for (TypeId part : options) + combineIntoUnion(result, part); + + return result; + } + + void combineIntoUnion(std::vector& result, TypeId ty) + { + ty = follow(ty); + if (auto utv = get(ty)) + { + for (TypeId t : utv) + combineIntoUnion(result, t); + return; + } + + for (TypeId& part : result) + { + if (isSubtype(ty, part, ice)) + return; // no need to do anything + else if (isSubtype(part, ty, ice)) + { + part = ty; // replace the less general type by the more general one + return; + } + } + + result.push_back(ty); + } + + /** + * @param replacer knows how to clone a type such that any recursive references point at the new containing type. + * @param result is an intersection that is safe for us to mutate in-place. + */ + void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + ty = follow(ty); + if (auto itv = get(ty)) + { + for (TypeId part : itv->parts) + combineIntoIntersection(replacer, result, part); + return; + } + + // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection + if (get(ty)) + { + if (result->parts.empty()) + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + + TypeId theTable = result->parts.back(); + + if (!get(FFlag::LuauNormalizeCombineIntersectionFix ? follow(theTable) : theTable)) + { + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + theTable = result->parts.back(); + } + + TypeId newTable = replacer.smartClone(theTable); + result->parts.back() = newTable; + + combineIntoTable(replacer, getMutable(newTable), ty); + } + else if (auto ftv = get(ty)) + { + bool merged = false; + for (TypeId& part : result->parts) + { + if (isSubtype(part, ty, ice)) + { + merged = true; + break; // no need to do anything + } + else if (isSubtype(ty, part, ice)) + { + merged = true; + part = ty; // replace the less general type by the more general one + break; + } + } + + if (!merged) + result->parts.push_back(ty); + } + else + result->parts.push_back(ty); + } + + TableState combineTableStates(TableState lhs, TableState rhs) + { + if (lhs == rhs) + return lhs; + + if (lhs == TableState::Free || rhs == TableState::Free) + return TableState::Free; + + if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) + return TableState::Unsealed; + + return lhs; + } + + /** + * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new + * "containing" type. + * @param table always points into a table that is safe for us to mutate. + */ + void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + LUAU_ASSERT(table); + + ty = follow(ty); + + TableTypeVar* tyTable = getMutable(ty); + LUAU_ASSERT(tyTable); + + for (const auto& [propName, prop] : tyTable->props) + { + if (auto it = table->props.find(propName); it != table->props.end()) + { + /** + * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate + * a table that comes from somewhere else in the type graph. + * + * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible + * while still rewriting any cyclic references back to the new 'root' table. + * + * replacer also keeps a mapping of types that have previously been copied, so we have the added + * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is + * safe for us to mutate in-place. + */ + TypeId clone = replacer.smartClone(it->second.type); + it->second.type = combine(replacer, clone, prop.type); + } + else + table->props.insert({propName, prop}); + } + + table->state = combineTableStates(table->state, tyTable->state); + table->level = max(table->level, tyTable->level); + } + + /** + * @param a is always cloned by the caller. It is safe to mutate in-place. + * @param b will never be mutated. + */ + TypeId combine(Replacer& replacer, TypeId a, TypeId b) + { + if (FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + + if (!get(a) && !get(a)) + { + if (!FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + else + return arena.addType(IntersectionTypeVar{{a, b}}); + } + + if (auto itv = getMutable(a)) + { + combineIntoIntersection(replacer, itv, b); + return a; + } + else if (auto ttv = getMutable(a)) + { + if (FFlag::LuauNormalizeCombineTableFix && !get(follow(b))) + return arena.addType(IntersectionTypeVar{{a, b}}); + combineIntoTable(replacer, ttv, b); + return a; + } + + LUAU_ASSERT(!"Impossible"); + LUAU_UNREACHABLE(); + } +}; + +#undef CHECK_ITERATION_LIMIT + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(ty, arena, state); + + Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, n, seen); + + return {ty, !n.limitExceeded}; +} + +// TODO: Think about using a temporary arena and cloning types out of it so that we +// reclaim memory used by wantonly allocated intermediate types here. +// The main wrinkle here is that we don't want clone() to copy a type if the source and dest +// arena are the same. +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(ty, module->internalTypes, ice); +} + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(tp, arena, state); + + Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, n, seen); + + return {tp, !n.limitExceeded}; +} + +std::pair normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(tp, module->internalTypes, ice); +} + +} // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 94e169f1..305f83ce 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,6 +4,8 @@ #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAG(LuauTypecheckOptPass) + namespace Luau { @@ -12,6 +14,8 @@ struct Quantifier TypeLevel level; std::vector generics; std::vector genericPacks; + bool seenGenericType = false; + bool seenMutableType = false; Quantifier(TypeLevel level) : level(level) @@ -23,6 +27,9 @@ struct Quantifier bool operator()(TypeId ty, const FreeTypeVar& ftv) { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + if (!level.subsumes(ftv.level)) return false; @@ -44,17 +51,40 @@ struct Quantifier return true; } + bool operator()(TypeId ty, const ConstrainedTypeVar&) + { + return true; + } + bool operator()(TypeId ty, const TableTypeVar&) { TableTypeVar& ttv = *getMutable(ty); + if (FFlag::LuauTypecheckOptPass) + { + if (ttv.state == TableState::Generic) + seenGenericType = true; + + if (ttv.state == TableState::Free) + seenMutableType = true; + } + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) return false; if (!level.subsumes(ttv.level)) + { + if (FFlag::LuauTypecheckOptPass && ttv.state == TableState::Unsealed) + seenMutableType = true; return false; + } if (ttv.state == TableState::Free) + { ttv.state = TableState::Generic; + + if (FFlag::LuauTypecheckOptPass) + seenGenericType = true; + } else if (ttv.state == TableState::Unsealed) ttv.state = TableState::Sealed; @@ -65,6 +95,9 @@ struct Quantifier bool operator()(TypePackId tp, const FreeTypePack& ftp) { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + if (!level.subsumes(ftp.level)) return false; @@ -84,6 +117,9 @@ void quantify(TypeId ty, TypeLevel level) LUAU_ASSERT(ftv); ftv->generics = q.generics; ftv->genericPacks = q.genericPacks; + + if (FFlag::LuauTypecheckOptPass && ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 770c7a47..8648b21e 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -7,24 +7,36 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) namespace Luau { void Tarjan::visitChildren(TypeId ty, int index) { - ty = log->follow(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); if (ignoreChildren(ty)) return; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto pty = log->pending(ty)) + ty = &pty->pending; + } + + if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -41,38 +53,52 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { for (TypeId part : itv->parts) visitChild(part); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + for (TypeId part : ctv->parts) + visitChild(part); + } } void Tarjan::visitChildren(TypePackId tp, int index) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); if (ignoreChildren(tp)) return; - if (const TypePack* tpp = log->getMutable(tp)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + } + + if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { visitChild(vtp->ty); } @@ -80,7 +106,10 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - ty = log->follow(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -98,7 +127,10 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -141,7 +173,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) + if (childLimit > 0 && childLimit < childCount) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -229,6 +261,9 @@ TarjanResult Tarjan::loop() TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + ty = log->follow(ty); auto [index, fresh] = indexify(ty); @@ -239,6 +274,9 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + tp = log->follow(tp); auto [index, fresh] = indexify(tp); @@ -347,7 +385,13 @@ TypeId Substitution::clone(TypeId ty) TypeId result = ty; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto pty = log->pending(ty)) + ty = &pty->pending; + } + + if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -357,7 +401,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -370,24 +414,29 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; result = addType(std::move(clone)); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + ConstrainedTypeVar clone{ctv->level, ctv->parts}; + result = addType(std::move(clone)); + } asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -396,14 +445,21 @@ TypeId Substitution::clone(TypeId ty) TypePackId Substitution::clone(TypePackId tp) { tp = log->follow(tp); - if (const TypePack* tpp = log->getMutable(tp)) + + if (FFlag::LuauTypecheckOptPass) + { + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + } + + if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -415,25 +471,34 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - ty = log->follow(ty); - if (isDirty(ty)) - newTypes[ty] = clean(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); else - newTypes[ty] = clone(ty); + ty = log->follow(ty); + + if (isDirty(ty)) + newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(ty)) : clean(ty); + else + newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(ty)) : clone(ty); } void Substitution::foundDirty(TypePackId tp) { - tp = log->follow(tp); - if (isDirty(tp)) - newPacks[tp] = clean(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); else - newPacks[tp] = clone(tp); + tp = log->follow(tp); + + if (isDirty(tp)) + newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(tp)) : clean(tp); + else + newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(tp)) : clone(tp); } TypeId Substitution::replace(TypeId ty) { ty = log->follow(ty); + if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -443,6 +508,7 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { tp = log->follow(tp); + if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -451,7 +517,13 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = log->follow(ty); + if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) + btv->boundTo = replace(btv->boundTo); + + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); if (ignoreChildren(ty)) return; @@ -493,11 +565,19 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } + else if (ConstrainedTypeVar* ctv = getMutable(ty)) + { + for (TypeId& part : ctv->parts) + part = replace(part); + } } void Substitution::replaceChildren(TypePackId tp) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index df9d4188..cb54bfc1 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -237,6 +237,15 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + formatAppend(result, "ConstrainedTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : ctv->parts) + visitChild(part, index); + } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); @@ -258,6 +267,28 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); } + else if (const SingletonTypeVar* stv = get(ty)) + { + std::string res; + + if (const StringSingleton* ss = get(stv)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(stv)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonTypeVar %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } else { LUAU_ASSERT(!"unknown type kind"); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 59ee6de2..610842da 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + /* * Prefix generic typenames with gen- * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 @@ -33,8 +35,8 @@ struct FindCyclicTypes bool exhaustive = false; std::unordered_set visited; std::unordered_set visitedPacks; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; void cycle(TypeId ty) { @@ -86,7 +88,7 @@ struct FindCyclicTypes }; template -void findCyclicTypes(std::unordered_set& cycles, std::unordered_set& cycleTPs, TID ty, bool exhaustive) +void findCyclicTypes(std::set& cycles, std::set& cycleTPs, TID ty, bool exhaustive) { FindCyclicTypes fct; fct.exhaustive = exhaustive; @@ -124,6 +126,7 @@ struct StringifierState std::unordered_map cycleTpNames; std::unordered_set seen; std::unordered_set usedNames; + size_t indentation = 0; bool exhaustive; @@ -216,6 +219,34 @@ struct StringifierState result.name += s; } + + void indent() + { + indentation += 4; + } + + void dedent() + { + indentation -= 4; + } + + void newline() + { + if (!opts.useLineBreaks) + return emit(" "); + + emit("\n"); + emitIndentation(); + } + +private: + void emitIndentation() + { + if (!opts.indent) + return; + + emit(std::string(indentation, ' ')); + } }; struct TypeVarStringifier @@ -321,7 +352,7 @@ struct TypeVarStringifier stringify(btv.boundTo); } - void operator()(TypeId ty, const Unifiable::Generic& gtv) + void operator()(TypeId ty, const GenericTypeVar& gtv) { if (gtv.explicitName) { @@ -332,6 +363,26 @@ struct TypeVarStringifier state.emit(state.getName(ty)); } + void operator()(TypeId, const ConstrainedTypeVar& ctv) + { + state.result.invalid = true; + + state.emit("[["); + + bool first = true; + for (TypeId ty : ctv.parts) + { + if (first) + first = false; + else + state.emit("|"); + + stringify(ty); + } + + state.emit("]]"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -415,10 +466,25 @@ struct TypeVarStringifier state.emit(") -> "); bool plural = true; - if (auto retPack = get(follow(ftv.retType))) + + if (FFlag::LuauLowerBoundsCalculation) { - if (retPack->head.size() == 1 && !retPack->tail) - plural = false; + auto retBegin = begin(ftv.retType); + auto retEnd = end(ftv.retType); + if (retBegin != retEnd) + { + ++retBegin; + if (retBegin == retEnd && !retBegin.tail()) + plural = false; + } + } + else + { + if (auto retPack = get(follow(ftv.retType))) + { + if (retPack->head.size() == 1 && !retPack->tail) + plural = false; + } } if (plural) @@ -511,6 +577,7 @@ struct TypeVarStringifier } state.emit(openbrace); + state.indent(); bool comma = false; if (ttv.indexer) @@ -527,7 +594,10 @@ struct TypeVarStringifier for (const auto& [name, prop] : ttv.props) { if (comma) - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + { + state.emit(","); + state.newline(); + } size_t length = state.result.name.length() - oldLength; @@ -553,6 +623,7 @@ struct TypeVarStringifier ++index; } + state.dedent(); state.emit(closedbrace); state.unsee(&ttv); @@ -563,7 +634,8 @@ struct TypeVarStringifier state.result.invalid = true; state.emit("{ @metatable "); stringify(mtv.metatable); - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + state.emit(","); + state.newline(); stringify(mtv.table); state.emit(" }"); } @@ -784,13 +856,16 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { - const auto& tail = *tp.tail; - if (first) - first = false; - else - state.emit(", "); + TypePackId tail = follow(*tp.tail); + if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + { + if (first) + first = false; + else + state.emit(", "); - stringify(tail); + stringify(tail); + } } state.unsee(&tp); @@ -805,6 +880,8 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); + if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + state.emit(""); stringify(pack.ty); } @@ -858,15 +935,12 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::unordered_set& cycleTPs, +static void assignCycleNames(const std::set& cycles, const std::set& cycleTPs, std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) { int nextIndex = 1; - std::vector sortedCycles{cycles.begin(), cycles.end()}; - std::sort(sortedCycles.begin(), sortedCycles.end(), std::less{}); - - for (TypeId cycleTy : sortedCycles) + for (TypeId cycleTy : cycles) { std::string name; @@ -888,10 +962,7 @@ static void assignCycleNames(const std::unordered_set& cycles, const std cycleNames[cycleTy] = std::move(name); } - std::vector sortedCycleTps{cycleTPs.begin(), cycleTPs.end()}; - std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less()); - - for (TypePackId tp : sortedCycleTps) + for (TypePackId tp : cycleTPs) { std::string name = "tp" + std::to_string(nextIndex); ++nextIndex; @@ -913,8 +984,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive); @@ -1016,8 +1087,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) ToStringResult result; StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive); @@ -1058,7 +1129,7 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto t) { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1163,14 +1234,18 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp if (argPackIter.tail()) { - if (!first) - state.emit(", "); + if (auto vtp = get(*argPackIter.tail()); !vtp || !vtp->hidden) + { + if (!first) + state.emit(", "); - state.emit("...: "); - if (auto vtp = get(*argPackIter.tail())) - tvs.stringify(vtp->ty); - else - tvs.stringify(*argPackIter.tail()); + state.emit("...: "); + + if (vtp) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } } state.emit("): "); @@ -1210,6 +1285,24 @@ std::string dump(TypePackId ty) return s; } +std::string dump(const ScopePtr& scope, const char* name) +{ + auto binding = scope->linearSearchForBinding(name); + if (!binding) + { + printf("No binding %s\n", name); + return {}; + } + + TypeId ty = binding->typeId; + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; +} + std::string generateName(size_t i) { std::string n; diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 678001bf..1ea2e27d 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -215,6 +215,7 @@ struct ArcCollector : public AstVisitor } } + // Adds a dependency from the current node to the named node. void add(const Identifier& name) { Node** it = map.find(name); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5fbb596d..a5f9d26c 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) +LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) namespace Luau { @@ -161,18 +162,37 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + if (FFlag::LuauJustOneCallFrameForHaveSeen && !FFlag::LuauTypecheckOptPass) { - return true; - } + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); - if (parent) + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (current->sharedSeen->end() != std::find(current->sharedSeen->begin(), current->sharedSeen->end(), sortedPair)) + return true; + } + + return false; + } + else { - return parent->haveSeen(lhs, rhs); - } + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } - return false; + if (!FFlag::LuauTypecheckOptPass && parent) + { + return parent->haveSeen(lhs, rhs); + } + + return false; + } } void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) @@ -222,8 +242,8 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) - return it->second.get(); + if (auto it = current->typeVarChanges.find(ty)) + return it->get(); } return nullptr; @@ -237,8 +257,8 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) - return it->second.get(); + if (auto it = current->typePackChanges.find(tp)) + return it->get(); } return nullptr; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index d575e023..bc8d0d4e 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,16 @@ public: } } + AstType* operator()(const ConstrainedTypeVar& ctv) + { + AstArray types; + types.size = ctv.parts.size(); + types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); + for (size_t i = 0; i < ctv.parts.size(); ++i) + types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); + return allocator->alloc(Location(), types); + } + AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -364,6 +374,9 @@ public: AstTypePack* operator()(const VariadicTypePack& vtp) const { + if (vtp.hidden) + return nullptr; + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 10930248..af42a4e6 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,12 +3,15 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" +#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/TypePack.h" +#include "Luau/ToString.h" #include "Luau/TypeUtils.h" #include "Luau/ToString.h" #include "Luau/TypeVar.h" @@ -19,14 +22,17 @@ LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) @@ -39,6 +45,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -53,6 +60,8 @@ LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) +LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) +LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); namespace Luau { @@ -140,6 +149,34 @@ bool hasBreak(AstStat* node) } } +static bool hasReturn(const AstStat* node) +{ + struct Searcher : AstVisitor + { + bool result = false; + + bool visit(AstStat*) override + { + return !result; // if we've already found a return statement, don't bother to traverse inward anymore + } + + bool visit(AstStatReturn*) override + { + result = true; + return false; + } + + bool visit(AstExprFunction*) override + { + return false; // We don't care if the function uses a lambda that itself returns + } + }; + + Searcher searcher; + const_cast(node)->visit(&searcher); + return searcher.result; +} + // returns the last statement before the block exits, or nullptr if the block never exits const AstStat* getFallthrough(const AstStat* node) { @@ -253,6 +290,26 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) +{ + if (FFlag::LuauRecursionLimitException) + { + try + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } + } + else + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } +} + +ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) { LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); @@ -268,6 +325,12 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona iceHandler->moduleName = module.name; + if (FFlag::LuauAutocompleteDynamicLimits) + { + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; + } + ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -312,7 +375,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(); + bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); if (encounteredFreeType) { reportError(TypeError{module.root->location, @@ -415,7 +478,26 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + if (FFlag::LuauRecursionLimitException) + { + try + { + checkBlockWithoutRecursionCheck(scope, block); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } + } + else + { + checkBlockWithoutRecursionCheck(scope, block); + } +} +void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +{ int subLevel = 0; std::vector sorted(block.body.data, block.body.data + block.body.size); @@ -435,6 +517,16 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) std::unordered_map> functionDecls; + auto isLocalLambda = [](AstStat* stat) -> AstStatLocal* { + AstStatLocal* local = stat->as(); + + if (FFlag::LuauLowerBoundsCalculation && local && local->vars.size == 1 && local->values.size == 1 && + local->values.data[0]->is()) + return local; + else + return nullptr; + }; + auto checkBody = [&](AstStat* stat) { if (auto fun = stat->as()) { @@ -482,7 +574,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCallOrReturn(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter) || (FFlag::LuauLowerBoundsCalculation && isLocalLambda(*protoIter))) { while (checkIter != protoIter) { @@ -513,7 +605,8 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); + TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -658,6 +751,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location) +{ + Unifier state = mkUnifier(location); + state.unifyLowerBound(subTy, superTy); + + state.log.commit(); + + reportErrors(state.errors); +} + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -682,6 +785,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; + if (useConstrainedIntersections()) + { + unifyLowerBound(retPack, scope->returnType, return_.location); + return; + } + // HACK: Nonstrict mode gets a bit too smart and strict for us when we // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) @@ -1209,9 +1318,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco else if (tableSelf->state == TableState::Sealed) reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); + const bool tableIsExtendable = tableSelf && tableSelf->state != TableState::Sealed; + ty = follow(ty); - if (tableSelf && tableSelf->state != TableState::Sealed) + if (tableIsExtendable) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1224,7 +1335,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && tableSelf->state != TableState::Sealed) + if (tableIsExtendable) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1372,7 +1483,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias for (auto param : binding->typePackParams) clone.instantiatedTypePackParams.push_back(param.tp); + bool isNormal = ty->normal; ty = addType(std::move(clone)); + + if (FFlag::LuauLowerBoundsCalculation) + asMutable(ty)->normal = isNormal; } } else @@ -1400,6 +1515,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauTwoPassAliasDefinitionFix && ok) bindingType = ty; + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(bindingType, currentModule, *iceHandler); + bindingType = t; + if (!ok) + reportError(typealias.location, NormalizationTooComplex{}); + } } } @@ -1673,10 +1796,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { - TypeId head = freshType(scope); - TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); + TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; + TypeId head = freshType(level); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } @@ -1793,7 +1917,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : utv) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeForType unions"); // Not needed when we normalize types. if (get(follow(t))) @@ -1817,12 +1941,25 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, + *iceHandler); // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. - if (result.size() == 1) - return result[0]; + if (!ok) + reportError(location, NormalizationTooComplex{}); - return addType(UnionTypeVar{std::move(result)}); + return t; + } + else + { + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); + } } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1830,7 +1967,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : itv->parts) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeFromType intersections"); if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) parts.push_back(*ty); @@ -1982,7 +2119,6 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { if (!std::any_of(begin(utv), end(utv), isNil)) return ty; - } if (std::optional strippedUnion = tryStripUnionFromNil(ty)) @@ -2124,7 +2260,26 @@ TypeId TypeChecker::checkExprTable( ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + if (FFlag::LuauTableUseCounterInstead) + { + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; + } + + return checkExpr_(scope, expr, expectedType); + } + else + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "checkExpr for tables"); + return checkExpr_(scope, expr, expectedType); + } +} + +ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +{ std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; @@ -3176,6 +3331,10 @@ std::pair TypeChecker::checkFunctionSignature( funScope->varargPack = anyTypePack; } } + else if (FFlag::LuauLowerBoundsCalculation && !isNonstrictMode()) + { + funScope->varargPack = addTypePack(TypePackVar{VariadicTypePack{anyType, /*hidden*/ true}}); + } std::vector argTypes; @@ -3311,9 +3470,24 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE { check(scope, *function.body); - // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (useConstrainedIntersections()) + { + TypePackId retPack = follow(funTy->retType); + // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type + // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) + if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) + { + auto level = getLevel(retPack); + if (level && scope->level.subsumes(*level)) + *asMutable(retPack) = TypePack{{}, std::nullopt}; + } + } + else + { + // We explicitly don't follow here to check if we have a 'true' free type instead of bound one + if (get_if(&funTy->retType->ty)) + *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3418,6 +3592,19 @@ void TypeChecker::checkArgumentList( size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); + auto reportCountMismatchError = [&state, &argLocations, minParams, paramPack, argPack]() { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + size_t mp = minParams; + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + mp = getMinParameterCount(&state.log, paramPack); + + state.reportError(TypeError{location, CountMismatch{mp, std::distance(begin(argPack), end(argPack))}}); + }; + while (true) { state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; @@ -3472,6 +3659,8 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. while (paramIter != endIter) { state.tryUnify(vtp->ty, *paramIter); @@ -3506,14 +3695,22 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable(t)) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.get(t)) { } // ok else { if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) minParams = getMinParameterCount(&state.log, paramPack); - bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log); + + bool isVariadic = false; + if (FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic) + { + std::optional tail = flatten(paramPack, state.log).second; + if (tail) + isVariadic = Luau::isVariadic(*tail); + } + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; } @@ -3532,14 +3729,7 @@ void TypeChecker::checkArgumentList( unify(errorRecoveryType(scope), *argIter, state.location); ++argIter; } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3551,6 +3741,21 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + if (FFlag::LuauLowerBoundsCalculation && vtp->hidden) + { + // We know that this function can technically be oversaturated, but we have its definition and we + // know that it's useless. + + TypeId e = errorRecoveryType(scope); + while (argIter != endIter) + { + unify(e, *argIter, state.location); + ++argIter; + } + + reportCountMismatchError(); + return; + } // Function is variadic and requires that all subsequent parameters // be compatible with a type. size_t argIndex = paramIndex; @@ -3595,14 +3800,7 @@ void TypeChecker::checkArgumentList( } else if (state.log.getMutable(tail)) { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } } @@ -3661,7 +3859,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = follow(actualFunctionType); TypePackId retPack; - if (!FFlag::LuauWidenIfSupertypeIsFree2) + if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) { retPack = freshTypePack(scope->level); } @@ -3809,21 +4007,49 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{errorRecoveryTypePack(scope)}}; } - if (get(fn)) + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. - TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree2) + if (useConstrainedIntersections()) { - UnifierOptions options; - options.isFunctionCall = true; - unify(r, fn, expr.location, options); + // This ternary is phrased deliberately. We need ties between sibling scopes to bias toward ftv->level. + const TypeLevel level = scope->level.subsumes(ftv->level) ? scope->level : ftv->level; + + std::vector adjustedArgTypes; + auto it = begin(argPack); + auto endIt = end(argPack); + Widen widen{¤tModule->internalTypes}; + for (; it != endIt; ++it) + { + TypeId t = *it; + TypeId widened = widen.substitute(t).value_or(t); // Surely widening is infallible + adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widened}})); + } + + TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); + + TxnLog log; + promoteTypeLevels(log, ¤tModule->internalTypes, level, retPack); + log.commit(); + + *asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack}; + return {{retPack}}; } else - unify(fn, r, expr.location); - return {{retPack}}; + { + TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); + if (FFlag::LuauWidenIfSupertypeIsFree2) + { + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + } + else + unify(fn, r, expr.location); + return {{retPack}}; + } } std::vector metaArgLocations; @@ -4363,10 +4589,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s bool Instantiation::isDirty(TypeId ty) { - if (log->getMutable(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + return false; + return true; + } else + { return false; + } } bool Instantiation::isDirty(TypePackId tp) @@ -4414,14 +4647,21 @@ TypePackId Instantiation::clean(TypePackId tp) bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + return true; + // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. // It's OK to use vector equality here, since we always generate fresh generics // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } else + { return false; + } } bool ReplaceGenerics::isDirty(TypeId ty) @@ -4464,16 +4704,24 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { + if (ty->persistent) + return false; + if (const TableTypeVar* ttv = log->getMutable(ty)) return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); else if (log->getMutable(ty)) return true; + else if (get(ty)) + return true; else return false; } bool Anyification::isDirty(TypePackId tp) { + if (tp->persistent) + return false; + if (log->getMutable(tp)) return true; else @@ -4494,7 +4742,16 @@ TypeId Anyification::clean(TypeId ty) clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; } - return addType(std::move(clone)); + TypeId res = addType(std::move(clone)); + asMutable(res)->normal = ty->normal; + return res; + } + else if (auto ctv = get(ty)) + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; } else return anyType; @@ -4511,16 +4768,34 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location ty = follow(ty); const FunctionTypeVar* ftv = get(ty); - if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) - return ty; + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } - Luau::quantify(ty, scope->level); return ty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + if (FFlag::LuauTypecheckOptPass) + { + const FunctionTypeVar* ftv = get(follow(ty)); + if (ftv && ftv->hasNoGenerics) + return ty; + } + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; + + if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; @@ -4533,8 +4808,18 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); + if (anyification.normalizationTooComplex) + reportError(location, NormalizationTooComplex{}); if (any.has_value()) return *any; else @@ -4546,7 +4831,15 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4830,6 +5123,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation ToStringOptions opts; opts.exhaustive = true; opts.maxTableLength = 0; + opts.useLineBreaks = true; TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); @@ -5283,7 +5577,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, bool needsClone = follow(tf.type) == target; bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); TableTypeVar* ttv = getMutableTableType(target); - + if (shouldMutate && ttv && needsClone) { // Substitution::clone is a shallow clone. If this is a metatable type, we @@ -5487,25 +5781,82 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV // We need to search in the provided Scope. Find t.x.y first. // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. - const auto& [symbol, keys] = getFullName(lvalue); + if (!FFlag::LuauTypecheckOptPass) + { + const auto& [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + std::optional found; + + std::vector childKeys; + const LValue* currentLValue = &lvalue; + while (currentLValue) + { + if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + { + found = it->second; + break; + } + + childKeys.push_back(*currentLValue); + currentLValue = baseof(*currentLValue); + } + + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = *it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; + } + + // No entry for it at all. Can happen when LValue root is a global. + return std::nullopt; + } + + const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; while (currentScope) { std::optional found; - std::vector childKeys; - const LValue* currentLValue = &lvalue; - while (currentLValue) + const LValue* topLValue = nullptr; + + for (topLValue = &lvalue; topLValue; topLValue = baseof(*topLValue)) { - if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.find(*topLValue); it != currentScope->refinements.end()) { found = it->second; break; } - - childKeys.push_back(*currentLValue); - currentLValue = baseof(*currentLValue); } if (!found) @@ -5521,9 +5872,15 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV } } + // We need to walk the l-value path in reverse, so we collect components into a vector + std::vector childKeys; + + for (const LValue* curr = &lvalue; curr != topLValue; curr = baseof(*curr)) + childKeys.push_back(curr); + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) { - const LValue& key = *it; + const LValue& key = **it; // Symbol can happen. Skip. if (get(key)) @@ -5938,6 +6295,11 @@ bool TypeChecker::isNonstrictMode() const return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck); } +bool TypeChecker::useConstrainedIntersections() const +{ + return FFlag::LuauLowerBoundsCalculation && !isNonstrictMode(); +} + std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location) { TypePackId expectedTypePack = addTypePack({}); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5bb05234..30503233 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -104,7 +104,7 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } -TypePackIterator begin(TypePackId tp, TxnLog* log) +TypePackIterator begin(TypePackId tp, const TxnLog* log) { return TypePackIterator{tp, log}; } @@ -256,7 +256,7 @@ size_t size(const TypePack& tp, TxnLog* log) return result; } -std::optional first(TypePackId tp) +std::optional first(TypePackId tp, bool ignoreHiddenVariadics) { auto it = begin(tp); auto endIter = end(tp); @@ -266,7 +266,7 @@ std::optional first(TypePackId tp) if (auto tail = it.tail()) { - if (auto vtp = get(*tail)) + if (auto vtp = get(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics)) return vtp->ty; } @@ -299,6 +299,46 @@ std::pair, std::optional> flatten(TypePackId tp) return {res, iter.tail()}; } +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; +} + +bool isVariadic(TypePackId tp) +{ + return isVariadic(tp, *TxnLog::empty()); +} + +bool isVariadic(TypePackId tp, const TxnLog& log) +{ + std::optional tail = flatten(tp, log).second; + + if (!tail) + return false; + + if (log.get(*tail)) + return true; + + if (auto vtp = log.get(*tail); vtp && !vtp->hidden) + return true; + + return false; +} + TypePackVar* asMutable(TypePackId tp) { return const_cast(tp); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index dbc412fc..0fbfdbf0 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -177,7 +177,7 @@ bool maybeString(TypeId ty) if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) { ty = follow(ty); - + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) return true; @@ -366,7 +366,7 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit, "hasLength"); ty = follow(ty); @@ -654,9 +654,9 @@ static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persi static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; -static TypeVar anyType_{AnyTypeVar{}}; -static TypeVar errorType_{ErrorTypeVar{}}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; +static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; +static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -698,7 +698,7 @@ TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); @@ -802,6 +802,7 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; + asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); @@ -838,6 +839,11 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto ctv = get(t)) + { + for (TypeId opt : ctv->parts) + queue.push_back(opt); + } else if (auto mtv = get(t)) { queue.push_back(mtv->table); @@ -899,6 +905,16 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } +std::optional getLevel(TypePackId tp) +{ + tp = follow(tp); + + if (auto ftv = get(tp)) + return ftv->level; + else + return std::nullopt; +} + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 398dc9e2..f9ea58cc 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,9 +14,12 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); +LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) @@ -27,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -126,7 +130,6 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel visitTypeVarOnce(ty, ptl, seen); } -// TODO: use this and make it static. void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -305,8 +308,7 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , log(parentLog) @@ -326,6 +328,7 @@ Unifier::Unifier(TypeArena* types, Mode mode, std::vector 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTy = log.follow(superTy); @@ -354,6 +369,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (log.get(superTy)) + return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -442,7 +460,18 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); - if (get(subTy) || get(subTy)) + if (get(subTy)) + { + if (anyIsTop) + { + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + return; + } + else + return tryUnifyWithAny(superTy, subTy); + } + + if (get(subTy)) return tryUnifyWithAny(superTy, subTy); bool cacheEnabled; @@ -484,7 +513,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionTypeVar* uv = log.getMutable(subTy)) + if (log.get(subTy)) + tryUnifyWithConstrainedSubTypeVar(subTy, superTy); + else if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); } @@ -946,7 +977,7 @@ struct WeirdIter LUAU_ASSERT(log.getMutable(newTail)); level = log.getMutable(packId)->level; - log.replace(packId, Unifiable::Bound(newTail)); + log.replace(packId, BoundTypePack(newTail)); packId = newTail; pack = log.getMutable(newTail); index = 0; @@ -994,39 +1025,32 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall tryUnify_(subTp, superTp, isFunctionCall); } -static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) -{ - tp = log.follow(tp); - - std::vector flattened; - std::optional tail = std::nullopt; - - TypePackIterator it(tp, &log); - - for (; it != end(tp); ++it) - { - flattened.push_back(*it); - } - - tail = it.tail(); - - return {flattened, tail}; -} - /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypePackId tryUnify_"); ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTp = log.follow(superTp); @@ -1087,8 +1111,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If the size of two heads does not match, but both packs have free tail // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = logAwareFlatten(superTp, log); - auto [subTypes, subTail] = logAwareFlatten(subTp, log); + auto [superTypes, superTail] = flatten(superTp, log); + auto [subTypes, subTail] = flatten(subTp, log); bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && (subTail && log.getMutable(*subTail)); @@ -1165,19 +1189,20 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && superIter.good() && isOptional(*superIter)) { superIter.advance(); continue; } - else if (subIter.good() && isOptional(*subIter)) + else if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && subIter.good() && isOptional(*subIter)) { subIter.advance(); continue; } // In nonstrict mode, any also marks an optional argument. - else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && + log.getMutable(log.follow(*superIter))) { superIter.advance(); continue; @@ -1195,7 +1220,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if (!isFunctionCall && subIter.good()) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && !isFunctionCall && subIter.good()) { // Sometimes it is ok to pass too many arguments return; @@ -1418,14 +1443,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) + if (subIter == subTable->props.end() && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) missingProperties.push_back(propName); } else { bool isAny = log.getMutable(log.follow(superProp.type)); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) + if (subIter == subTable->props.end() && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && + !isAny) missingProperties.push_back(propName); } } @@ -1438,8 +1466,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // And vice versa if we're invariant - if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && - superTable->state != TableState::Free) + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) { for (const auto& [propName, subProp] : subTable->props) { @@ -1453,7 +1480,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else { bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) + if (superIter == superTable->props.end() && + (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } } @@ -1499,13 +1527,15 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if (FFlag::LuauAnyInIsOptionalIsOptional && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && + (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: should isOptional(anyType) be true? @@ -1664,9 +1694,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (FFlag::LuauTxnLogDontRetryForIndexers) { - // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); } else if (FFlag::LuauTxnLogCheckForInvalidation) { @@ -1921,8 +1951,6 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); - Unifier innerState = makeChildUnifier(); - std::vector missingPropertiesInSuper; bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; bool errorReported = false; @@ -1944,6 +1972,8 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } + Unifier innerState = makeChildUnifier(); + // Tables must have exactly the same props and their types must all unify for (const auto& it : superTable->props) { @@ -2376,6 +2406,180 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); } +void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) +{ + const ConstrainedTypeVar* subConstrained = get(subTy); + if (!subConstrained) + ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); + + const std::vector& subTyParts = subConstrained->parts; + + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + + const size_t count = subTyParts.size(); + + for (size_t i = 0; i < count; ++i) + { + TypeId type = subTyParts[i]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (i == count - 1) + log.concat(std::move(innerState.log)); + + ++i; + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + + if (!innerState.errors.empty()) + { + failed = true; + break; + } + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + else + log.replace(subTy, BoundTypeVar{superTy}); +} + +void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) +{ + ConstrainedTypeVar* superC = log.getMutable(superTy); + if (!superC) + ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); + + // subTy could be a + // table + // metatable + // class + // function + // primitive + // free + // generic + // intersection + // union + // Do we really just tack it on? I think we might! + // We can certainly do some deduplication. + // Is there any point to deducing Player|Instance when we could just reduce to Instance? + // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? + // Maybe we do a simplification step during quantification. + + auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); + if (it != superC->parts.end()) + return; + + superC->parts.push_back(subTy); +} + +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) +{ + // The duplication between this and regular typepack unification is tragic. + + auto superIter = begin(superTy, &log); + auto superEndIter = end(superTy); + + auto subIter = begin(subTy, &log); + auto subEndIter = end(subTy); + + int count = FInt::LuauTypeInferLowerBoundsIterationLimit; + + for (; subIter != subEndIter; ++subIter) + { + if (0 >= --count) + ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); + + if (superIter != superEndIter) + { + tryUnify_(*subIter, *superIter); + ++superIter; + continue; + } + + if (auto t = superIter.tail()) + { + TypePackId tailPack = follow(*t); + + if (log.get(tailPack)) + occursCheck(tailPack, subTy); + + FreeTypePack* freeTailPack = log.getMutable(tailPack); + if (!freeTailPack) + return; + + TypeLevel level = freeTailPack->level; + + TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); + + for (; subIter != subEndIter; ++subIter) + { + tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}})); + } + + tp->tail = subIter.tail(); + } + + return; + } + + if (superIter != superEndIter) + { + if (auto subTail = subIter.tail()) + { + TypePackId subTailPack = follow(*subTail); + if (get(subTailPack)) + { + TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); + + for (; superIter != superEndIter; ++superIter) + tp->head.push_back(*superIter); + } + } + else + { + while (superIter != superEndIter) + { + if (!isOptional(*superIter)) + { + errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}}); + return; + } + ++superIter; + } + } + + return; + } + + // Both iters are at their respective tails + auto subTail = subIter.tail(); + auto superTail = superIter.tail(); + if (subTail && superTail) + tryUnify(*subTail, *superTail); + else if (subTail) + { + const FreeTypePack* freeSubTail = log.getMutable(*subTail); + if (freeSubTail) + { + log.replace(*subTail, TypePack{}); + } + } + else if (superTail) + { + const FreeTypePack* freeSuperTail = log.getMutable(*superTail); + if (freeSuperTail) + { + log.replace(*superTail, TypePack{}); + } + } +} + void Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2385,7 +2589,8 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypeId"); auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); @@ -2425,6 +2630,11 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } void Unifier::occursCheck(TypePackId needle, TypePackId haystack) @@ -2450,7 +2660,8 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypePackId"); while (!log.getMutable(haystack)) { @@ -2474,7 +2685,23 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + if (FFlag::LuauTypecheckOptPass) + { + Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; + } + + Unifier u = Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +void Unifier::reportError(TypeError err) +{ + errors.push_back(std::move(err)); } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 65939bee..f8543111 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -32,6 +32,7 @@ class DenseHashTable { public: class const_iterator; + class iterator; DenseHashTable(const Key& empty_key, size_t buckets = 0) : count(0) @@ -43,7 +44,7 @@ public: // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 if (buckets) - data.resize(buckets, ItemInterface::create(empty_key)); + resize_data(buckets); } void clear() @@ -125,7 +126,7 @@ public: if (data.empty() && data.capacity() >= newsize) { LUAU_ASSERT(count == 0); - data.resize(newsize, ItemInterface::create(empty_key)); + resize_data(newsize); return; } @@ -169,6 +170,21 @@ public: return const_iterator(this, data.size()); } + iterator begin() + { + size_t start = 0; + + while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + start++; + + return iterator(this, start); + } + + iterator end() + { + return iterator(this, data.size()); + } + size_t size() const { return count; @@ -233,7 +249,82 @@ public: size_t index; }; + class iterator + { + public: + iterator() + : set(0) + , index(0) + { + } + + iterator(DenseHashTable* set, size_t index) + : set(set) + , index(index) + { + } + + MutableItem& operator*() const + { + return *reinterpret_cast(&set->data[index]); + } + + MutableItem* operator->() const + { + return reinterpret_cast(&set->data[index]); + } + + bool operator==(const iterator& other) const + { + return set == other.set && index == other.index; + } + + bool operator!=(const iterator& other) const + { + return set != other.set || index != other.index; + } + + iterator& operator++() + { + size_t size = set->data.size(); + + do + { + index++; + } while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key)); + + return *this; + } + + iterator operator++(int) + { + iterator res = *this; + ++*this; + return res; + } + + private: + DenseHashTable* set; + size_t index; + }; + private: + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + data.resize(count, ItemInterface::create(empty_key)); + } + + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + size_t size = data.size(); + data.resize(count); + + for (size_t i = size; i < count; i++) + data[i].first = empty_key; + } + std::vector data; size_t count; Key empty_key; @@ -290,6 +381,7 @@ class DenseHashSet public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashSet(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -336,6 +428,16 @@ public: { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; // This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has @@ -348,6 +450,7 @@ class DenseHashMap public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashMap(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -401,10 +504,21 @@ public: { return impl.begin(); } + const_iterator end() const { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; } // namespace Luau diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index d7d867f4..4f3dbbd5 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -173,7 +173,7 @@ public: } const Lexeme& next(); - const Lexeme& next(bool skipComments); + const Lexeme& next(bool skipComments, bool updatePrevLocation); void nextline(); Lexeme lookahead(); diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 70c6c78d..5dd4f04e 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -349,13 +349,11 @@ void Lexer::setReadNames(bool read) const Lexeme& Lexer::next() { - return next(this->skipComments); + return next(this->skipComments, true); } -const Lexeme& Lexer::next(bool skipComments) +const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { - bool first = true; - // in skipComments mode we reject valid comments do { @@ -363,11 +361,11 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - if (!FFlag::LuauParseLocationIgnoreCommentSkip || first) + if (!FFlag::LuauParseLocationIgnoreCommentSkip || updatePrevLocation) prevLocation = lexeme.location; lexeme = readNext(); - first = false; + updatePrevLocation = false; } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f9d32178..badd3fd3 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,6 +11,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) +LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau { @@ -2789,7 +2790,7 @@ void Parser::nextLexeme() { if (options.captureComments) { - Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) { @@ -2813,7 +2814,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false).type; + type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type; } } else diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6330bf1f..8ef69e75 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -1386,8 +1386,8 @@ struct Compiler const Constant* cv = constants.find(expr->index); - if (cv && cv->type == Constant::Type_Number && double(int(cv->valueNumber)) == cv->valueNumber && cv->valueNumber >= 1 && - cv->valueNumber <= 256) + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) { uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp new file mode 100644 index 00000000..d8511bdb --- /dev/null +++ b/Compiler/src/CostModel.cpp @@ -0,0 +1,258 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "CostModel.h" + +#include "Luau/Common.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +inline uint64_t parallelAddSat(uint64_t x, uint64_t y) +{ + uint64_t s = x + y; + uint64_t m = s & 0x8080808080808080ull; // saturation mask + + return (s ^ m) | (m - (m >> 7)); +} + +struct Cost +{ + static const uint64_t kLiteral = ~0ull; + + // cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant + uint64_t model; + // constant mask: 8-byte 0xff mask; equal to all ff's for literals, for variables only byte #i (1+) is set to align with model + uint64_t constant; + + Cost(int cost = 0, uint64_t constant = 0) + : model(cost < 0x7f ? cost : 0x7f) + , constant(constant) + { + } + + Cost operator+(const Cost& other) const + { + Cost result; + result.model = parallelAddSat(model, other.model); + return result; + } + + Cost& operator+=(const Cost& other) + { + model = parallelAddSat(model, other.model); + constant = 0; + return *this; + } + + static Cost fold(const Cost& x, const Cost& y) + { + uint64_t newmodel = parallelAddSat(x.model, y.model); + uint64_t newconstant = x.constant & y.constant; + + // the extra cost for folding is 1; the discount is 1 for the variable that is shared by x&y (or whichever one is used in x/y if the other is + // literal) + uint64_t extra = (newconstant == kLiteral) ? 0 : (1 | (0x0101010101010101ull & newconstant)); + + Cost result; + result.model = parallelAddSat(newmodel, extra); + result.constant = newconstant; + + return result; + } +}; + +struct CostVisitor : AstVisitor +{ + DenseHashMap vars; + Cost result; + + CostVisitor() + : vars(nullptr) + { + } + + Cost model(AstExpr* node) + { + if (AstExprGroup* expr = node->as()) + { + return model(expr->expr); + } + else if (node->is() || node->is() || node->is() || + node->is()) + { + return Cost(0, Cost::kLiteral); + } + else if (AstExprLocal* expr = node->as()) + { + const uint64_t* i = vars.find(expr->local); + + return Cost(0, i ? *i : 0); // locals typically don't require extra instructions to compute + } + else if (node->is()) + { + return 1; + } + else if (node->is()) + { + return 3; + } + else if (AstExprCall* expr = node->as()) + { + Cost cost = 3; + cost += model(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + { + Cost ac = model(expr->args.data[i]); + // for constants/locals we still need to copy them to the argument list + cost += ac.model == 0 ? Cost(1) : ac; + } + + return cost; + } + else if (AstExprIndexName* expr = node->as()) + { + return model(expr->expr) + 1; + } + else if (AstExprIndexExpr* expr = node->as()) + { + return model(expr->expr) + model(expr->index) + 1; + } + else if (AstExprFunction* expr = node->as()) + { + return 10; // high baseline cost due to allocation + } + else if (AstExprTable* expr = node->as()) + { + Cost cost = 10; // high baseline cost due to allocation + + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + cost += model(item.key); + + cost += model(item.value); + cost += 1; + } + + return cost; + } + else if (AstExprUnary* expr = node->as()) + { + return Cost::fold(model(expr->expr), Cost(0, Cost::kLiteral)); + } + else if (AstExprBinary* expr = node->as()) + { + return Cost::fold(model(expr->left), model(expr->right)); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + return model(expr->expr); + } + else if (AstExprIfElse* expr = node->as()) + { + return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + return {}; + } + } + + void assign(AstExpr* expr) + { + // variable assignments reset variable mask, so that further uses of this variable aren't discounted + // this doesn't work perfectly with backwards control flow like loops, but is good enough for a single pass + if (AstExprLocal* lv = expr->as()) + if (uint64_t* i = vars.find(lv->local)) + *i = 0; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside model() which makes it easier to get the resulting value of the subexpression + result += model(node); + + return false; + } + + bool visit(AstStat* node) override + { + if (node->is()) + result += 2; + else if (node->is() || node->is() || node->is() || node->is()) + result += 2; + else if (node->is() || node->is()) + result += 1; + + return true; + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->values.size; ++i) + { + Cost arg = model(node->values.data[i]); + + // propagate constant mask from expression through variables + if (arg.constant && i < node->vars.size) + vars[node->vars.data[i]] = arg.constant; + + result += arg; + } + + return false; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + return true; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + + // if lhs is not a local, setting it requires an extra table operation + result += node->var->is() ? 1 : 2; + + return true; + } +}; + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount) +{ + CostVisitor visitor; + for (size_t i = 0; i < varCount && i < 7; ++i) + visitor.vars[vars[i]] = 0xffull << (i * 8 + 8); + + root->visit(&visitor); + + return visitor.result.model; +} + +int computeCost(uint64_t model, const bool* varsConst, size_t varCount) +{ + int cost = int(model & 0x7f); + + // don't apply discounts to what is likely a saturated sum + if (cost == 0x7f) + return cost; + + for (size_t i = 0; i < varCount && i < 7; ++i) + cost -= int((model >> (8 * i + 8)) & 0x7f) * varsConst[i]; + + return cost; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/CostModel.h b/Compiler/src/CostModel.h new file mode 100644 index 00000000..c27861ec --- /dev/null +++ b/Compiler/src/CostModel.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +namespace Luau +{ +namespace Compile +{ + +// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); + +// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 6f110f1f..60e5dfda 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -32,11 +32,13 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/Compiler.cpp Compiler/src/Builtins.cpp Compiler/src/ConstantFolding.cpp + Compiler/src/CostModel.cpp Compiler/src/TableShape.cpp Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp Compiler/src/Builtins.h Compiler/src/ConstantFolding.h + Compiler/src/CostModel.h Compiler/src/TableShape.h Compiler/src/ValueTracking.h ) @@ -58,6 +60,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h + Analysis/include/Luau/Normalize.h Analysis/include/Luau/Predicate.h Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h @@ -94,6 +97,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/LValue.cpp Analysis/src/Module.cpp + Analysis/src/Normalize.cpp Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp @@ -216,6 +220,7 @@ if(TARGET Luau.UnitTest) tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp + tests/CostModel.test.cpp tests/Config.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp @@ -224,6 +229,7 @@ if(TARGET Luau.UnitTest) tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp + tests/Normalize.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/StringUtils.test.cpp diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 1c75c0b0..dc40b6ef 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -34,7 +34,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) -LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary, false) +LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 @@ -390,6 +390,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) setarrayvector(L, t, nasize); /* create new hash part with appropriate size */ setnodevector(L, t, nhsize); + /* used for the migration check at the end */ + LuaNode* nnew = t->node; if (nasize < oldasize) { /* array part must shrink? */ t->sizearray = nasize; @@ -413,6 +415,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) /* shrink array */ luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } + /* used for the migration check at the end */ + TValue* anew = t->array; /* re-insert elements from hash part */ if (FFlag::LuauTableRehashRework) { @@ -441,14 +445,30 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) } } + /* make sure we haven't recursively rehashed during element migration */ + LUAU_ASSERT(nnew == t->node); + LUAU_ASSERT(anew == t->array); + if (nold != dummynode) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ } +static int adjustasize(Table* t, int size, const TValue* ek) +{ + LUAU_ASSERT(FFlag::LuauTableNewBoundary2); + bool tbound = t->node != dummynode || size < t->sizearray; + int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; + /* move the array size up until the boundary is guaranteed to be inside the array part */ + while (size + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, size + 1)))) + size++; + return size; +} + void luaH_resizearray(lua_State* L, Table* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); - resize(L, t, nasize, nsize); + int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize; + resize(L, t, asize, nsize); } void luaH_resizehash(lua_State* L, Table* t, int nhsize) @@ -470,21 +490,12 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) totaluse++; /* compute new size for array part */ int na = computesizes(nums, &nasize); + int nh = totaluse - na; /* enforce the boundary invariant; for performance, only do hash lookups if we must */ - if (FFlag::LuauTableNewBoundary) - { - bool tbound = t->node != dummynode || nasize < t->sizearray; - int ekindex = ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; - /* move the array size up until the boundary is guaranteed to be inside the array part */ - while (nasize + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, nasize + 1)))) - { - nasize++; - na++; - } - } + if (FFlag::LuauTableNewBoundary2) + nasize = adjustasize(t, nasize, ek); /* resize the table to new computed sizes */ - LUAU_ASSERT(na <= totaluse); - resize(L, t, nasize, totaluse - na); + resize(L, t, nasize, nh); } /* @@ -544,7 +555,7 @@ static LuaNode* getfreepos(Table* t) static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* enforce boundary invariant */ - if (FFlag::LuauTableNewBoundary && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1) { rehash(L, t, key); /* grow table */ @@ -735,7 +746,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) { - LUAU_ASSERT(!FFlag::LuauTableNewBoundary); + LUAU_ASSERT(!FFlag::LuauTableNewBoundary2); unsigned int i = j; /* i is zero or a present index */ j++; /* find `i' and `j' such that i is present and j is not */ @@ -820,7 +831,7 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } - else if (FFlag::LuauTableNewBoundary) + else if (FFlag::LuauTableNewBoundary2) { /* validate boundary invariant */ LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 41887f4b..9c1f387e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -199,7 +199,7 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; + void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) { diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 34949efb..39c60eac 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,7 @@ #include -LUAU_FASTFLAG(LuauTableNewBoundary) +LUAU_FASTFLAG(LuauTableNewBoundary2) // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -2268,7 +2268,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (FFlag::LuauTableNewBoundary || (h->lsizenode == 0 && ttisnil(gval(h->node)))) + else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node)))) { // fallthrough to exit VM_NEXT(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp new file mode 100644 index 00000000..ec04932f --- /dev/null +++ b/tests/CostModel.test.cpp @@ -0,0 +1,101 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" + +#include "doctest.h" + +using namespace Luau; + +namespace Luau +{ +namespace Compile +{ + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau + +TEST_SUITE_BEGIN("CostModel"); + +static uint64_t modelFunction(const char* source) +{ + Allocator allocator; + AstNameTable names(allocator); + + ParseResult result = Parser::parse(source, strlen(source), names, allocator); + REQUIRE(result.root != nullptr); + + AstStatFunction* func = result.root->body.data[0]->as(); + REQUIRE(func); + + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size); +} + +TEST_CASE("Expression") +{ + uint64_t model = modelFunction(R"( +function test(a, b, c) + return a + (b + 1) * (b + 1) - c +end +)"); + + const bool args1[] = {false, false, false}; + const bool args2[] = {false, true, false}; + + CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3)); +} + +TEST_CASE("PropagateVariable") +{ + uint64_t model = modelFunction(R"( +function test(a) + local b = a * a * a + return b * b +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("LoopAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,3 do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // loop baseline cost is 2 + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("MutableVariable") +{ + uint64_t model = modelFunction(R"( +function test(a, b) + local x = a * a + x += b + return x * x +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 9dc9feee..d8b37a65 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -231,7 +231,7 @@ ModulePtr Fixture::getMainModule() SourceModule* Fixture::getMainSourceModule() { - return frontend.getSourceModule(fromString("MainModule")); + return frontend.getSourceModule(fromString(mainModuleName)); } std::optional Fixture::getPrimitiveType(TypeId ty) @@ -259,7 +259,7 @@ std::optional Fixture::getType(const std::string& name) TypeId Fixture::requireType(const std::string& name) { std::optional ty = getType(name); - REQUIRE(bool(ty)); + REQUIRE_MESSAGE(bool(ty), "Unable to requireType \"" << name << "\""); return follow(*ty); } diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index 6711d979..1d2ad645 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -68,7 +68,9 @@ TEST_CASE("encode_tables") REQUIRE(parseResult.errors.size() == 0); std::string json = toJson(parseResult.root); - CHECK(json == R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); } TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 9ce9a4c2..05ee9a7b 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -597,8 +597,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff("LuauLintNoRobloxBits", true); - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1439,6 +1437,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index de063121..738893db 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -2,6 +2,7 @@ #include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/Scope.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -9,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -42,29 +45,23 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -90,12 +87,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId counterType = requireType("Cyclic"); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); + CloneState cloneState; + TypeId counterCopy = clone(counterType, dest, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -112,8 +106,11 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); - CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type - CHECK_EQ(2, dest.typeVars.size()); // One table and one function + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(3, dest.typePacks.size()); // function args, its return type, and the hidden any... pack + else + CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type + CHECK_EQ(2, dest.typeVars.size()); // One table and one function } TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") @@ -143,15 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); + TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -161,15 +155,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); + TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -191,12 +182,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") std::nullopt, &exampleMetaClass, {}, {}}}; TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -216,16 +204,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId clonedTy = clone(&freeTy, dest, cloneState); CHECK_EQ("any", toString(clonedTy)); CHECK(cloneState.encounteredFreeType); cloneState = {}; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); + TypePackId clonedTp = clone(&freeTp, dest, cloneState); CHECK_EQ("...any", toString(clonedTp)); CHECK(cloneState.encounteredFreeType); } @@ -237,16 +223,32 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") ttv->state = TableState::Free; TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Sealed); CHECK(cloneState.encounteredFreeType); } +TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") +{ + TypeArena src; + + TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {getSingletonTypes().numberType, getSingletonTypes().stringType}}); + + TypeArena dest; + CloneState cloneState; + + TypeId cloned = clone(constrained, dest, cloneState); + CHECK_NE(constrained, cloned); + + const ConstrainedTypeVar* ctv = get(cloned); + REQUIRE_EQ(2, ctv->parts.size()); + CHECK_EQ(getSingletonTypes().numberType, ctv->parts[0]); + CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); +} + TEST_CASE_FIXTURE(Fixture, "clone_self_property") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; @@ -284,6 +286,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + ScopedFastFlag sff{"LuauRecursionLimitException", true}; TypeArena src; @@ -299,11 +302,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") } TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index d3faea2a..a8a12b69 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -275,4 +275,38 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); } +TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): (boolean, string?) + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "returning_too_many_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): boolean + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp new file mode 100644 index 00000000..5a84201a --- /dev/null +++ b/tests/Normalize.test.cpp @@ -0,0 +1,967 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Normalize.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +struct NormalizeFixture : Fixture +{ + ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; + ScopedFastFlag sff2{"LuauTableSubtypingVariance2", true}; +}; + +void createSomeClasses(TypeChecker& typeChecker) +{ + auto& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr}); + + ClassTypeVar* parentClass = getMutable(parentType); + parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; + + parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; + + addGlobalBinding(typeChecker, "Parent", {parentType}); + typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr}); + + ClassTypeVar* childClass = getMutable(childType); + childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; + + addGlobalBinding(typeChecker, "Child", {childType}); + typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr}); + + addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); + typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + + freeze(arena); +} + +static bool isSubtype(TypeId a, TypeId b) +{ + InternalErrorReporter ice; + return isSubtype(a, b, ice); +} + +TEST_SUITE_BEGIN("isSubtype"); + +TEST_CASE_FIXTURE(NormalizeFixture, "primitives") +{ + check(R"( + local a = 41 + local b = 32 + + local c = "hello" + local d = "world" + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(d, c)); + CHECK(!isSubtype(d, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions") +{ + check(R"( + function a(x: number): number return x end + function b(x: number): number return x end + + function c(x: number?): number return x end + function d(x: number): number? return x end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(d, a)); + CHECK(isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") +{ + check(R"( + function a(n: number) return "string" end + function b(q: any) return 5 :: any end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + // Intuition: + // We cannot use b where a is required because we cannot rely on b to return a string. + // We cannot use a where b is required because we cannot rely on a to accept non-number arguments. + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +{ + check(R"( + type A = (any) -> () + type B = (any, any) -> () + type T = A & B + + local a: A + local b: B + local t: T + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(a, b)); // !! + CHECK(!isSubtype(b, a)); + + CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +{ + check(R"( + local a: (number) -> () + local b: () -> () + + local c: () -> number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(c, a)); + + CHECK(!isSubtype(a, b)); + CHECK(!isSubtype(c, b)); + + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +{ + /* + * (T0..TN) <: (T0..TN, A?) + * (T0..TN) <: (T0..TN, any) + * (T0..TN, A?) R <: U -> S if U <: T and R <: S + * A | B <: T if A <: T and B <: T + * T <: A | B if T <: A or T <: B + */ + check(R"( + local a: (number?) -> () + local b: (number) -> () + local c: (number, number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, number?) -> () <: (number) -> (number) + * The packs have inequal lengths, but (number) <: (number, number?) + * and number <: number + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * because (number, number?) () () + * because (number, number?) () + local b: (number) -> () + local c: (number, any) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, any) -> () (number) + * The packs have inequal lengths + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(a, c)); + + /* + * (number) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") +{ + check(R"( + local a: (...number) -> () + local b: (...number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") +{ + check(R"( + local a: (...number) -> () + local b: (number, number) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "union") +{ + check(R"( + local a: number | string + local b: number + local c: string + local d: number? + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(b, d)); + CHECK(!isSubtype(d, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") +{ + check(R"( + local a: {x: number} + local b: {x: number?} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") +{ + check(R"( + local a: {x: number} + local b: {x: any} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection") +{ + check(R"( + local a: number & string + local b: number + local c: string + local d: number & nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); + + CHECK(!isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") +{ + check(R"( + local a: number & string + local b: number | nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop") +{ + check(R"( + type T = {x: {y: number}} & {x: {y: string}} + local a: T + )"); + + CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a"))); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "tables") +{ + check(R"( + local a: {x: number} + local b: {x: any} + local c: {y: number} + local d: {x: number, y: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); + + CHECK(!isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(d, b)); + CHECK(!isSubtype(b, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") +{ + check(R"( + local a: {[string]: number} + local b: {[string]: any} + local c: {[string]: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") +{ + check(R"( + local a: {x: number} + local b: {[string]: number} + local c: {} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(!isSubtype(c, b)); + CHECK(isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") +{ + check(R"( + type A = {method: (A) -> ()} + local a: A + + type B = {method: (any) -> ()} + local b: B + + type C = {method: (C) -> ()} + local c: C + + type D = {method: (D) -> (), another: (D) -> ()} + local d: D + + type E = {method: (A) -> (), another: (E) -> ()} + local e: E + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + TypeId e = requireType("e"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(e, a)); + CHECK(!isSubtype(a, e)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "classes") +{ + createSomeClasses(typeChecker); + + TypeId p = typeChecker.globalScope->lookupType("Parent")->type; + TypeId c = typeChecker.globalScope->lookupType("Child")->type; + TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + + CHECK(isSubtype(c, p)); + CHECK(!isSubtype(p, c)); + CHECK(!isSubtype(u, p)); + CHECK(!isSubtype(p, u)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) +{ + check(R"( + local T = {} + T.__index = T + function T.new() + return setmetatable({}, T) + end + + function T:method() end + + local a: typeof(T.new) + local b: {method: (any) -> ()} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_tables") +{ + check(R"( + type T = {x: number} & ({x: number} & {y: string?}) + local t: T + )"); + + CHECK("{| x: number, y: string? |}" == toString(requireType("t"))); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Normalize"); + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_disjoint_tables") +{ + check(R"( + type T = {a: number} & {b: number} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: number, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number & string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_confluent_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: string, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship") +{ + check(R"( + local t: {x: number} | {x: number?} + )"); + + ModulePtr tempModule{new Module}; + + // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze + // the arena that the type lives in. + ModulePtr mainModule = getMainModule(); + unfreeze(mainModule->internalTypes); + + TypeId tType = requireType("t"); + normalize(tType, tempModule, *typeChecker.iceHandler); + + CHECK_EQ("{| x: number? |}", toString(tType, {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions") +{ + check(R"( + type T = ((any) -> string) & ((number) -> string) + local t: T + )"); + + CHECK_EQ("(any) -> string", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + --!nonstrict + + if Math.random() then + return function(initialState, handlers) + return function(state, action) + return state + end + end + else + return function(initialState, handlers) + return function(state, action) + return state + end + end + end + )"); + + CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") +{ + check(R"( + function foo(x:number, y:number) + return x + y + end + )"); + + CHECK_EQ("(number, number) -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function") +{ + check(R"( + function apply(f, x) + return f(x) + end + + local a = apply(function(x: number) return x + x end, 5) + )"); + + TypeId aType = requireType("a"); + CHECK_MESSAGE(isNumber(follow(aType)), "Expected a number but got ", toString(aType)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") +{ + check(R"( + function apply(f: (a) -> b, x) + return f(x) + end + )"); + + CHECK_EQ("((a) -> b, a) -> b", toString(requireType("apply"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(t->normal); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type Weirdo = (...{x: number}) -> () + + local w: Weirdo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("w"); + auto ftv = get(t); + REQUIRE(ftv); + + auto [argHead, argTail] = flatten(ftv->argTypes); + CHECK(argHead.empty()); + REQUIRE(argTail.has_value()); + + auto vtp = get(*argTail); + REQUIRE(vtp); + CHECK(vtp->ty->normal); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") +{ + CheckResult result = check(R"( + local Cyclic = {} + function Cyclic.get() + return Cyclic + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("Cyclic"); + CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function fussy(a, b) + if math.random() > 0.5 then + return a + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(a, b) -> a | b" == toString(requireType("fussy"))); +} + +TEST_CASE_FIXTURE(Fixture, "constrained_intersection_of_intersections") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local f : (() -> number) | ((number) -> number) + local g : (() -> number) | ((string) -> number) + + function h() + if math.random() then + return f + else + return g + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = requireType("h"); + + CHECK("() -> (() -> number) | ((number) -> number) | ((string) -> number)" == toString(h)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type X = {} + type Y = {y: number} + type Z = {z: string} + type W = {w: boolean} + type T = {x: Y & X} & {x:Z & W} + + local x: X + local y: Y + local z: Z + local w: W + local t: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("{| |}" == toString(requireType("x"), {true})); + CHECK("{| y: number |}" == toString(requireType("y"), {true})); + CHECK("{| z: string |}" == toString(requireType("z"), {true})); + CHECK("{| w: boolean |}" == toString(requireType("w"), {true})); + CHECK("{| x: {| w: boolean, y: number, z: string |} |}" == toString(requireType("t"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_2") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(w, x, y, z) + y.y = 5 + z.z = "five" + w.w = true + + type Z = {x: typeof(x) & typeof(y)} & {x: typeof(w) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(4 == args.size()); + CHECK("{+ w: boolean +}" == toString(args[0])); + CHECK("a" == toString(args[1])); + CHECK("{+ y: number +}" == toString(args[2])); + CHECK("{+ z: string +}" == toString(args[3])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + y.y = y + z.z = "five" + + type Z = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + z.z = "five" + + type R = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + local r: R + + y.y = r + + return r + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineTableFix", true}, + }; + // CLI-52787 + // ends up combining {_:any} with any, recursively + // which used to ICE because this combines a table with a non-table. + CheckResult result = check(R"( + export type t0 = any & { _: {_:any} } & { _:any } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineIntersectionFix", true}, + }; + + CheckResult result = check(R"( + export type t0 = {_:{_:any} & {_:any|string}} & {_:{_:{}}} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 79f9ecab..b941103d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1618,6 +1618,26 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") +{ + ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; + ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; + + // Same should hold when comments are captured + ParseOptions opts; + opts.captureComments = true; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )", + opts); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 29bdd866..f3fda54e 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + using namespace Luau; struct ToDotClassFixture : Fixture @@ -101,9 +103,34 @@ local function f(a, ...: string) return a end )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a, ...string) -> a", toString(requireType("f"))); + ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n6 -> n7; +n7 [label="BoundTypeVar 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="FunctionTypeVar 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -119,7 +146,8 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts)); + } } TEST_CASE_FIXTURE(Fixture, "union") @@ -361,4 +389,49 @@ n3 [label="number"]; toDot(*ty, opts)); } +TEST_CASE_FIXTURE(Fixture, "constrained") +{ + // ConstrainedTypeVars never appear in the final type graph, so we have to create one directly + // to dotify it. + TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}}; + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="ConstrainedTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +n1 -> n3; +n3 [label="string"]; +n1 -> n4; +n4 [label="nil"]; +})", + toDot(&t, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "singletontypes") +{ + CheckResult result = check(R"( + local x: "hi" | "\"hello\"" | true | false + )"); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="SingletonTypeVar string: hi"]; +n1 -> n3; +)" +"n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];" +R"( +n1 -> n4; +n4 [label="SingletonTypeVar boolean: true"]; +n1 -> n5; +n5 [label="SingletonTypeVar boolean: false"]; +})", toDot(requireType("x"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 3051e209..ccf5c583 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); + TEST_SUITE_BEGIN("ToString"); TEST_CASE_FIXTURE(Fixture, "primitive") diff --git a/tests/TopoSort.test.cpp b/tests/TopoSort.test.cpp index 9b990866..1f14ae88 100644 --- a/tests/TopoSort.test.cpp +++ b/tests/TopoSort.test.cpp @@ -340,26 +340,28 @@ TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases TEST_CASE_FIXTURE(Fixture, "return_comes_last") { - CheckResult result = check(R"( -export type Module = { bar: (number) -> boolean, foo: () -> string } + AstStatBlock* program = parse(R"( + local module = {} -return function() : Module - local module = {} + local function confuseCompiler() return module.foo() end - local function confuseCompiler() return module.foo() end - - module.foo = function() return "" end + module.foo = function() return "" end - function module.bar(x:number) - confuseCompiler() - return true - end - - return module -end + function module.bar(x:number) + confuseCompiler() + return true + end + + return module )"); - LUAU_REQUIRE_NO_ERRORS(result); + auto sorted = toposort(*program); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[2], program->body.data[1]); + CHECK_EQ(sorted[1], program->body.data[2]); + CHECK_EQ(sorted[3], program->body.data[3]); + CHECK_EQ(sorted[4], program->body.data[4]); } TEST_CASE_FIXTURE(Fixture, "break_comes_last") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 5ac45ff2..0c324cd0 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -388,7 +388,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") std::string actual = decorateWithTypes(code); - CHECK_EQ(expected, decorateWithTypes(code)); + CHECK_EQ(expected, actual); } TEST_CASE_FIXTURE(Fixture, "function_type_location") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2ad11d01..e2971ad5 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -753,4 +753,14 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") REQUIRE(ocf); } +TEST_CASE_FIXTURE(Fixture, "instantiation_clone_has_to_follow") +{ + CheckResult result = check(R"( + export type t8 = (t0)&(((true)|(any))->"") + export type t0 = ({})&({_:{[any]:number},}) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index c6fbebed..1ae65947 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -557,9 +559,9 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("a"))); - REQUIRE_EQ("number", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("boolean", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); } TEST_CASE_FIXTURE(Fixture, "see_thru_select") @@ -881,7 +883,10 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); + else + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 98fa66eb..8e3629e7 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -91,6 +91,9 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + persist(tf.type); + freeze(arena); } }; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 1713216a..65993681 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_CASE_FIXTURE(Fixture, "tc_function") @@ -98,7 +100,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") end return result - end + end return T )"); @@ -274,6 +276,10 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( function f(g) return f(f) @@ -281,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); + CHECK_EQ("t1 where t1 = (t1) -> (a...)", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") @@ -481,10 +487,10 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") std::vector fArgs = flatten(fType->argTypes).first; - TypeId xType = argVec[1]; + TypeId xType = follow(argVec[1]); CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); + CHECK_EQ(xType, follow(fArgs[0])); } TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") @@ -1043,13 +1049,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -1142,13 +1151,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") @@ -1338,6 +1350,126 @@ end CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } +TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(a: boolean, b: number) + if a then + return nil + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(boolean, number) -> number?", toString(requireType("foo"))); + + // TODO: Test multiple returns + // Think of various cases where typepacks need to grow. maybe consult other tests + // Basic normalization of ConstrainedTypeVars during quantification +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_higher_order_function") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5) + f("six") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + + +/* The bug here is that we are using the same level 2.0 for both the body of resolveDispatcher and the + * lambda useCallback. + * + * I think what we want to do is, at each scope level, never reuse the same sublevel. + * + * We also adjust checkBlock to consider the syntax `local x = function() ... end` to be sortable + * in the same way as `local function x() ... end`. This causes the function `resolveDispatcher` to be + * checked before the lambda. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useCallback: (any) -> any} + end + + local useCallback = function(deps: any) + return resolveDispatcher().useCallback(deps) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s: %s\n", toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") +{ + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useContext: (number?) -> any} + end + + local useContext + useContext = function(unstable_observedBits: number?) + resolveDispatcher().useContext(unstable_observedBits) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3") +{ + CheckResult result = check(R"( + local foo + + foo():bar(function() + return foo() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") { ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; @@ -1471,4 +1603,17 @@ pcall(wrapper, test) CHECK(acm->isVariadic); } +TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") +{ + CheckResult result = check(R"( + function f() + return 5, f() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f360a77c..49d31fc6 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -230,8 +230,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") @@ -253,8 +253,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") @@ -705,10 +705,10 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { ScopedFastFlag sffs[] = { - { "LuauTableSubtypingVariance2", true }, - { "LuauUnsealedTableLiteral", true }, - { "LuauPropertiesGetExpectedType", true }, - { "LuauRecursiveTypeParameterRestriction", true }, + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + {"LuauPropertiesGetExpectedType", true}, + {"LuauRecursiveTypeParameterRestriction", true}, }; CheckResult result = check(R"( @@ -843,6 +843,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a) -> a", toString(requireType("id"))); CHECK_EQ(*typeChecker.numberType, *requireType("a")); CHECK_EQ(*typeChecker.nilType, *requireType("b")); } @@ -1037,25 +1038,39 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) + local function sum(x: a, y: a, f: (a, a) -> a) + return f(x, y) + end + return sum(2, 3, function(a, b) return a + b end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) + local function map(arr: {a}, f: (a) -> b) + local r = {} + for i,v in ipairs(arr) do + table.insert(r, f(v)) + end + return r + end + local a = {1, 2, 3} + local r = map(a, function(a) return a + a > 100 end) )"); LUAU_REQUIRE_NO_ERRORS(result); REQUIRE_EQ("{boolean}", toString(requireType("r"))); check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + local function foldl(arr: {a}, init: b, f: (b, a) -> b) + local r = init + for i,v in ipairs(arr) do + r = f(r, v) + end + return r + end + local a = {1, 2, 3} + local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1065,25 +1080,19 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) + g12(1, function(x) return x + x end) + g12(1, 2, function(x, y) return x + y end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + g12({x=1}, function(x) return {x=-x.x} end) + g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1121,12 +1130,12 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X + type A = { x: number } + local a: A = { x = 1 } + local b = a + type B = typeof(b) + type X = T + local c: X )"); LUAU_REQUIRE_NO_ERRORS(result); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ac7a6532..3675919f 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("IntersectionTypes"); TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") @@ -306,7 +308,10 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table '{| x: number, y: number |}'"); + else + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") @@ -314,27 +319,34 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( - type X = { x: (number) -> number } - type Y = { y: (string) -> string } + type X = { x: (number) -> number } + type Y = { y: (string) -> string } - type XY = X & Y + type XY = X & Y - local xy : XY = { - x = function(a: number) return -a end, - y = function(a: string) return a .. "b" end - } - function xy.z(a:number) return a * 10 end - function xy:y(a:number) return a * 10 end - function xy:w(a:number) return a * 10 end + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end )"); LUAU_REQUIRE_ERROR_COUNT(4, result); CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' caused by: Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); - CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); + + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") @@ -375,6 +387,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -393,6 +407,8 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -427,8 +443,8 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") repeat type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) function _(l0):(t0)&(t0) - while nil do - end + while nil do + end end until _(_)(_)._ )"); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 40831bf6..5cd3f3ba 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -199,16 +199,16 @@ end TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( ---!nonstrict -local f = {} -function f:foo(a: number, b: number) end + --!nonstrict + local f = {} + function f:foo(a: number, b: number) end -function bar(...) - f.foo(f, 1, ...) -end + function bar(...) + f.foo(f, 1, ...) + end -bar(2) -)"); + bar(2) + )"); LUAU_REQUIRE_NO_ERRORS(result); } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 6a8a9d93..5f2e2404 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -91,7 +91,8 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); + REQUIRE(retType.has_value()); + CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); CHECK_EQ(requireType("s"), typeChecker.stringType); } diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2e16b21e..6b3741fa 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -527,6 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_NO_ERRORS(result); } +// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { ScopedFastFlag sff[]{ @@ -556,10 +558,19 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type LUAU_REQUIRE_NO_ERRORS(result); - // f and g should have the type () -> () - CHECK_EQ("() -> (a...)", toString(requireType("f"))); - CHECK_EQ("() -> (a...)", toString(requireType("g"))); - CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ("() -> ()", toString(requireType("f"))); + CHECK_EQ("() -> ()", toString(requireType("g"))); + CHECK_EQ("nil", toString(requireType("x"))); + } + else + { + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + } } TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") @@ -575,6 +586,10 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( local function f() return end local g = function() return f() end @@ -585,6 +600,10 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( --!strict local function f(...) return ... end @@ -594,4 +613,112 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } +TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5, 'a') + f('b', 6) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We incorrectly infer that the argument to foo could be called with (number, number) or (string, string) + // even though that is strictly more permissive than the actual source text shows. + CHECK("((number | string, number | string) -> (a...)) -> ()" == toString(requireType("foo"))); +} + +// Once fixed, move this to Normalize.test.cpp +TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") +{ +#if defined(_DEBUG) || defined(_NOOPT) + ScopedFastInt sfi("LuauNormalizeIterationLimit", 500); +#endif + + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y) + x.x = y + y.x = x + + type R = {x: typeof(x)} & {x: typeof(y)} + local r: R + + return r + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") +{ + CheckResult result = check(R"( + local function f(): () end + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_ERRORS(result); + // LUAU_REQUIRE_NO_ERRORS(result); + // CHECK_EQ("boolean", toString(requireType("ok"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") +{ + CheckResult result = check(R"( + local function f(): number + if math.random() > 0.5 then + return 5 + else + error("something") + end + end + + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "function_returns_many_things_but_first_of_it_is_forgotten") +{ + CheckResult result = check(R"( + local function f(): (number, string, boolean) + if math.random() > 0.5 then + return 5, "hello", true + else + error("something") + end + end + + local ok, res, s, b = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("boolean", toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index cddeab6e..ce22bcb1 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -8,6 +9,7 @@ LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -48,6 +50,7 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; + normalize(vec3, arena, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); @@ -55,17 +58,21 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + normalize(isA, arena, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; + normalize(inst, arena, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + normalize(folder, arena, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); getMutable(part)->props = { {"Position", Property{vec3}}, }; + normalize(part, arena, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; @@ -697,7 +704,10 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("{| x: number, y: number |}", toString(requireTypeAtPosition({4, 28}))); + else + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index d39341ea..2b01c29e 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -5,8 +5,6 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(BetterDiagnosticCodesInStudio) - using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); @@ -261,14 +259,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::BetterDiagnosticCodesInStudio) - { - CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); - } - else - { - CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); - } + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0484351d..ca1b8de7 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -1211,7 +1213,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK(get(result.errors[0])); + else + CHECK(get(result.errors[0])); } // This unit test could be flaky if the fix has regressed. @@ -2922,6 +2927,60 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +// The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. +TEST_CASE_FIXTURE(Fixture, "dont_leak_free_table_props") +{ + CheckResult result = check(R"( + local function a(state) + print(state.blah) + end + + local function b(state) -- The bug was that we inferred state: {blah: any, gwar: any} + print(state.gwar) + end + + return function() + return function(state) + a(state) + b(state) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); + CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + )"); + + CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + toString(requireType("Base64FileReader"))); +} + TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") { ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 660ddcfc..6abd96b9 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,6 +13,7 @@ #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) @@ -177,7 +178,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_case") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") @@ -293,7 +293,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. -TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { #if defined(LUAU_ENABLE_ASAN) int limit = 250; @@ -302,12 +302,14 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") #else int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - CHECK_NOTHROW(check("print('Hello!')")); - CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); + ScopedFastFlag sff{"LuauTableUseCounterInstead", true}; + ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; + + CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") @@ -721,9 +723,9 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") local l0 do end while _ do - function _:_() - _ += _(_._(_:n0(xpcall,_))) - end + function _:_() + _ += _(_._(_:n0(xpcall,_))) + end end )"); @@ -978,4 +980,48 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); + ScopedFastFlag sff{"LuauRecursionLimitException", true}; + + CheckResult result = check(R"( + function complex() + function _(l0:t0): (any, ()->()) + return 0,_ + end + type t0 = t0 | {} + _(nil) + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") +{ + ScopedFastFlag substituteFollowNewTypes{"LuauSubstituteFollowNewTypes", true}; + + CheckResult result = check(R"( + local obj = {} + + function obj:Method() + self.fieldA = function(object) + if object.a then + self.arr[object] = true + elseif object.b then + self.fieldB[object] = object:Connect(function(arg) + self.arr[arg] = nil + end) + end + end + end + + return obj + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 130f33d7..f141622f 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -27,8 +29,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const auto& [returns, tail] = flatten(takeTwoType->retType); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, returns[0]); - CHECK_EQ(typeChecker.numberType, returns[1]); + CHECK_EQ(typeChecker.numberType, follow(returns[0])); + CHECK_EQ(typeChecker.numberType, follow(returns[1])); CHECK(!tail); } @@ -74,9 +76,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const auto& [rets, tail] = flatten(takeOneMoreType->retType); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, rets[0]); - CHECK_EQ(typeChecker.numberType, rets[1]); - CHECK_EQ(typeChecker.numberType, rets[2]); + CHECK_EQ(typeChecker.numberType, follow(rets[0])); + CHECK_EQ(typeChecker.numberType, follow(rets[1])); + CHECK_EQ(typeChecker.numberType, follow(rets[2])); CHECK(!tail); } @@ -91,26 +93,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* applyType = get(requireType("apply")); - REQUIRE(applyType != nullptr); - - std::vector applyArgs = flatten(applyType->argTypes).first; - REQUIRE_EQ(3, applyArgs.size()); - - const FunctionTypeVar* fType = get(follow(applyArgs[0])); - REQUIRE(fType != nullptr); - - const FunctionTypeVar* gType = get(follow(applyArgs[1])); - REQUIRE(gType != nullptr); - - std::vector gArgs = flatten(gType->argTypes).first; - REQUIRE_EQ(1, gArgs.size()); - - // function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...) - - REQUIRE_EQ(*gArgs[0], *applyArgs[2]); - REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType)); - REQUIRE_EQ(toString(fType->retType), toString(applyType->retType)); + CHECK_EQ("((b...) -> (c...), (a) -> (b...), a) -> (c...)", toString(requireType("apply"))); } TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") @@ -328,7 +311,10 @@ local c: Packed auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); + else + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); REQUIRE(ttvA->instantiatedTypeParams.size() == 1); REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ff207a18..96bdd534 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,6 +6,7 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; @@ -254,11 +255,11 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { CheckResult result = check(R"( -local a = {} -function a.foo(x:number, y:number) return x + y end -type A = typeof(a) -local b: A? = a -local c = b.foo(1, 2) + local a = {} + function a.foo(x:number, y:number) return x + y end + type A = typeof(a) + local b: A? = a + local c = b.foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -356,7 +357,10 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", toString(result.errors[0])); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -533,8 +537,13 @@ TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") LUAU_REQUIRE_ERROR_COUNT(1, result); // NOTE: union normalization will improve this message - CHECK_EQ(toString(result.errors[0]), - R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Type '(string) -> number' could not be converted into '(number) -> string'\n" + "caused by:\n" + " Argument #1 type is not compatible. Type 'number' could not be converted into 'string'"); + else + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); } diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index ab9be42c..c8176456 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -581,4 +581,19 @@ do assert(#arr == 5) end +-- test boundary invariant maintenance when table is filled using SETLIST opcode +do + local arr = {[2]=2,1} + assert(#arr == 2) +end + +-- test boundary invariant maintenance when table is filled using table.move +do + local t1 = {1, 2, 3, 4, 5} + local t2 = {[6] = 6} + + table.move(t1, 1, 5, 1, t2) + assert(#t2 == 6) +end + return"OK" From 5bb9f379b07e378db0a170e7c4030e3a943b2f14 Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Fri, 15 Apr 2022 19:19:42 -0500 Subject: [PATCH 047/102] Unified strict and nonstrict mode in the prototype (#458) --- prototyping/Luau/StrictMode.agda | 9 +++---- prototyping/Luau/StrictMode/ToString.agda | 4 ++-- prototyping/Luau/Type.agda | 29 ++++++++--------------- prototyping/Luau/TypeCheck.agda | 9 ++----- prototyping/Properties/StrictMode.agda | 8 +++---- prototyping/Properties/Subtyping.agda | 4 +--- prototyping/Properties/TypeCheck.agda | 11 +++------ 7 files changed, 24 insertions(+), 50 deletions(-) diff --git a/prototyping/Luau/StrictMode.agda b/prototyping/Luau/StrictMode.agda index b6769f01..1b028042 100644 --- a/prototyping/Luau/StrictMode.agda +++ b/prototyping/Luau/StrictMode.agda @@ -5,18 +5,15 @@ module Luau.StrictMode where open import Agda.Builtin.Equality using (_≡_) open import FFI.Data.Maybe using (just; nothing) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; var; binexp; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; local_←_; _∙_; done; return; name; +; -; *; /; <; >; <=; >=; ··) -open import Luau.Type using (Type; strict; nil; number; string; boolean; _⇒_; _∪_; _∩_; tgt) +open import Luau.Type using (Type; nil; number; string; boolean; _⇒_; _∪_; _∩_; src; tgt) open import Luau.Subtyping using (_≮:_) open import Luau.Heap using (Heap; function_is_end) renaming (_[_] to _[_]ᴴ) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; ⊢ᴴ_; ⊢ᴼ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; var; addr; app; binexp; block; return; local; function; srcBinOp) +open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; ⊢ᴴ_; ⊢ᴼ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; var; addr; app; binexp; block; return; local; function; srcBinOp) open import Properties.Contradiction using (¬) -open import Properties.TypeCheck(strict) using (typeCheckᴮ) +open import Properties.TypeCheck using (typeCheckᴮ) open import Properties.Product using (_,_) -src : Type → Type -src = Luau.Type.src strict - data Warningᴱ (H : Heap yes) {Γ} : ∀ {M T} → (Γ ⊢ᴱ M ∈ T) → Set data Warningᴮ (H : Heap yes) {Γ} : ∀ {B T} → (Γ ⊢ᴮ B ∈ T) → Set diff --git a/prototyping/Luau/StrictMode/ToString.agda b/prototyping/Luau/StrictMode/ToString.agda index 08ee13b8..eee5722e 100644 --- a/prototyping/Luau/StrictMode/ToString.agda +++ b/prototyping/Luau/StrictMode/ToString.agda @@ -7,8 +7,8 @@ open import FFI.Data.String using (String; _++_) open import Luau.Subtyping using (_≮:_; Tree; witness; scalar; function; function-ok; function-err) open import Luau.StrictMode using (Warningᴱ; Warningᴮ; UnallocatedAddress; UnboundVariable; FunctionCallMismatch; FunctionDefnMismatch; BlockMismatch; app₁; app₂; BinOpMismatch₁; BinOpMismatch₂; bin₁; bin₂; block₁; return; LocalVarMismatch; local₁; local₂; function₁; function₂; heap; expr; block; addr) open import Luau.Syntax using (Expr; val; yes; var; var_∈_; _⟨_⟩∈_; _$_; addr; number; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name) -open import Luau.Type using (strict; number; boolean; string; nil) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_) +open import Luau.Type using (number; boolean; string; nil) +open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_) open import Luau.Addr.ToString using (addrToString) open import Luau.Var.ToString using (varToString) open import Luau.Type.ToString using (typeToString) diff --git a/prototyping/Luau/Type.agda b/prototyping/Luau/Type.agda index 30c45388..59d1107f 100644 --- a/prototyping/Luau/Type.agda +++ b/prototyping/Luau/Type.agda @@ -146,25 +146,16 @@ just T ≡ᴹᵀ just U with T ≡ᵀ U (just T ≡ᴹᵀ just T) | yes refl = yes refl (just T ≡ᴹᵀ just U) | no p = no (λ q → p (just-inv q)) -data Mode : Set where - strict : Mode - nonstrict : Mode - -src : Mode → Type → Type -src m nil = never -src m number = never -src m boolean = never -src m string = never -src m (S ⇒ T) = S --- In nonstrict mode, functions are covaraiant, in strict mode they're contravariant -src strict (S ∪ T) = (src strict S) ∩ (src strict T) -src nonstrict (S ∪ T) = (src nonstrict S) ∪ (src nonstrict T) -src strict (S ∩ T) = (src strict S) ∪ (src strict T) -src nonstrict (S ∩ T) = (src nonstrict S) ∩ (src nonstrict T) -src strict never = unknown -src nonstrict never = never -src strict unknown = never -src nonstrict unknown = unknown +src : Type → Type +src nil = never +src number = never +src boolean = never +src string = never +src (S ⇒ T) = S +src (S ∪ T) = (src S) ∩ (src T) +src (S ∩ T) = (src S) ∪ (src T) +src never = unknown +src unknown = never tgt : Type → Type tgt nil = never diff --git a/prototyping/Luau/TypeCheck.agda b/prototyping/Luau/TypeCheck.agda index aea6507a..cabd27a8 100644 --- a/prototyping/Luau/TypeCheck.agda +++ b/prototyping/Luau/TypeCheck.agda @@ -1,8 +1,6 @@ {-# OPTIONS --rewriting #-} -open import Luau.Type using (Mode) - -module Luau.TypeCheck (m : Mode) where +module Luau.TypeCheck where open import Agda.Builtin.Equality using (_≡_) open import FFI.Data.Maybe using (Maybe; just) @@ -10,15 +8,12 @@ open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr open import Luau.Var using (Var) open import Luau.Addr using (Addr) open import Luau.Heap using (Heap; Object; function_is_end) renaming (_[_] to _[_]ᴴ) -open import Luau.Type using (Type; Mode; nil; unknown; number; boolean; string; _⇒_; tgt) +open import Luau.Type using (Type; nil; unknown; number; boolean; string; _⇒_; src; tgt) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) open import FFI.Data.Vector using (Vector) open import FFI.Data.Maybe using (Maybe; just; nothing) open import Properties.Product using (_×_; _,_) -src : Type → Type -src = Luau.Type.src m - orUnknown : Maybe Type → Type orUnknown nothing = unknown orUnknown (just T) = T diff --git a/prototyping/Properties/StrictMode.agda b/prototyping/Properties/StrictMode.agda index 1165fdaa..fd2cf2f2 100644 --- a/prototyping/Properties/StrictMode.agda +++ b/prototyping/Properties/StrictMode.agda @@ -11,8 +11,8 @@ open import Luau.StrictMode using (Warningᴱ; Warningᴮ; Warningᴼ; Warning open import Luau.Substitution using (_[_/_]ᴮ; _[_/_]ᴱ; _[_/_]ᴮunless_; var_[_/_]ᴱwhenever_) open import Luau.Subtyping using (_≮:_; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) open import Luau.Syntax using (Expr; yes; var; val; var_∈_; _⟨_⟩∈_; _$_; addr; number; bool; string; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name; ==; ~=) -open import Luau.Type using (Type; strict; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; tgt; _≡ᵀ_; _≡ᴹᵀ_) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orUnknown; srcBinOp; tgtBinOp) +open import Luau.Type using (Type; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; src; tgt; _≡ᵀ_; _≡ᴹᵀ_) +open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orUnknown; srcBinOp; tgtBinOp) open import Luau.Var using (_≡ⱽ_) open import Luau.Addr using (_≡ᴬ_) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_; ⊕-lookup-miss; ⊕-swap; ⊕-over) renaming (_[_] to _[_]ⱽ) @@ -23,13 +23,11 @@ open import Properties.Dec using (Dec; yes; no) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Functions using (_∘_) open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; never-tgt-≮:; tgt-never-≮:; src-unknown-≮:; unknown-src-≮:; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never) -open import Properties.TypeCheck(strict) using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) +open import Properties.TypeCheck using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) open import Luau.OpSem using (_⟦_⟧_⟶_; _⊢_⟶*_⊣_; _⊢_⟶ᴮ_⊣_; _⊢_⟶ᴱ_⊣_; app₁; app₂; function; beta; return; block; done; local; subst; binOp₀; binOp₁; binOp₂; refl; step; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.RuntimeError using (BinOpError; RuntimeErrorᴱ; RuntimeErrorᴮ; FunctionMismatch; BinOpMismatch₁; BinOpMismatch₂; UnboundVariable; SEGV; app₁; app₂; bin₁; bin₂; block; local; return; +; -; *; /; <; >; <=; >=; ··) open import Luau.RuntimeType using (RuntimeType; valueType; number; string; boolean; nil; function) -src = Luau.Type.src strict - data _⊑_ (H : Heap yes) : Heap yes → Set where refl : (H ⊑ H) snoc : ∀ {H′ a O} → (H′ ≡ᴴ H ⊕ a ↦ O) → (H ⊑ H′) diff --git a/prototyping/Properties/Subtyping.agda b/prototyping/Properties/Subtyping.agda index cc6bb5c1..b713eaf7 100644 --- a/prototyping/Properties/Subtyping.agda +++ b/prototyping/Properties/Subtyping.agda @@ -5,14 +5,12 @@ module Properties.Subtyping where open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) -open import Luau.Type using (Type; Scalar; strict; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; tgt) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; src; tgt) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Equality using (_≢_) open import Properties.Functions using (_∘_) open import Properties.Product using (_×_; _,_) -src = Luau.Type.src strict - -- Language membership is decidable dec-language : ∀ T t → Either (¬Language T t) (Language T t) dec-language nil (scalar number) = Left (scalar-scalar number nil (λ ())) diff --git a/prototyping/Properties/TypeCheck.agda b/prototyping/Properties/TypeCheck.agda index a5916a13..0726a4be 100644 --- a/prototyping/Properties/TypeCheck.agda +++ b/prototyping/Properties/TypeCheck.agda @@ -1,16 +1,14 @@ {-# OPTIONS --rewriting #-} -open import Luau.Type using (Mode) - -module Properties.TypeCheck (m : Mode) where +module Properties.TypeCheck where open import Agda.Builtin.Equality using (_≡_; refl) open import Agda.Builtin.Bool using (Bool; true; false) open import FFI.Data.Maybe using (Maybe; just; nothing) open import FFI.Data.Either using (Either) -open import Luau.TypeCheck(m) using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orUnknown; tgtBinOp) +open import Luau.TypeCheck using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orUnknown; tgtBinOp) open import Luau.Syntax using (Block; Expr; Value; BinaryOperator; yes; nil; addr; number; bool; string; val; var; binexp; _$_; function_is_end; block_is_end; _∙_; return; done; local_←_; _⟨_⟩; _⟨_⟩∈_; var_∈_; name; fun; arg; +; -; *; /; <; >; ==; ~=; <=; >=) -open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_; tgt) +open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_; src; tgt) open import Luau.RuntimeType using (RuntimeType; nil; number; function; string; valueType) open import Luau.VarCtxt using (VarCtxt; ∅; _↦_; _⊕_↦_; _⋒_; _⊝_) renaming (_[_] to _[_]ⱽ) open import Luau.Addr using (Addr) @@ -22,9 +20,6 @@ open import Properties.Equality using (_≢_; sym; trans; cong) open import Properties.Product using (_×_; _,_) open import Properties.Remember using (Remember; remember; _,_) -src : Type → Type -src = Luau.Type.src m - typeOfᴼ : Object yes → Type typeOfᴼ (function f ⟨ var x ∈ S ⟩∈ T is B end) = (S ⇒ T) From e0a6461173216b76bdf731410f50485b89c83aa4 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 21 Apr 2022 14:44:27 -0700 Subject: [PATCH 048/102] Sync to upstream/release/524 (#462) --- Analysis/include/Luau/Clone.h | 2 +- Analysis/include/Luau/Frontend.h | 21 +- Analysis/include/Luau/TypeVar.h | 6 +- Analysis/include/Luau/Unifier.h | 1 - Analysis/include/Luau/VisitTypeVar.h | 2 +- Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/Clone.cpp | 42 ++- Analysis/src/Error.cpp | 23 +- Analysis/src/Frontend.cpp | 50 ++-- Analysis/src/Module.cpp | 18 +- Analysis/src/Normalize.cpp | 44 +-- Analysis/src/Substitution.cpp | 15 +- Analysis/src/ToDot.cpp | 2 +- Analysis/src/Transpiler.cpp | 53 ++-- Analysis/src/TypeAttach.cpp | 14 + Analysis/src/TypeInfer.cpp | 37 ++- Analysis/src/TypeVar.cpp | 6 + Analysis/src/Unifier.cpp | 110 ++------ Ast/include/Luau/StringUtils.h | 1 + Ast/src/Parser.cpp | 8 + Ast/src/StringUtils.cpp | 2 +- CLI/Repl.cpp | 72 ++++- CMakeLists.txt | 6 + Compiler/include/Luau/BytecodeBuilder.h | 7 + Compiler/src/BytecodeBuilder.cpp | 70 ++++- Compiler/src/Compiler.cpp | 127 ++++++++- Compiler/src/ConstantFolding.cpp | 13 +- Compiler/src/ConstantFolding.h | 3 +- VM/src/lapi.cpp | 24 +- VM/src/lgc.cpp | 26 +- VM/src/lgc.h | 2 +- VM/src/lstate.h | 7 +- bench/bench.py | 13 +- fuzz/proto.cpp | 13 +- tests/Autocomplete.test.cpp | 36 +++ tests/Compiler.test.cpp | 359 +++++++++++++++++++++++- tests/CostModel.test.cpp | 125 +++++++++ tests/JsonEncoder.test.cpp | 332 +++++++++++++++++++++- tests/Linter.test.cpp | 2 +- tests/Module.test.cpp | 24 +- tests/NonstrictMode.test.cpp | 46 ++- tests/Normalize.test.cpp | 44 ++- tests/ToDot.test.cpp | 4 +- tests/Transpiler.test.cpp | 17 ++ tests/TypeInfer.classes.test.cpp | 34 ++- tests/TypeInfer.definitions.test.cpp | 2 - tests/TypeInfer.functions.test.cpp | 30 +- tests/TypeInfer.modules.test.cpp | 4 - tests/TypeInfer.provisional.test.cpp | 14 - tests/TypeInfer.refinements.test.cpp | 8 +- tests/TypeInfer.tables.test.cpp | 2 - tests/TypeInfer.test.cpp | 18 +- tests/TypeVar.test.cpp | 6 +- 53 files changed, 1596 insertions(+), 355 deletions(-) diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 78aa92c7..9b6ffa62 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -18,7 +18,7 @@ struct CloneState SeenTypePacks seenTypePacks; int recursionCount = 0; - bool encounteredFreeType = false; + bool encounteredFreeType = false; // TODO: Remove with LuauLosslessClone. }; TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index e24e433c..59125470 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -13,6 +13,7 @@ #include LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauDirtySourceModule) namespace Luau { @@ -57,19 +58,27 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa struct SourceNode { - bool isDirty(bool forAutocomplete) const + bool hasDirtySourceModule() const + { + LUAU_ASSERT(FFlag::LuauDirtySourceModule); + + return dirtySourceModule; + } + + bool hasDirtyModule(bool forAutocomplete) const { if (FFlag::LuauSeparateTypechecks) - return forAutocomplete ? dirtyAutocomplete : dirty; + return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; else - return dirty; + return dirtyModule; } ModuleName name; std::unordered_set requires; std::vector> requireLocations; - bool dirty = true; - bool dirtyAutocomplete = true; + bool dirtySourceModule = true; + bool dirtyModule = true; + bool dirtyModuleForAutocomplete = true; double autocompleteLimitsMult = 1.0; }; @@ -163,7 +172,7 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index ae7d1377..84576758 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -373,15 +373,17 @@ struct ClassTypeVar std::optional metatable; // metaclass? Tags tags; std::shared_ptr userData; + ModuleName definitionModuleName; - ClassTypeVar( - Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData) + ClassTypeVar(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, + std::shared_ptr userData, ModuleName definitionModuleName) : name(name) , props(props) , parent(parent) , metatable(metatable) , tags(tags) , userData(userData) + , definitionModuleName(definitionModuleName) { } }; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 340feb7f..418d4ca4 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -92,7 +92,6 @@ private: bool canCacheResult(TypeId subTy, TypeId superTy); void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount); - void cacheResult_DEPRECATED(TypeId subTy, TypeId superTy); public: void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index d11cbd0d..045190ea 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -52,7 +52,7 @@ inline void unsee(std::unordered_set& seen, const void* tv) inline void unsee(DenseHashSet& seen, const void* tv) { - // When DenseHashSet is used for 'visitOnce', where don't forget visited elements + // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } template diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index e0e79cb4..dec12d01 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false); +LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) static const std::unordered_set kStatementStartingKeywords = { @@ -462,7 +463,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, + FFlag::LuauFixAutocompleteClassSecurityLevel ? containingClass : cls); } else if (auto tbl = get(ty)) fillProps(tbl->props); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 8e7f7c07..d5bd9dab 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -10,6 +10,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAGVARIABLE(LuauLosslessClone, false) namespace Luau { @@ -87,11 +88,18 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - cloneState.encounteredFreeType = true; + if (FFlag::LuauLosslessClone) + { + defaultClone(t); + } + else + { + cloneState.encounteredFreeType = true; - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; + } } void operator()(const Unifiable::Generic& t) @@ -143,10 +151,18 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - cloneState.encounteredFreeType = true; - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; + if (FFlag::LuauLosslessClone) + { + defaultClone(t); + } + else + { + cloneState.encounteredFreeType = true; + + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; + } } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -174,7 +190,8 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const ConstrainedTypeVar& t) { - cloneState.encounteredFreeType = true; + if (!FFlag::LuauLosslessClone) + cloneState.encounteredFreeType = true; TypeId res = dest.addType(ConstrainedTypeVar{t.level}); ConstrainedTypeVar* ctv = getMutable(res); @@ -252,7 +269,7 @@ void TypeCloner::operator()(const TableTypeVar& t) for (TypePackId& arg : ttv->instantiatedTypePackParams) arg = clone(arg, dest, cloneState); - if (ttv->state == TableState::Free) + if (!FFlag::LuauLosslessClone && ttv->state == TableState::Free) { cloneState.encounteredFreeType = true; @@ -276,7 +293,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t) void TypeCloner::operator()(const ClassTypeVar& t) { - TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); ClassTypeVar* ctv = getMutable(result); seenTypes[typeId] = result; @@ -361,7 +378,10 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) // Persistent types are not being cloned and we get the original type back which might be read-only if (!res->persistent) + { asMutable(res)->documentationSymbol = typeId->documentationSymbol; + asMutable(res)->normal = typeId->normal; + } } return res; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index cbec0b15..24ed4ac1 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); - static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -59,27 +57,20 @@ struct ErrorConverter std::string result; - if (FFlag::LuauTypeMismatchModuleName) + if (givenTypeName == wantedTypeName) { - if (givenTypeName == wantedTypeName) + if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) { - if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) + if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) - { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; - } + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; } } + } - if (result.empty()) - result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; - } - else - { + if (result.empty()) result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; - } if (tm.error) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8b0b2210..34ccdac4 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) +LUAU_FASTFLAGVARIABLE(LuauDirtySourceModule, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau @@ -358,7 +359,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.isDirty(frontendOptions.forAutocomplete)) + if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. if (FFlag::LuauSeparateTypechecks) @@ -402,7 +403,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalerrors.begin(), module->errors.end()); moduleResolver.modules[moduleName] = std::move(module); - sourceNode.dirty = false; + sourceNode.dirtyModule = false; } return checkResult; @@ -618,7 +619,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.isDirty(forAutocomplete)) + if (!it->second.hasDirtyModule(forAutocomplete)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization @@ -768,7 +769,7 @@ LintResult Frontend::lint(const SourceModule& module, std::optionalsecond.isDirty(forAutocomplete); + return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); } /* @@ -810,20 +811,31 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (FFlag::LuauSeparateTypechecks) + if (FFlag::LuauDirtySourceModule) { - if (sourceNode.dirty && sourceNode.dirtyAutocomplete) + LUAU_ASSERT(FFlag::LuauSeparateTypechecks); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) continue; - sourceNode.dirty = true; - sourceNode.dirtyAutocomplete = true; + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + } + else if (FFlag::LuauSeparateTypechecks) + { + if (sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; + + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; } else { - if (sourceNode.dirty) + if (sourceNode.dirtyModule) continue; - sourceNode.dirty = true; + sourceNode.dirtyModule = true; } if (0 == reverseDeps.count(name)) @@ -851,13 +863,14 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete) +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.isDirty(forAutocomplete)) + if (it != sourceNodes.end() && + (FFlag::LuauDirtySourceModule ? !it->second.hasDirtySourceModule() : !it->second.hasDirtyModule(forAutocomplete_DEPRECATED))) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -901,17 +914,20 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.requires.clear(); sourceNode.requireLocations.clear(); + if (FFlag::LuauDirtySourceModule) + sourceNode.dirtySourceModule = false; + if (FFlag::LuauSeparateTypechecks) { if (it == sourceNodes.end()) { - sourceNode.dirty = true; - sourceNode.dirtyAutocomplete = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; } } else { - sourceNode.dirty = true; + sourceNode.dirtyModule = true; } for (const auto& [moduleName, location] : requireTrace.requires) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index e2e3b436..bafd4371 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,8 +14,8 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauLosslessClone) namespace Luau { @@ -182,20 +182,20 @@ bool Module::clonePublicInterface(InternalErrorReporter& ice) } } - if (FFlag::LuauCloneDeclaredGlobals) + for (auto& [name, ty] : declaredGlobals) { - for (auto& [name, ty] : declaredGlobals) - { - ty = clone(ty, interfaceTypes, cloneState); - if (FFlag::LuauLowerBoundsCalculation) - normalize(ty, interfaceTypes, ice); - } + ty = clone(ty, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(ty, interfaceTypes, ice); } freeze(internalTypes); freeze(interfaceTypes); - return cloneState.encounteredFreeType; + if (FFlag::LuauLosslessClone) + return false; // TODO: make function return void. + else + return cloneState.encounteredFreeType; } } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 40341ac1..043526ed 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -5,7 +5,6 @@ #include #include "Luau/Clone.h" -#include "Luau/DenseHash.h" #include "Luau/Substitution.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -254,7 +253,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) } template -static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErrorReporter& ice) +static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) { int count = 0; auto isNormal = [&](TypeId ty) { @@ -262,18 +261,19 @@ static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErro if (count >= FInt::LuauNormalizeIterationLimit) ice.ice("Luau::areNormal hit iteration limit"); - return ty->normal || seen.find(asMutable(ty)); + // The follow is here because a bound type may not be normal, but the bound type is normal. + return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); }; return std::all_of(begin(t), end(t), isNormal); } -static bool areNormal(const std::vector& types, const DenseHashSet& seen, InternalErrorReporter& ice) +static bool areNormal(const std::vector& types, const std::unordered_set& seen, InternalErrorReporter& ice) { return areNormal_(types, seen, ice); } -static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalErrorReporter& ice) +static bool areNormal(TypePackId tp, const std::unordered_set& seen, InternalErrorReporter& ice) { tp = follow(tp); if (get(tp)) @@ -288,7 +288,7 @@ static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalEr return true; if (auto vtp = get(*tail)) - return vtp->ty->normal || seen.find(asMutable(vtp->ty)); + return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end(); return true; } @@ -335,9 +335,14 @@ struct Normalize return false; } - bool operator()(TypeId ty, const BoundTypeVar& btv) + bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set& seen) { - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type. + // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. + // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. + if (seen.find(asMutable(btv.boundTo)) != seen.end()) + return false; + + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); asMutable(ty)->normal = btv.boundTo->normal; @@ -365,7 +370,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, DenseHashSet& seen) + bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -391,8 +396,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const FunctionTypeVar& ftv) = delete; - bool operator()(TypeId ty, const FunctionTypeVar& ftv, DenseHashSet& seen) + bool operator()(TypeId ty, const FunctionTypeVar& ftv, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -407,7 +411,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const TableTypeVar& ttv, DenseHashSet& seen) + bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -419,7 +423,7 @@ struct Normalize auto checkNormal = [&](TypeId t) { // if t is on the stack, it is possible that this type is normal. // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && !seen.find(asMutable(t))) + if (!t->normal && seen.find(asMutable(t)) == seen.end()) normal = false; }; @@ -449,7 +453,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const MetatableTypeVar& mtv, DenseHashSet& seen) + bool operator()(TypeId ty, const MetatableTypeVar& mtv, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -477,7 +481,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const UnionTypeVar& utvRef, DenseHashSet& seen) + bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -507,7 +511,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, DenseHashSet& seen) + bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -775,8 +779,8 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo (void)clone(ty, arena, state); Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, n, seen); + std::unordered_set seen; + visitTypeVar(ty, n, seen); return {ty, !n.limitExceeded}; } @@ -800,8 +804,8 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE (void)clone(tp, arena, state); Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(tp, n, seen); + std::unordered_set seen; + visitTypeVar(tp, n, seen); return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 8648b21e..1b51fa3d 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) +LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) namespace Luau { @@ -106,7 +107,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(ty == log->follow(ty)); else ty = log->follow(ty); @@ -127,7 +128,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(tp == log->follow(tp)); else tp = log->follow(tp); @@ -148,7 +149,8 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - ty = log->follow(ty); + if (!FFlag::LuauSubstituteFollowPossibleMutations) + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -156,7 +158,8 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - tp = log->follow(tp); + if (!FFlag::LuauSubstituteFollowPossibleMutations) + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -471,7 +474,7 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(ty == log->follow(ty)); else ty = log->follow(ty); @@ -484,7 +487,7 @@ void Substitution::foundDirty(TypeId ty) void Substitution::foundDirty(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(tp == log->follow(tp)); else tp = log->follow(tp); diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index cb54bfc1..9b396c80 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -327,7 +327,7 @@ void StateDot::visitChildren(TypePackId tp, int index) } else if (const VariadicTypePack* vtp = get(tp)) { - formatAppend(result, "VariadicTypePack %d", index); + formatAppend(result, "VariadicTypePack %s%d", vtp->hidden ? "hidden " : "", index); finishNodeLabel(tp); finishNode(); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 92ed241e..1577bd63 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1025,31 +1025,42 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - CommaSeparatorInserter comma(writer); + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; - writer.symbol("{"); - - for (std::size_t i = 0; i < a->props.size; ++i) + if (a->props.size == 0 && indexType && indexType->name == "number") { - comma(); - advance(a->props.data[i].location.begin); - writer.identifier(a->props.data[i].name.value); - if (a->props.data[i].type) - { - writer.symbol(":"); - visualizeTypeAnnotation(*a->props.data[i].type); - } - } - if (a->indexer) - { - comma(); - writer.symbol("["); - visualizeTypeAnnotation(*a->indexer->indexType); - writer.symbol("]"); - writer.symbol(":"); + writer.symbol("{"); visualizeTypeAnnotation(*a->indexer->resultType); + writer.symbol("}"); + } + else + { + CommaSeparatorInserter comma(writer); + + writer.symbol("{"); + + for (std::size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + writer.symbol("}"); } - writer.symbol("}"); } else if (auto a = typeAnnotation.as()) { diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index bc8d0d4e..0f4534b7 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -479,6 +479,20 @@ public: { return visitLocal(al->local); } + + virtual bool visit(AstStatFor* stat) override + { + visitLocal(stat->var); + return true; + } + + virtual bool visit(AstStatForIn* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + visitLocal(stat->vars.data[i]); + return true; + } + virtual bool visit(AstExprFunction* fn) override { // TODO: add generics if the inferred type of the function is generic CLI-39908 diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index af42a4e6..6411e2ab 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" @@ -47,7 +48,6 @@ LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. @@ -61,7 +61,9 @@ LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); +LUAU_FASTFLAG(LuauLosslessClone) namespace Luau { @@ -376,7 +378,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); - if (encounteredFreeType) + if (!FFlag::LuauLosslessClone && encounteredFreeType) { reportError(TypeError{module.root->location, GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); @@ -785,7 +787,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; - if (useConstrainedIntersections()) + if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) { unifyLowerBound(retPack, scope->returnType, return_.location); return; @@ -1241,7 +1243,12 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) + { + if (FFlag::LuauReturnTypeInferenceInNonstrict && FFlag::LuauLowerBoundsCalculation) + quantify(funScope, ty, exprName->location); + globalBindings[name] = oldBinding; + } else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; @@ -1555,7 +1562,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {})); + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); @@ -3284,7 +3291,7 @@ std::pair TypeChecker::checkFunctionSignature( TypePackId retPack; if (expr.returnAnnotation) retPack = resolveTypePack(funScope, *expr.returnAnnotation); - else if (isNonstrictMode()) + else if (FFlag::LuauReturnTypeInferenceInNonstrict ? (!FFlag::LuauLowerBoundsCalculation && isNonstrictMode()) : isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) { @@ -5328,19 +5335,9 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& indexer = table->indexer) tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - if (FFlag::LuauTypeMismatchModuleName) - { - TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; - ttv.definitionModuleName = currentModuleName; - return addType(std::move(ttv)); - } - else - { - return addType(TableTypeVar{ - props, tableIndexer, scope->level, - TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe - }); - } + TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + ttv.definitionModuleName = currentModuleName; + return addType(std::move(ttv)); } else if (const auto& func = annotation.as()) { @@ -5602,9 +5599,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; - - if (FFlag::LuauTypeMismatchModuleName) - ttv->definitionModuleName = currentModuleName; + ttv->definitionModuleName = currentModuleName; } return instantiated; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 0fbfdbf0..4d42573c 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) +LUAU_FASTFLAGVARIABLE(LuauClassDefinitionModuleInError, false) namespace Luau { @@ -304,6 +305,11 @@ std::optional getDefinitionModuleName(TypeId type) if (ftv->definition) return ftv->definition->definitionModuleName; } + else if (auto ctv = get(type); ctv && FFlag::LuauClassDefinitionModuleInError) + { + if (!ctv->definitionModuleName.empty()) + return ctv->definitionModuleName; + } return std::nullopt; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index f9ea58cc..9862d7b3 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,7 +17,6 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); @@ -28,7 +27,6 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) -LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauTypecheckOptPass) @@ -474,32 +472,21 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(subTy)) return tryUnifyWithAny(superTy, subTy); - bool cacheEnabled; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before - if (FFlag::LuauUnifierCacheErrors) + bool cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; + + if (cacheEnabled) { - cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; - - if (cacheEnabled) - { - if (cache.contains({subTy, superTy})) - return; - - if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) - { - reportError(TypeError{location, *error}); - return; - } - } - } - else - { - cacheEnabled = !isFunctionCall && !isIntersection; - - if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + if (cache.contains({subTy, superTy})) return; + + if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) + { + reportError(TypeError{location, *error}); + return; + } } // If we have seen this pair of types before, we are currently recursing into cyclic types. @@ -543,12 +530,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); - - if (!FFlag::LuauUnifierCacheErrors) - { - if (cacheEnabled && errors.empty()) - cacheResult_DEPRECATED(subTy, superTy); - } } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. @@ -568,7 +549,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - if (FFlag::LuauUnifierCacheErrors && cacheEnabled) + if (cacheEnabled) cacheResult(subTy, superTy, errorCount); log.popSeen(superTy, subTy); @@ -705,21 +686,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[i]; - if (FFlag::LuauUnifierCacheErrors) + if (cache.contains({subTy, type})) { - if (cache.contains({subTy, type})) - { - startIndex = i; - break; - } - } - else - { - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + startIndex = i; + break; } } } @@ -807,21 +777,10 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[i]; - if (FFlag::LuauUnifierCacheErrors) + if (cache.contains({type, superTy})) { - if (cache.contains({type, superTy})) - { - startIndex = i; - break; - } - } - else - { - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) - { - startIndex = i; - break; - } + startIndex = i; + break; } } } @@ -896,19 +855,6 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount) } } -void Unifier::cacheResult_DEPRECATED(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauUnifierCacheErrors); - - if (!canCacheResult(subTy, superTy)) - return; - - sharedState.cachedUnify.insert({superTy, subTy}); - - if (variance == Invariant) - sharedState.cachedUnify.insert({subTy, superTy}); -} - struct WeirdIter { TypePackId packId; @@ -1650,24 +1596,16 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) Unifier innerState = makeChildUnifier(); - if (FFlag::LuauExtendedIndexerError) - { - innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); - bool reported = !innerState.errors.empty(); + bool reported = !innerState.errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); - innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); - if (!reported) - checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); - } - else - { - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - } + if (!reported) + checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -2225,7 +2163,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2 || !FFlag::LuauExtendedIndexerError); + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); tryUnify_(subIndexer.indexType, superIndexer.indexType); tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 6ecf0606..6ae9e977 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -19,6 +19,7 @@ std::string format(const char* fmt, ...) LUAU_PRINTF_ATTR(1, 2); std::string vformat(const char* fmt, va_list args); void formatAppend(std::string& str, const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); +void vformatAppend(std::string& ret, const char* fmt, va_list args); std::string join(const std::vector& segments, std::string_view delimiter); std::string join(const std::vector& segments, std::string_view delimiter); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index badd3fd3..31ff3f77 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -167,6 +167,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc Function top; top.vararg = true; + functionStack.reserve(8); functionStack.push_back(top); nameSelf = names.addStatic("self"); @@ -186,6 +187,13 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode hotcommentHeader = false; + + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + localStack.reserve(16); + scratchStat.reserve(16); + scratchExpr.reserve(16); + scratchLocal.reserve(16); + scratchBinding.reserve(16); } bool Parser::blockFollow(const Lexeme& l) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 9c7fed31..0dc3f3f5 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -11,7 +11,7 @@ namespace Luau { -static void vformatAppend(std::string& ret, const char* fmt, va_list args) +void vformatAppend(std::string& ret, const char* fmt, va_list args) { va_list argscopy; va_copy(argscopy, args); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5fd6d341..345cb7ac 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -579,7 +579,8 @@ static bool compileFile(const char* name, CompileFormat format) if (format == CompileFormat::Text) { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } @@ -636,13 +637,60 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } +static void setLuauFlags(bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = state; + } +} + +static void setFlag(std::string_view name, bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (name == flag->name) + { + flag->value = state; + return; + } + } + + fprintf(stderr, "Warning: --fflag unrecognized flag '%.*s'.\n\n", int(name.length()), name.data()); +} + +static void applyFlagKeyValue(std::string_view element) +{ + if (size_t separator = element.find('='); separator != std::string_view::npos) + { + std::string_view key = element.substr(0, separator); + std::string_view value = element.substr(separator + 1); + + if (value == "true") + setFlag(key, true); + else if (value == "false") + setFlag(key, false); + else + fprintf(stderr, "Warning: --fflag unrecognized value '%.*s' for flag '%.*s'.\n\n", int(value.length()), value.data(), int(key.length()), + key.data()); + } + else + { + if (element == "true") + setLuauFlags(true); + else if (element == "false") + setLuauFlags(false); + else + setFlag(element, true); + } +} + int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; + setLuauFlags(true); CliMode mode = CliMode::Unknown; CompileFormat compileFormat{}; @@ -727,6 +775,22 @@ int replMain(int argc, char** argv) return 1; #endif } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + std::string_view list = argv[i] + 9; + + while (!list.empty()) + { + size_t ending = list.find(","); + + applyFlagKeyValue(list.substr(0, ending)); + + if (ending != std::string_view::npos) + list.remove_prefix(ending + 1); + else + break; + } + } else if (argv[i][0] == '-') { fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); diff --git a/CMakeLists.txt b/CMakeLists.txt index c6ccebc5..af03b33a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,12 @@ else() list(APPEND LUAU_OPTIONS -Wall) # All warnings endif() +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # Some gcc versions treat var in `if (type var = val)` as unused + # Some gcc versions treat variables used in constexpr if blocks as unused + list(APPEND LUAU_OPTIONS -Wno-unused) +endif() + # Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere if(LUAU_WERROR) if(MSVC) diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 287bf4ee..67b93028 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -3,6 +3,7 @@ #include "Luau/Bytecode.h" #include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" #include @@ -80,6 +81,8 @@ public: void pushDebugUpval(StringRef name); uint32_t getDebugPC() const; + void addDebugRemark(const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + void finalize(); enum DumpFlags @@ -88,6 +91,7 @@ public: Dump_Lines = 1 << 1, Dump_Source = 1 << 2, Dump_Locals = 1 << 3, + Dump_Remarks = 1 << 4, }; void setDumpFlags(uint32_t flags) @@ -228,6 +232,9 @@ private: DenseHashMap stringTable; + DenseHashMap debugRemarks; + std::string debugRemarkBuffer; + BytecodeEncoder* encoder = nullptr; std::string bytecode; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 6944de0f..6c6f1225 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -181,9 +181,17 @@ BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) , stringTable({nullptr, 0}) + , debugRemarks(~0u) , encoder(encoder) { LUAU_ASSERT(stringTable.find(StringRef{"", 0}) == nullptr); + + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + insns.reserve(32); + lines.reserve(32); + constants.reserve(16); + protos.reserve(16); + functions.reserve(8); } uint32_t BytecodeBuilder::beginFunction(uint8_t numparams, bool isvararg) @@ -219,8 +227,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) validate(); #endif - // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants - func.data.reserve(insns.size() * 7); + // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants plus overhead + func.data.reserve(32 + insns.size() * 7); writeFunction(func.data, currentFunction); @@ -242,6 +250,9 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) constantMap.clear(); tableShapeMap.clear(); + + debugRemarks.clear(); + debugRemarkBuffer.clear(); } void BytecodeBuilder::setMainFunction(uint32_t fid) @@ -505,9 +516,40 @@ uint32_t BytecodeBuilder::getDebugPC() const return uint32_t(insns.size()); } +void BytecodeBuilder::addDebugRemark(const char* format, ...) +{ + if ((dumpFlags & Dump_Remarks) == 0) + return; + + size_t offset = debugRemarkBuffer.size(); + + va_list args; + va_start(args, format); + vformatAppend(debugRemarkBuffer, format, args); + va_end(args); + + // we null-terminate all remarks to avoid storing remark length + debugRemarkBuffer += '\0'; + + debugRemarks[uint32_t(insns.size())] = uint32_t(offset); +} + void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); + + // preallocate space for bytecode blob + size_t capacity = 16; + + for (auto& p : stringTable) + capacity += p.first.length + 2; + + for (const Function& func : functions) + capacity += func.data.size(); + + bytecode.reserve(capacity); + + // assemble final bytecode blob bytecode = char(LBC_VERSION); writeStringTable(bytecode); @@ -663,6 +705,8 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const void BytecodeBuilder::writeLineInfo(std::string& ss) const { + LUAU_ASSERT(!lines.empty()); + // this function encodes lines inside each span as a 8-bit delta to span baseline // span is always a power of two; depending on the line info input, it may need to be as low as 1 int span = 1 << 24; @@ -693,7 +737,17 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const } // second pass: compute span base - std::vector baseline((lines.size() - 1) / span + 1); + int baselineOne = 0; + std::vector baselineScratch; + int* baseline = &baselineOne; + size_t baselineSize = (lines.size() - 1) / span + 1; + + if (baselineSize > 1) + { + // avoid heap allocation for single-element baseline which is most functions (<256 lines) + baselineScratch.resize(baselineSize); + baseline = baselineScratch.data(); + } for (size_t offset = 0; offset < lines.size(); offset += span) { @@ -725,7 +779,7 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const int lastLine = 0; - for (size_t i = 0; i < baseline.size(); ++i) + for (size_t i = 0; i < baselineSize; ++i) { writeInt(ss, baseline[i] - lastLine); lastLine = baseline[i]; @@ -1695,6 +1749,14 @@ std::string BytecodeBuilder::dumpCurrentFunction() const continue; } + if (dumpFlags & Dump_Remarks) + { + const uint32_t* remark = debugRemarks.find(uint32_t(code - insns.data())); + + if (remark) + formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + *remark); + } + if (dumpFlags & Dump_Source) { int line = lines[code - insns.data()]; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8ef69e75..810caaee 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -8,12 +8,17 @@ #include "Builtins.h" #include "ConstantFolding.h" +#include "CostModel.h" #include "TableShape.h" #include "ValueTracking.h" #include #include #include +#include + +LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) +LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) namespace Luau { @@ -77,8 +82,12 @@ struct Compiler , globals(AstName()) , variables(nullptr) , constants(nullptr) + , locstants(nullptr) , tableShapes(nullptr) { + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + localStack.reserve(16); + upvals.reserve(16); } uint8_t getLocal(AstLocal* local) @@ -209,7 +218,9 @@ struct Compiler Function& f = functions[func]; f.id = fid; - f.upvals = std::move(upvals); + f.upvals = upvals; + + upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes return fid; } @@ -2133,10 +2144,119 @@ struct Compiler pushLocal(stat->vars.data[i], uint8_t(vars + i)); } + int getConstantShort(AstExpr* expr) + { + const Constant* c = constants.find(expr); + + if (c && c->type == Constant::Type_Number) + { + double n = c->valueNumber; + + if (n >= -32767 && n <= 32767 && double(int(n)) == n) + return int(n); + } + + return INT_MIN; + } + + bool canUnrollForBody(AstStatFor* stat) + { + struct CanUnrollVisitor : AstVisitor + { + bool result = true; + + bool visit(AstExpr* node) override + { + // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) + result = result && !node->is(); + return result; + } + + bool visit(AstStat* node) override + { + // while we can easily unroll nested loops, our cost model doesn't take unrolling into account so this can result in code explosion + // we also avoid continue/break since they introduce control flow across iterations + result = result && !node->is() && !node->is() && !node->is(); + return result; + } + }; + + CanUnrollVisitor canUnroll; + stat->body->visit(&canUnroll); + + return canUnroll.result; + } + + bool tryCompileUnrolledFor(AstStatFor* stat, int thresholdBase, int thresholdMaxBoost) + { + int from = getConstantShort(stat->from); + int to = getConstantShort(stat->to); + int step = stat->step ? getConstantShort(stat->step) : 1; + + // check that limits are reasonably small and trip count can be computed + if (from == INT_MIN || to == INT_MIN || step == INT_MIN || step == 0 || (step < 0 && to > from) || (step > 0 && to < from)) + { + bytecode.addDebugRemark("loop unroll failed: invalid iteration count"); + return false; + } + + if (!canUnrollForBody(stat)) + { + bytecode.addDebugRemark("loop unroll failed: unsupported loop body"); + return false; + } + + int tripCount = (to - from) / step + 1; + + if (tripCount > thresholdBase * thresholdMaxBoost / 100) + { + bytecode.addDebugRemark("loop unroll failed: too many iterations (%d)", tripCount); + return false; + } + + AstLocal* var = stat->var; + uint64_t costModel = modelCost(stat->body, &var, 1); + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to unrolling + bool varc = true; + int unrolledCost = computeCost(costModel, &varc, 1) * tripCount; + int baselineCost = (computeCost(costModel, nullptr, 0) + 1) * tripCount; + int unrollProfit = (unrolledCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / unrolledCost); + + int threshold = thresholdBase * unrollProfit / 100; + + if (unrolledCost > threshold) + { + bytecode.addDebugRemark( + "loop unroll failed: too expensive (iterations %d, cost %d, profit %.2fx)", tripCount, unrolledCost, double(unrollProfit) / 100); + return false; + } + + bytecode.addDebugRemark("loop unroll succeeded (iterations %d, cost %d, profit %.2fx)", tripCount, unrolledCost, double(unrollProfit) / 100); + + for (int i = from; step > 0 ? i <= to : i >= to; i += step) + { + // we need to re-fold constants in the loop body with the new value; this reuses computed constant values elsewhere in the tree + locstants[var].type = Constant::Type_Number; + locstants[var].valueNumber = i; + + foldConstants(constants, variables, locstants, stat); + + compileStat(stat->body); + } + + return true; + } + void compileStatFor(AstStatFor* stat) { RegScope rs(this); + // Optimization: small loops can be unrolled when it is profitable + if (options.optimizationLevel >= 2 && isConstant(stat->to) && isConstant(stat->from) && (!stat->step || isConstant(stat->step))) + if (tryCompileUnrolledFor(stat, FInt::LuauCompileLoopUnrollThreshold, FInt::LuauCompileLoopUnrollThresholdMaxBoost)) + return; + size_t oldLocals = localStack.size(); size_t oldJumps = loopJumps.size(); @@ -2826,6 +2946,8 @@ struct Compiler : self(self) , functions(functions) { + // preallocate the result; this works around std::vector's inefficient growth policy for small arrays + functions.reserve(16); } bool visit(AstExprFunction* node) override @@ -2979,6 +3101,7 @@ struct Compiler DenseHashMap globals; DenseHashMap variables; DenseHashMap constants; + DenseHashMap locstants; DenseHashMap tableShapes; unsigned int regTop = 0; @@ -3008,7 +3131,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName if (options.optimizationLevel >= 1) { // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, root); + foldConstants(compiler.constants, compiler.variables, compiler.locstants, root); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 35ea0bf0..7ad91d4b 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -191,13 +191,13 @@ struct ConstantVisitor : AstVisitor { DenseHashMap& constants; DenseHashMap& variables; + DenseHashMap& locals; - DenseHashMap locals; - - ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables) + ConstantVisitor( + DenseHashMap& constants, DenseHashMap& variables, DenseHashMap& locals) : constants(constants) , variables(variables) - , locals(nullptr) + , locals(locals) { } @@ -385,9 +385,10 @@ struct ConstantVisitor : AstVisitor } }; -void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root) +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, + DenseHashMap& locals, AstNode* root) { - ConstantVisitor visitor{constants, variables}; + ConstantVisitor visitor{constants, variables, locals}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index c0e63539..0a995d75 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -42,7 +42,8 @@ struct Constant } }; -void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root); +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, + DenseHashMap& locals, AstNode* root); } // namespace Compile } // namespace Luau diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 46b10934..431f7e59 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,8 @@ #include +LUAU_FASTFLAG(LuauGcWorkTrackFix) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1050,6 +1052,7 @@ int lua_gc(lua_State* L, int what, int data) { size_t prevthreshold = g->GCthreshold; size_t amount = (cast_to(size_t, data) << 10); + ptrdiff_t oldcredit = g->gcstate == GCSpause ? 0 : g->GCthreshold - g->totalbytes; // temporarily adjust the threshold so that we can perform GC work if (amount <= g->totalbytes) @@ -1069,9 +1072,9 @@ int lua_gc(lua_State* L, int what, int data) while (g->GCthreshold <= g->totalbytes) { - luaC_step(L, false); + size_t stepsize = luaC_step(L, false); - actualwork += g->gcstepsize; + actualwork += FFlag::LuauGcWorkTrackFix ? stepsize : g->gcstepsize; if (g->gcstate == GCSpause) { /* end of cycle? */ @@ -1107,11 +1110,20 @@ int lua_gc(lua_State* L, int what, int data) // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) { - // if a new cycle was triggered by explicit step, we ignore old threshold as that shows an incorrect 'credit' of GC work - if (waspaused) - g->GCthreshold = g->totalbytes + actualwork; + if (FFlag::LuauGcWorkTrackFix) + { + // if a new cycle was triggered by explicit step, old 'credit' of GC work is 0 + ptrdiff_t newthreshold = g->totalbytes + actualwork + oldcredit; + g->GCthreshold = newthreshold < 0 ? 0 : newthreshold; + } else - g->GCthreshold = prevthreshold + actualwork; + { + // if a new cycle was triggered by explicit step, we ignore old threshold as that shows an incorrect 'credit' of GC work + if (waspaused) + g->GCthreshold = g->totalbytes + actualwork; + else + g->GCthreshold = prevthreshold + actualwork; + } } break; } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 8fc930d5..e7b73fe7 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,9 +13,10 @@ #include -#define GC_SWEEPMAX 40 -#define GC_SWEEPCOST 10 -#define GC_SWEEPPAGESTEPCOST 4 +LUAU_FASTFLAGVARIABLE(LuauGcWorkTrackFix, false) +LUAU_FASTFLAGVARIABLE(LuauGcSweepCostFix, false) + +#define GC_SWEEPPAGESTEPCOST (FFlag::LuauGcSweepCostFix ? 16 : 4) #define GC_INTERRUPT(state) \ { \ @@ -64,7 +65,7 @@ static void recordGcStateStep(global_State* g, int startgcstate, double seconds, case GCSpropagate: case GCSpropagateagain: g->gcmetrics.currcycle.marktime += seconds; - g->gcmetrics.currcycle.markrequests += g->gcstepsize; + g->gcmetrics.currcycle.markwork += work; if (assist) g->gcmetrics.currcycle.markassisttime += seconds; @@ -74,7 +75,7 @@ static void recordGcStateStep(global_State* g, int startgcstate, double seconds, break; case GCSsweep: g->gcmetrics.currcycle.sweeptime += seconds; - g->gcmetrics.currcycle.sweeprequests += g->gcstepsize; + g->gcmetrics.currcycle.sweepwork += work; if (assist) g->gcmetrics.currcycle.sweepassisttime += seconds; @@ -87,13 +88,11 @@ static void recordGcStateStep(global_State* g, int startgcstate, double seconds, { g->gcmetrics.stepassisttimeacc += seconds; g->gcmetrics.currcycle.assistwork += work; - g->gcmetrics.currcycle.assistrequests += g->gcstepsize; } else { g->gcmetrics.stepexplicittimeacc += seconds; g->gcmetrics.currcycle.explicitwork += work; - g->gcmetrics.currcycle.explicitrequests += g->gcstepsize; } } @@ -878,11 +877,11 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) return heaptrigger < int64_t(g->totalbytes) ? g->totalbytes : (heaptrigger > int64_t(heapgoal) ? heapgoal : size_t(heaptrigger)); } -void luaC_step(lua_State* L, bool assist) +size_t luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = FFlag::LuauGcWorkTrackFix ? g->gcstepsize * g->gcstepmul / 100 : (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -902,12 +901,13 @@ void luaC_step(lua_State* L, bool assist) int lastgcstate = g->gcstate; size_t work = gcstep(L, lim); - (void)work; #ifdef LUAI_GCMETRICS recordGcStateStep(g, lastgcstate, lua_clock() - lasttimestamp, assist, work); #endif + size_t actualstepsize = work * 100 / g->gcstepmul; + // at the end of the last cycle if (g->gcstate == GCSpause) { @@ -927,14 +927,16 @@ void luaC_step(lua_State* L, bool assist) } else { - g->GCthreshold = g->totalbytes + g->gcstepsize; + g->GCthreshold = g->totalbytes + (FFlag::LuauGcWorkTrackFix ? actualstepsize : g->gcstepsize); // compensate if GC is "behind schedule" (has some debt to pay) - if (g->GCthreshold > debt) + if (FFlag::LuauGcWorkTrackFix ? g->GCthreshold >= debt : g->GCthreshold > debt) g->GCthreshold -= debt; } GC_INTERRUPT(lastgcstate); + + return actualstepsize; } void luaC_fullgc(lua_State* L) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index dcd070b7..08d1ff5d 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -133,7 +133,7 @@ #define luaC_init(L, o, tt) luaC_initobj(L, cast_to(GCObject*, (o)), tt) LUAI_FUNC void luaC_freeall(lua_State* L); -LUAI_FUNC void luaC_step(lua_State* L, bool assist); +LUAI_FUNC size_t luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); LUAI_FUNC void luaC_initupval(lua_State* L, UpVal* uv); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index e7c37373..45d9ba2c 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -106,7 +106,7 @@ struct GCCycleMetrics double markassisttime = 0.0; double markmaxexplicittime = 0.0; size_t markexplicitsteps = 0; - size_t markrequests = 0; + size_t markwork = 0; double atomicstarttimestamp = 0.0; size_t atomicstarttotalsizebytes = 0; @@ -122,10 +122,7 @@ struct GCCycleMetrics double sweepassisttime = 0.0; double sweepmaxexplicittime = 0.0; size_t sweepexplicitsteps = 0; - size_t sweeprequests = 0; - - size_t assistrequests = 0; - size_t explicitrequests = 0; + size_t sweepwork = 0; size_t assistwork = 0; size_t explicitwork = 0; diff --git a/bench/bench.py b/bench/bench.py index 39f219f3..67fc8cf7 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -814,13 +814,12 @@ def run(args, argsubcb): analyzeResult('', mainResult, compareResults) else: - for subdir, dirs, files in os.walk(arguments.folder): - for filename in files: - filepath = subdir + os.sep + filename - - if filename.endswith(".lua"): - if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): - runTest(subdir, filename, filepath) + all_files = [subdir + os.sep + filename for subdir, dirs, files in os.walk(arguments.folder) for filename in files] + for filepath in sorted(all_files): + subdir, filename = os.path.split(filepath) + if filename.endswith(".lua"): + if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): + runTest(subdir, filename, filepath) if arguments.sort and len(plotValueLists) > 1: rearrange(rearrangeSortKeyForComparison) diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 1022831b..a48f068b 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -103,7 +103,7 @@ int registerTypes(Luau::TypeChecker& env) // Vector3 stub TypeId vector3MetaType = arena.addType(TableTypeVar{}); - TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}}); + TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); getMutable(vector3InstanceType)->props = { {"X", {env.numberType}}, {"Y", {env.numberType}}, @@ -117,7 +117,7 @@ int registerTypes(Luau::TypeChecker& env) env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub - TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}}); + TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(instanceType)->props = { {"Name", {env.stringType}}, }; @@ -125,7 +125,7 @@ int registerTypes(Luau::TypeChecker& env) env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub - TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}}); + TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); getMutable(partType)->props = { {"Position", {vector3InstanceType}}, }; @@ -173,7 +173,7 @@ struct FuzzConfigResolver : Luau::ConfigResolver { FuzzConfigResolver() { - defaultConfig.mode = Luau::Mode::Nonstrict; // typecheckTwice option will cover Strict mode + defaultConfig.mode = Luau::Mode::Nonstrict; defaultConfig.enabledLint.warningMask = ~0ull; defaultConfig.parseOptions.captureComments = true; } @@ -275,6 +275,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) // lint (note that we need access to types so we need to do this with typeck in scope) if (kFuzzLinter && result.errors.empty()) frontend.lint(name, std::nullopt); + + // Second pass in strict mode (forced by auto-complete) + Luau::FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(name, opts); } catch (std::exception&) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 2e7902f5..f66e23ed 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3034,4 +3034,40 @@ string:@1 CHECK(ac.entryMap["sub"].wrongIndexType == true); } +TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") +{ + check(R"( +local a = { x = 2, y = 4 } +a.@1 + )"); + + frontend.clear(); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.markDirty("MainModule", nullptr); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 83dad729..f3e60690 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -17,11 +17,13 @@ std::string rep(const std::string& s, size_t n); using namespace Luau; -static std::string compileFunction(const char* source, uint32_t id) +static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); - Luau::compileOrThrow(bcb, source); + Luau::CompileOptions options; + options.optimizationLevel = optimizationLevel; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); } @@ -2689,6 +2691,27 @@ local 8: reg 3, start pc 34 line 21, end pc 34 line 21 )"); } +TEST_CASE("DebugRemarks") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Remarks); + + uint32_t fid = bcb.beginFunction(0); + + bcb.addDebugRemark("test remark #%d", 42); + bcb.emitABC(LOP_RETURN, 0, 1, 0); + + bcb.endFunction(0, 0); + + bcb.setMainFunction(fid); + bcb.finalize(); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +REMARK test remark #42 +RETURN R0 0 +)"); +} + TEST_CASE("AssignmentConflict") { // assignments are left to right @@ -4076,4 +4099,336 @@ RETURN R1 6 )"); } +TEST_CASE("LoopUnrollBasic") +{ + // forward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); + + // backward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=2,1,-1 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +RETURN R0 1 +)"); + + // loops with step that doesn't divide to-from + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,4,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 3 +SETTABLEN R1 R0 3 +RETURN R0 1 +)"); +} + +TEST_CASE("LoopUnrollUnsupported") +{ + // can't unroll loops with non-constant bounds + CHECK_EQ("\n" + compileFunction(R"( +for i=x,y,z do +end +)", + 0, 2), + R"( +GETIMPORT R2 1 +GETIMPORT R0 3 +GETIMPORT R1 5 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops with bounds where we can't compute trip count + CHECK_EQ("\n" + compileFunction(R"( +for i=2,1 do +end +)", + 0, 2), + R"( +LOADN R2 2 +LOADN R0 1 +LOADN R1 1 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops with bounds that might be imprecise (non-integer) + CHECK_EQ("\n" + compileFunction(R"( +for i=1,2,0.1 do +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 2 +LOADK R1 K0 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops if the bounds are too large, as it might overflow trip count math + CHECK_EQ("\n" + compileFunction(R"( +for i=4294967295,4294967296 do +end +)", + 0, 2), + R"( +LOADK R2 K0 +LOADK R0 K1 +LOADN R1 1 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops if the body has loop control flow or nested loops + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + for j=1,1 do + if i == 1 then + continue + else + break + end + end +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 1 +LOADN R1 1 +FORNPREP R0 +11 +LOADN R5 1 +LOADN R3 1 +LOADN R4 1 +FORNPREP R3 +6 +JUMPIFNOTEQK R2 K0 +5 +JUMP +2 +JUMP +1 +JUMP +1 +FORNLOOP R3 -6 +FORNLOOP R0 -11 +RETURN R0 0 +)"); + + // can't unroll loops if the body has functions that refer to loop variables + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + local x = function() return i end +end +)", + 1, 2), + R"( +LOADN R2 1 +LOADN R0 1 +LOADN R1 1 +FORNPREP R0 +3 +NEWCLOSURE R3 P0 +CAPTURE VAL R2 +FORNLOOP R0 -3 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollCost") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 25}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // loops with short body + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 3 +SETTABLEN R1 R0 3 +LOADN R1 4 +SETTABLEN R1 R0 4 +LOADN R1 5 +SETTABLEN R1 R0 5 +LOADN R1 6 +SETTABLEN R1 R0 6 +LOADN R1 7 +SETTABLEN R1 R0 7 +LOADN R1 8 +SETTABLEN R1 R0 8 +LOADN R1 9 +SETTABLEN R1 R0 9 +LOADN R1 10 +SETTABLEN R1 R0 10 +RETURN R0 1 +)"); + + // loops with body that's too long + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,100 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R3 1 +LOADN R1 100 +LOADN R2 1 +FORNPREP R1 +2 +SETTABLE R3 R0 R3 +FORNLOOP R1 -2 +RETURN R0 1 +)"); + + // loops with body that's long but has a high boost factor due to constant folding + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,30 do + t[i] = i * i * i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 8 +SETTABLEN R1 R0 2 +LOADN R1 27 +SETTABLEN R1 R0 3 +LOADN R1 64 +SETTABLEN R1 R0 4 +LOADN R1 125 +SETTABLEN R1 R0 5 +LOADN R1 216 +SETTABLEN R1 R0 6 +LOADN R1 343 +SETTABLEN R1 R0 7 +LOADN R1 512 +SETTABLEN R1 R0 8 +LOADN R1 729 +SETTABLEN R1 R0 9 +LOADN R1 1000 +SETTABLEN R1 R0 10 +LOADN R1 1331 +SETTABLEN R1 R0 11 +LOADN R1 1728 +SETTABLEN R1 R0 12 +LOADN R1 2197 +SETTABLEN R1 R0 13 +LOADN R1 2744 +SETTABLEN R1 R0 14 +LOADN R1 3375 +SETTABLEN R1 R0 15 +LOADN R1 4096 +SETTABLEN R1 R0 16 +LOADN R1 4913 +SETTABLEN R1 R0 17 +LOADN R1 5832 +SETTABLEN R1 R0 18 +LOADN R1 6859 +SETTABLEN R1 R0 19 +LOADN R1 8000 +SETTABLEN R1 R0 20 +LOADN R1 9261 +SETTABLEN R1 R0 21 +LOADN R1 10648 +SETTABLEN R1 R0 22 +LOADN R1 12167 +SETTABLEN R1 R0 23 +LOADN R1 13824 +SETTABLEN R1 R0 24 +LOADN R1 15625 +SETTABLEN R1 R0 25 +LOADN R1 17576 +SETTABLEN R1 R0 26 +LOADN R1 19683 +SETTABLEN R1 R0 27 +LOADN R1 21952 +SETTABLEN R1 R0 28 +LOADN R1 24389 +SETTABLEN R1 R0 29 +LOADN R1 27000 +SETTABLEN R1 R0 30 +RETURN R0 1 +)"); + + // loops with body that's long and doesn't have a high boost factor + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = math.abs(math.sin(i)) +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R3 1 +LOADN R1 10 +LOADN R2 1 +FORNPREP R1 +11 +FASTCALL1 24 R3 +3 +MOVE R6 R3 +GETIMPORT R5 2 +CALL R5 1 -1 +FASTCALL 2 +2 +GETIMPORT R4 4 +CALL R4 -1 1 +SETTABLE R4 R0 R3 +FORNLOOP R1 -11 +RETURN R0 1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index ec04932f..aa5b7284 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -98,4 +98,129 @@ end CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); } +TEST_CASE("ImportCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return Instance.new(a) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("FastCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return math.abs(a + 1) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // note: we currently don't treat fast calls differently from cost model perspective + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(5, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("ControlFlow") +{ + uint64_t model = modelFunction(R"( +function test(a) + while a < 0 do + a += 1 + end + for i=1,2 do + a += 1 + end + for i in pairs({}) do + a += 1 + if a % 2 == 0 then continue end + end + repeat + a += 1 + if a % 2 == 0 then break end + until a > 10 + return a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(38, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(37, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("Conditional") +{ + uint64_t model = modelFunction(R"( +function test(a) + return if a < 0 then -a else a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("VarArgs") +{ + uint64_t model = modelFunction(R"( +function test(...) + return select('#', ...) :: number +end +)"); + + CHECK_EQ(8, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TablesFunctions") +{ + uint64_t model = modelFunction(R"( +function test() + return { 42, op = function() end } +end +)"); + + CHECK_EQ(22, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("CostOverflow") +{ + uint64_t model = modelFunction(R"( +function test() + return {{{{{{{{{{{{{{{}}}}}}}}}}}}}}} +end +)"); + + CHECK_EQ(127, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TableAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,#a do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + TEST_SUITE_END(); diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index 1d2ad645..6f1cebc6 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -9,6 +9,46 @@ using namespace Luau; +struct JsonEncoderFixture +{ + Allocator allocator; + AstNameTable names{allocator}; + + ParseResult parse(std::string_view src) + { + ParseOptions opts; + opts.allowDeclarationSyntax = true; + return Parser::parse(src.data(), src.size(), names, allocator, opts); + } + + AstStatBlock* expectParse(std::string_view src) + { + ParseResult res = parse(src); + REQUIRE(res.errors.size() == 0); + return res.root; + } + + AstStat* expectParseStatement(std::string_view src) + { + AstStatBlock* root = expectParse(src); + REQUIRE(1 == root->body.size); + return root->body.data[0]; + } + + AstExpr* expectParseExpr(std::string_view src) + { + std::string s = "a = "; + s.append(src); + AstStatBlock* root = expectParse(s); + + AstStatAssign* statAssign = root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(statAssign->values.size == 1); + + return statAssign->values.data[0]; + } +}; + TEST_SUITE_BEGIN("JsonEncoderTests"); TEST_CASE("encode_constants") @@ -51,7 +91,7 @@ TEST_CASE("encode_AstStatBlock") toJson(&block)); } -TEST_CASE("encode_tables") +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") { std::string src = R"( local x: { @@ -61,16 +101,294 @@ TEST_CASE("encode_tables") } )"; - Allocator allocator; - AstNameTable names(allocator); - ParseResult parseResult = Parser::parse(src.c_str(), src.length(), names, allocator); - - REQUIRE(parseResult.errors.size() == 0); - std::string json = toJson(parseResult.root); + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); CHECK( json == R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); } +TEST_CASE("encode_AstExprGroup") +{ + AstExprConstantNumber number{Location{}, 5.0}; + AstExprGroup group{Location{}, &number}; + + std::string json = toJson(&group); + + const std::string expected = R"({"type":"AstExprGroup","location":"0,0 - 0,0","expr":{"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":5}})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprGlobal") +{ + AstExprGlobal global{Location{}, AstName{"print"}}; + + std::string json = toJson(&global); + std::string expected = R"({"type":"AstExprGlobal","location":"0,0 - 0,0","global":"print"})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprLocal") +{ + AstLocal local{AstName{"foo"}, Location{}, nullptr, 0, 0, nullptr}; + AstExprLocal exprLocal{Location{}, &local, false}; + + CHECK(toJson(&exprLocal) == R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"type":null,"name":"foo","location":"0,0 - 0,0"}})"); +} + +TEST_CASE("encode_AstExprVarargs") +{ + AstExprVarargs varargs{Location{}}; + + CHECK(toJson(&varargs) == R"({"type":"AstExprVarargs","location":"0,0 - 0,0"})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprCall") +{ + AstExpr* expr = expectParseExpr("foo(1, 2, 3)"); + std::string_view expected = R"({"type":"AstExprCall","location":"0,4 - 0,16","func":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"args":[{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},{"type":"AstExprConstantNumber","location":"0,11 - 0,12","value":2},{"type":"AstExprConstantNumber","location":"0,14 - 0,15","value":3}],"self":false,"argLocation":"0,8 - 0,16"})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexName") +{ + AstExpr* expr = expectParseExpr("foo.bar"); + + std::string_view expected = R"({"type":"AstExprIndexName","location":"0,4 - 0,11","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":"bar","indexLocation":"0,8 - 0,11","op":"."})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexExpr") +{ + AstExpr* expr = expectParseExpr("foo['bar']"); + + std::string_view expected = R"({"type":"AstExprIndexExpr","location":"0,4 - 0,14","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":{"type":"AstExprConstantString","location":"0,8 - 0,13","value":"bar"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") +{ + AstExpr* expr = expectParseExpr("function (a) return a end"); + + std::string_view expected = R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"type":null,"name":"a","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"type":null,"name":"a","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTable") +{ + AstExpr* expr = expectParseExpr("{true, key=true, [key2]=true}"); + + std::string_view expected = R"({"type":"AstExprTable","location":"0,4 - 0,33","items":[{"kind":"item","value":{"type":"AstExprConstantBool","location":"0,5 - 0,9","value":true}},{"kind":"record","key":{"type":"AstExprConstantString","location":"0,11 - 0,14","value":"key"},"value":{"type":"AstExprConstantBool","location":"0,15 - 0,19","value":true}},{"kind":"general","key":{"type":"AstExprGlobal","location":"0,22 - 0,26","global":"key2"},"value":{"type":"AstExprConstantBool","location":"0,28 - 0,32","value":true}}]})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprUnary") +{ + AstExpr* expr = expectParseExpr("-b"); + + std::string_view expected = R"({"type":"AstExprUnary","location":"0,4 - 0,6","op":"minus","expr":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprBinary") +{ + AstExpr* expr = expectParseExpr("b + c"); + + std::string_view expected = R"({"type":"AstExprBinary","location":"0,4 - 0,9","op":"Add","left":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"right":{"type":"AstExprGlobal","location":"0,8 - 0,9","global":"c"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTypeAssertion") +{ + AstExpr* expr = expectParseExpr("b :: any"); + + std::string_view expected = R"({"type":"AstExprTypeAssertion","location":"0,4 - 0,12","expr":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"annotation":{"type":"AstTypeReference","location":"0,9 - 0,12","name":"any","parameters":[]}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprError") +{ + std::string_view src = "a = "; + ParseResult parseResult = Parser::parse(src.data(), src.size(), names, allocator); + + REQUIRE(1 == parseResult.root->body.size); + + AstStatAssign* statAssign = parseResult.root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(1 == statAssign->values.size); + + AstExpr* expr = statAssign->values.data[0]; + + std::string_view expected = R"({"type":"AstExprError","location":"0,4 - 0,4","expressions":[],"messageIndex":0})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatIf") +{ + AstStat* statement = expectParseStatement("if true then else end"); + + std::string_view expected = R"({"type":"AstStatIf","location":"0,0 - 0,21","condition":{"type":"AstExprConstantBool","location":"0,3 - 0,7","value":true},"thenbody":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"elsebody":{"type":"AstStatBlock","location":"0,17 - 0,18","body":[]},"hasThen":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatWhile") +{ + AstStat* statement = expectParseStatement("while true do end"); + + std::string_view expected = R"({"type":"AtStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatRepeat") +{ + AstStat* statement = expectParseStatement("repeat until true"); + + std::string_view expected = R"({"type":"AstStatRepeat","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,13 - 0,17","value":true},"body":{"type":"AstStatBlock","location":"0,6 - 0,7","body":[]},"hasUntil":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatBreak") +{ + AstStat* statement = expectParseStatement("while true do break end"); + + std::string_view expected = R"({"type":"AtStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatContinue") +{ + AstStat* statement = expectParseStatement("while true do continue end"); + + std::string_view expected = R"({"type":"AtStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatFor") +{ + AstStat* statement = expectParseStatement("for a=0,1 do end"); + + std::string_view expected = R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"type":null,"name":"a","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatForIn") +{ + AstStat* statement = expectParseStatement("for a in b do end"); + + std::string_view expected = R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"type":null,"name":"a","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatCompoundAssign") +{ + AstStat* statement = expectParseStatement("a += b"); + + std::string_view expected = R"({"type":"AstStatCompoundAssign","location":"0,0 - 0,6","op":"Add","var":{"type":"AstExprGlobal","location":"0,0 - 0,1","global":"a"},"value":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") +{ + AstStat* statement = expectParseStatement("local function a(b) return end"); + + std::string_view expected = R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"type":null,"name":"a","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"type":null,"name":"b","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") +{ + AstStat* statement = expectParseStatement("type A = B"); + + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"type":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","parameters":[]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") +{ + AstStat* statement = expectParseStatement("declare function foo(x: number): string"); + + std::string_view expected = R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") +{ + AstStatBlock* root = expectParse(R"( + declare class Foo + prop: number + function method(self, foo: number): string + end + + declare class Bar extends Foo + prop2: string + end + )"); + + REQUIRE(2 == root->body.size); + + std::string_view expected1 = R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; + CHECK(toJson(root->body.data[0]) == expected1); + + std::string_view expected2 = R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; + CHECK(toJson(root->body.data[1]) == expected2); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") +{ + AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); + + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"types":[]}}]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeError") +{ + ParseResult parseResult = parse("type T = "); + REQUIRE(1 == parseResult.root->body.size); + + AstStat* statement = parseResult.root->body.data[0]; + + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypePackExplicit") +{ + AstStatBlock* root = expectParse(R"( + type A = () -> T... + local a: A<(number, string)> + )"); + + CHECK(2 == root->body.size); + + std::string_view expected = R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"type":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","location":"2,14 - 2,15"}],"values":[]})"; + + CHECK(toJson(root->body.data[1]) == expected); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 05ee9a7b..6649cb7f 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1436,7 +1436,7 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 738893db..af7d76de 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -173,13 +173,13 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { {"__add", {typeChecker.anyType}}, }, - std::nullopt, std::nullopt, {}, {}}}; + std::nullopt, std::nullopt, {}, {}, "Test"}}; TypeVar exampleClass{ClassTypeVar{"ExampleClass", { {"PropOne", {typeChecker.numberType}}, {"PropTwo", {typeChecker.stringType}}, }, - std::nullopt, &exampleMetaClass, {}, {}}}; + std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; TypeArena dest; CloneState cloneState; @@ -196,9 +196,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") CHECK_EQ("ExampleClassMeta", metatable->name); } -TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") +TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + ScopedFastFlag sff[]{ + {"LuauErrorRecoveryType", true}, + {"LuauLosslessClone", true}, + }; TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -207,17 +210,17 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") CloneState cloneState; TypeId clonedTy = clone(&freeTy, dest, cloneState); - CHECK_EQ("any", toString(clonedTy)); - CHECK(cloneState.encounteredFreeType); + CHECK(get(clonedTy)); cloneState = {}; TypePackId clonedTp = clone(&freeTp, dest, cloneState); - CHECK_EQ("...any", toString(clonedTp)); - CHECK(cloneState.encounteredFreeType); + CHECK(get(clonedTp)); } -TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") +TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { + ScopedFastFlag sff{"LuauLosslessClone", true}; + TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->state = TableState::Free; @@ -227,8 +230,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); - CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(cloneState.encounteredFreeType); + CHECK_EQ(clonedTtv->state, TableState::Free); } TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index a8a12b69..9748eb27 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,6 +13,29 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); +TEST_CASE_FIXTURE(Fixture, "function_returns_number_or_string") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true} + }; + + CheckResult result = check(R"( + --!nonstrict + local function f() + if math.random() > 0.5 then + return 5 + else + return "hi" + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("() -> number | string" == toString(requireType("f"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") { CheckResult result = check(R"( @@ -35,8 +58,13 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ(0, rets.size()); } -TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") +TEST_CASE_FIXTURE(Fixture, "first_return_type_dictates_number_of_return_types") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict function getMinCardCountForWidth(width) @@ -51,22 +79,18 @@ TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_coul TypeId t = requireType("getMinCardCountForWidth"); REQUIRE(t); - REQUIRE_EQ("(any) -> (...any)", toString(t)); + REQUIRE_EQ("(any) -> number", toString(t)); } -#if 0 -// Maybe we want this? TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") { CheckResult result = check(R"( + --!nonstrict function foo(x): number return 'hello' end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); } -#endif TEST_CASE_FIXTURE(Fixture, "function_parameters_are_any") { @@ -256,6 +280,12 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + {"LuauSealExports", true}, + }; + CheckResult result = check(R"( --!nonstrict @@ -272,7 +302,7 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); + REQUIRE_EQ("((any) -> string) | {| foo: any |}", toString(getMainModule()->getModuleScope()->returnType)); } TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 5a84201a..d3778f67 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -21,7 +21,7 @@ void createSomeClasses(TypeChecker& typeChecker) unfreeze(arena); - TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); ClassTypeVar* parentClass = getMutable(parentType); parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; @@ -31,7 +31,7 @@ void createSomeClasses(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "Parent", {parentType}); typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; - TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr}); + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); ClassTypeVar* childClass = getMutable(childType); childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; @@ -39,7 +39,7 @@ void createSomeClasses(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "Child", {childType}); typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; - TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; @@ -400,7 +400,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop") CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a"))); } -#if 0 TEST_CASE_FIXTURE(NormalizeFixture, "tables") { check(R"( @@ -428,6 +427,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "tables") CHECK(!isSubtype(b, d)); } +#if 0 TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") { check(R"( @@ -619,6 +619,7 @@ TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauReturnTypeInferenceInNonstrict", true}, }; check(R"( @@ -639,7 +640,7 @@ TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") end )"); - CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType)); + CHECK_EQ("(any, any) -> (any, any) -> any", toString(getMainModule()->getModuleScope()->returnType)); } TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") @@ -950,6 +951,27 @@ TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "visiting_a_type_twice_is_not_considered_normal") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + --!strict + function f(a, b) + local function g() + if math.random() > 0.5 then + return a() + else + return b + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(() -> a, a) -> ()", toString(requireType("f"))); +} + TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") { ScopedFastFlag flags[] = { @@ -964,4 +986,16 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_bound_type_is_normal_but_not_its_bounded_to") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type t252 = ((t0)|(any))|(any) + type t0 = t252,t24...> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index f3fda54e..332a4b22 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -21,13 +21,13 @@ struct ToDotClassFixture : Fixture TypeId baseClassMetaType = arena.addType(TableTypeVar{}); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}}); + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseField", {typeChecker.numberType}}, }; typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}}); + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { {"ChildField", {typeChecker.stringType}}, }; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 0c324cd0..b02a52b2 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -661,4 +661,21 @@ type t4 = false CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_array_types") +{ + std::string code = R"( +type t1 = {number} +type t2 = {[string]: number} + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") +{ + std::string code = "for k:string,v:boolean in next,{}do end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 8e3629e7..5a6e4032 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -19,13 +19,13 @@ struct ClassFixture : Fixture unfreeze(arena); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, {"BaseField", {numberType}}, }; - TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassType)->props = { {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, @@ -34,39 +34,39 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; addGlobalBinding(typeChecker, "BaseClass", baseClassType, "@test"); - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}}); + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}}); + TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; addGlobalBinding(typeChecker, "ChildClass", childClassType, "@test"); - TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}}); + TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}}); + TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; addGlobalBinding(typeChecker, "GrandChild", childClassType, "@test"); - TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}}); + TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}}); + TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; @@ -75,13 +75,13 @@ struct ClassFixture : Fixture TypeId vector2MetaType = arena.addType(TableTypeVar{}); - TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}}); + TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); getMutable(vector2InstanceType)->props = { {"X", {numberType}}, {"Y", {numberType}}, }; - TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}}); + TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(vector2Type)->props = { {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, }; @@ -468,4 +468,18 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") +{ + ScopedFastFlag luauClassDefinitionModuleInError{"LuauClassDefinitionModuleInError", true}; + + CheckResult result = check(R"( +local i = ChildClass.New() +type ChildClass = { x: number } +local a: ChildClass = i + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 898d8902..4545b8db 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -295,8 +295,6 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type TEST_CASE_FIXTURE(Fixture, "single_class_type_identity_in_global_types") { - ScopedFastFlag luauCloneDeclaredGlobals{"LuauCloneDeclaredGlobals", true}; - loadDefinition(R"( declare class Cls end diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 65993681..7cd7bec3 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -656,6 +656,11 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict @@ -664,7 +669,7 @@ TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") end return function() - return f():andThen() + return f() end )"); @@ -791,14 +796,18 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict - function Test(a) + function Test(a): ...any return 1, "" end - local tab = {} table.insert(tab, Test(1)); )"); @@ -1616,4 +1625,19 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") CHECK(nullptr != get(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index e5eeae31..fa1f519c 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -307,8 +307,6 @@ type Rename = typeof(x.x) TEST_CASE_FIXTURE(Fixture, "module_type_conflict") { - ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; - fileResolver.source["game/A"] = R"( export type T = { x: number } return {} @@ -343,8 +341,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") { - ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; - fileResolver.source["game/A"] = R"( export type Wrap = { x: T } return {} diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 6b3741fa..4b5075d9 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -584,20 +584,6 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } -TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") -{ - ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", false}, - }; - - CheckResult result = check(R"( - local function f() return end - local g = function() return f() end - )"); - - LUAU_REQUIRE_ERRORS(result); // Should not have any errors. -} - TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { ScopedFastFlag sff[] = { diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index ce22bcb1..136ca00a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -44,7 +44,7 @@ struct RefinementClassFixture : Fixture TypeArena& arena = typeChecker.globalTypes; unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); getMutable(vec3)->props = { {"X", Property{typeChecker.numberType}}, {"Y", Property{typeChecker.numberType}}, @@ -52,7 +52,7 @@ struct RefinementClassFixture : Fixture }; normalize(vec3, arena, *typeChecker.iceHandler); - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); @@ -66,9 +66,9 @@ struct RefinementClassFixture : Fixture }; normalize(inst, arena, *typeChecker.iceHandler); - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); normalize(folder, arena, *typeChecker.iceHandler); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index ca1b8de7..2a727bb3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2086,7 +2086,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") { ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; CheckResult result = check(R"( type A = { [number]: string } @@ -2105,7 +2104,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") { ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; CheckResult result = check(R"( type A = { [number]: number } diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6abd96b9..a578b1cf 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -86,16 +86,21 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nocheck function f(x) - return x + return 5 end -- we get type information even if there's type errors f(1, 2) )"); - CHECK_EQ("(any) -> (...any)", toString(requireType("f"))); + CHECK_EQ("(any) -> number", toString(requireType("f"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -363,6 +368,11 @@ TEST_CASE_FIXTURE(Fixture, "globals") TEST_CASE_FIXTURE(Fixture, "globals2") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict foo = function() return 1 end @@ -373,9 +383,9 @@ TEST_CASE_FIXTURE(Fixture, "globals2") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("() -> (...any)", toString(tm->wantedType)); + CHECK_EQ("() -> number", toString(tm->wantedType)); CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> (...any)", toString(requireType("foo"))); + CHECK_EQ("() -> number", toString(requireType("foo"))); } TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index fd5f4dbc..d03bb03c 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -275,7 +275,7 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); CHECK(Luau::hasTag(&base, "foo")); @@ -283,8 +283,8 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; - TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); CHECK(!Luau::hasTag(&derived, "foo")); From 74c84815a0ca7f1799a800156833cfc44016f12f Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Thu, 28 Apr 2022 15:00:55 -0500 Subject: [PATCH 049/102] Prototyping type normalizaton (#466) * Added type normalization --- prototyping/Luau/FunctionTypes.agda | 38 ++ prototyping/Luau/StrictMode.agda | 3 +- prototyping/Luau/Subtyping.agda | 2 +- prototyping/Luau/Type.agda | 24 +- prototyping/Luau/TypeCheck.agda | 3 +- prototyping/Luau/TypeNormalization.agda | 69 ++++ prototyping/Properties.agda | 3 + prototyping/Properties/DecSubtyping.agda | 70 ++++ prototyping/Properties/FunctionTypes.agda | 150 +++++++ prototyping/Properties/StrictMode.agda | 6 +- prototyping/Properties/Subtyping.agda | 331 ++++++++++----- prototyping/Properties/TypeCheck.agda | 3 +- prototyping/Properties/TypeNormalization.agda | 376 ++++++++++++++++++ 13 files changed, 940 insertions(+), 138 deletions(-) create mode 100644 prototyping/Luau/FunctionTypes.agda create mode 100644 prototyping/Luau/TypeNormalization.agda create mode 100644 prototyping/Properties/DecSubtyping.agda create mode 100644 prototyping/Properties/FunctionTypes.agda create mode 100644 prototyping/Properties/TypeNormalization.agda diff --git a/prototyping/Luau/FunctionTypes.agda b/prototyping/Luau/FunctionTypes.agda new file mode 100644 index 00000000..7607052b --- /dev/null +++ b/prototyping/Luau/FunctionTypes.agda @@ -0,0 +1,38 @@ +{-# OPTIONS --rewriting #-} + +open import FFI.Data.Either using (Either; Left; Right) +open import Luau.Type using (Type; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) +open import Luau.TypeNormalization using (normalize) + +module Luau.FunctionTypes where + +-- The domain of a normalized type +srcⁿ : Type → Type +srcⁿ (S ⇒ T) = S +srcⁿ (S ∩ T) = srcⁿ S ∪ srcⁿ T +srcⁿ never = unknown +srcⁿ T = never + +-- To get the domain of a type, we normalize it first We need to do +-- this, since if we try to use it on non-normalized types, we get +-- +-- src(number ∩ string) = src(number) ∪ src(string) = never ∪ never +-- src(never) = unknown +-- +-- so src doesn't respect type equivalence. +src : Type → Type +src (S ⇒ T) = S +src T = srcⁿ(normalize T) + +-- The codomain of a type +tgt : Type → Type +tgt nil = never +tgt (S ⇒ T) = T +tgt never = never +tgt unknown = unknown +tgt number = never +tgt boolean = never +tgt string = never +tgt (S ∪ T) = (tgt S) ∪ (tgt T) +tgt (S ∩ T) = (tgt S) ∩ (tgt T) + diff --git a/prototyping/Luau/StrictMode.agda b/prototyping/Luau/StrictMode.agda index 1b028042..d3c0f153 100644 --- a/prototyping/Luau/StrictMode.agda +++ b/prototyping/Luau/StrictMode.agda @@ -5,7 +5,8 @@ module Luau.StrictMode where open import Agda.Builtin.Equality using (_≡_) open import FFI.Data.Maybe using (just; nothing) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; var; binexp; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; local_←_; _∙_; done; return; name; +; -; *; /; <; >; <=; >=; ··) -open import Luau.Type using (Type; nil; number; string; boolean; _⇒_; _∪_; _∩_; src; tgt) +open import Luau.FunctionTypes using (src; tgt) +open import Luau.Type using (Type; nil; number; string; boolean; _⇒_; _∪_; _∩_) open import Luau.Subtyping using (_≮:_) open import Luau.Heap using (Heap; function_is_end) renaming (_[_] to _[_]ᴴ) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) diff --git a/prototyping/Luau/Subtyping.agda b/prototyping/Luau/Subtyping.agda index 943f459b..624b6be4 100644 --- a/prototyping/Luau/Subtyping.agda +++ b/prototyping/Luau/Subtyping.agda @@ -25,7 +25,6 @@ data Language where function : ∀ {T U} → Language (T ⇒ U) function function-ok : ∀ {T U u} → (Language U u) → Language (T ⇒ U) (function-ok u) function-err : ∀ {T U t} → (¬Language T t) → Language (T ⇒ U) (function-err t) - scalar-function-err : ∀ {S t} → (Scalar S) → Language S (function-err t) left : ∀ {T U t} → Language T t → Language (T ∪ U) t right : ∀ {T U u} → Language U u → Language (T ∪ U) u _,_ : ∀ {T U t} → Language T t → Language U t → Language (T ∩ U) t @@ -36,6 +35,7 @@ data ¬Language where scalar-scalar : ∀ {S T} → (s : Scalar S) → (Scalar T) → (S ≢ T) → ¬Language T (scalar s) scalar-function : ∀ {S} → (Scalar S) → ¬Language S function scalar-function-ok : ∀ {S u} → (Scalar S) → ¬Language S (function-ok u) + scalar-function-err : ∀ {S t} → (Scalar S) → ¬Language S (function-err t) function-scalar : ∀ {S T U} (s : Scalar S) → ¬Language (T ⇒ U) (scalar s) function-ok : ∀ {T U u} → (¬Language U u) → ¬Language (T ⇒ U) (function-ok u) function-err : ∀ {T U t} → (Language T t) → ¬Language (T ⇒ U) (function-err t) diff --git a/prototyping/Luau/Type.agda b/prototyping/Luau/Type.agda index 59d1107f..1d0ec9e5 100644 --- a/prototyping/Luau/Type.agda +++ b/prototyping/Luau/Type.agda @@ -24,6 +24,8 @@ data Scalar : Type → Set where string : Scalar string nil : Scalar nil +skalar = number ∪ (string ∪ (nil ∪ boolean)) + lhs : Type → Type lhs (T ⇒ _) = T lhs (T ∪ _) = T @@ -146,28 +148,6 @@ just T ≡ᴹᵀ just U with T ≡ᵀ U (just T ≡ᴹᵀ just T) | yes refl = yes refl (just T ≡ᴹᵀ just U) | no p = no (λ q → p (just-inv q)) -src : Type → Type -src nil = never -src number = never -src boolean = never -src string = never -src (S ⇒ T) = S -src (S ∪ T) = (src S) ∩ (src T) -src (S ∩ T) = (src S) ∪ (src T) -src never = unknown -src unknown = never - -tgt : Type → Type -tgt nil = never -tgt (S ⇒ T) = T -tgt never = never -tgt unknown = unknown -tgt number = never -tgt boolean = never -tgt string = never -tgt (S ∪ T) = (tgt S) ∪ (tgt T) -tgt (S ∩ T) = (tgt S) ∩ (tgt T) - optional : Type → Type optional nil = nil optional (T ∪ nil) = (T ∪ nil) diff --git a/prototyping/Luau/TypeCheck.agda b/prototyping/Luau/TypeCheck.agda index cabd27a8..d4fabb90 100644 --- a/prototyping/Luau/TypeCheck.agda +++ b/prototyping/Luau/TypeCheck.agda @@ -7,8 +7,9 @@ open import FFI.Data.Maybe using (Maybe; just) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; number; bool; string; val; var; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; binexp; local_←_; _∙_; done; return; name; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.Var using (Var) open import Luau.Addr using (Addr) +open import Luau.FunctionTypes using (src; tgt) open import Luau.Heap using (Heap; Object; function_is_end) renaming (_[_] to _[_]ᴴ) -open import Luau.Type using (Type; nil; unknown; number; boolean; string; _⇒_; src; tgt) +open import Luau.Type using (Type; nil; unknown; number; boolean; string; _⇒_) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) open import FFI.Data.Vector using (Vector) open import FFI.Data.Maybe using (Maybe; just; nothing) diff --git a/prototyping/Luau/TypeNormalization.agda b/prototyping/Luau/TypeNormalization.agda new file mode 100644 index 00000000..341883ea --- /dev/null +++ b/prototyping/Luau/TypeNormalization.agda @@ -0,0 +1,69 @@ +module Luau.TypeNormalization where + +open import Luau.Type using (Type; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) + +-- The top non-function type +¬function : Type +¬function = number ∪ (string ∪ (nil ∪ boolean)) + +-- Unions and intersections of normalized types +_∪ᶠ_ : Type → Type → Type +_∪ⁿˢ_ : Type → Type → Type +_∩ⁿˢ_ : Type → Type → Type +_∪ⁿ_ : Type → Type → Type +_∩ⁿ_ : Type → Type → Type + +-- Union of function types +(F₁ ∩ F₂) ∪ᶠ G = (F₁ ∪ᶠ G) ∩ (F₂ ∪ᶠ G) +F ∪ᶠ (G₁ ∩ G₂) = (F ∪ᶠ G₁) ∩ (F ∪ᶠ G₂) +(R ⇒ S) ∪ᶠ (T ⇒ U) = (R ∩ⁿ T) ⇒ (S ∪ⁿ U) +F ∪ᶠ G = F ∪ G + +-- Union of normalized types +S ∪ⁿ (T₁ ∪ T₂) = (S ∪ⁿ T₁) ∪ T₂ +S ∪ⁿ unknown = unknown +S ∪ⁿ never = S +unknown ∪ⁿ T = unknown +never ∪ⁿ T = T +(S₁ ∪ S₂) ∪ⁿ G = (S₁ ∪ⁿ G) ∪ S₂ +F ∪ⁿ G = F ∪ᶠ G + +-- Intersection of normalized types +S ∩ⁿ (T₁ ∪ T₂) = (S ∩ⁿ T₁) ∪ⁿˢ (S ∩ⁿˢ T₂) +S ∩ⁿ unknown = S +S ∩ⁿ never = never +(S₁ ∪ S₂) ∩ⁿ G = (S₁ ∩ⁿ G) +unknown ∩ⁿ G = G +never ∩ⁿ G = never +F ∩ⁿ G = F ∩ G + +-- Intersection of normalized types with a scalar +(S₁ ∪ nil) ∩ⁿˢ nil = nil +(S₁ ∪ boolean) ∩ⁿˢ boolean = boolean +(S₁ ∪ number) ∩ⁿˢ number = number +(S₁ ∪ string) ∩ⁿˢ string = string +(S₁ ∪ S₂) ∩ⁿˢ T = S₁ ∩ⁿˢ T +unknown ∩ⁿˢ T = T +F ∩ⁿˢ T = never + +-- Union of normalized types with an optional scalar +S ∪ⁿˢ never = S +unknown ∪ⁿˢ T = unknown +(S₁ ∪ nil) ∪ⁿˢ nil = S₁ ∪ nil +(S₁ ∪ boolean) ∪ⁿˢ boolean = S₁ ∪ boolean +(S₁ ∪ number) ∪ⁿˢ number = S₁ ∪ number +(S₁ ∪ string) ∪ⁿˢ string = S₁ ∪ string +(S₁ ∪ S₂) ∪ⁿˢ T = (S₁ ∪ⁿˢ T) ∪ S₂ +F ∪ⁿˢ T = F ∪ T + +-- Normalize! +normalize : Type → Type +normalize nil = never ∪ nil +normalize (S ⇒ T) = (normalize S ⇒ normalize T) +normalize never = never +normalize unknown = unknown +normalize boolean = never ∪ boolean +normalize number = never ∪ number +normalize string = never ∪ string +normalize (S ∪ T) = normalize S ∪ⁿ normalize T +normalize (S ∩ T) = normalize S ∩ⁿ normalize T diff --git a/prototyping/Properties.agda b/prototyping/Properties.agda index 5594812e..b696c0fa 100644 --- a/prototyping/Properties.agda +++ b/prototyping/Properties.agda @@ -4,10 +4,13 @@ module Properties where import Properties.Contradiction import Properties.Dec +import Properties.DecSubtyping import Properties.Equality import Properties.Functions +import Properties.FunctionTypes import Properties.Remember import Properties.Step import Properties.StrictMode import Properties.Subtyping import Properties.TypeCheck +import Properties.TypeNormalization diff --git a/prototyping/Properties/DecSubtyping.agda b/prototyping/Properties/DecSubtyping.agda new file mode 100644 index 00000000..332520a9 --- /dev/null +++ b/prototyping/Properties/DecSubtyping.agda @@ -0,0 +1,70 @@ +{-# OPTIONS --rewriting #-} + +module Properties.DecSubtyping where + +open import Agda.Builtin.Equality using (_≡_; refl) +open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) +open import Luau.FunctionTypes using (src; srcⁿ; tgt) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) +open import Properties.Contradiction using (CONTRADICTION; ¬) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; <:-trans; ≮:-trans-<:; <:-trans-≮:; <:-never; <:-unknown; <:-∪-left; <:-∪-right; <:-∪-lub; ≮:-∪-left; ≮:-∪-right; <:-∩-left; <:-∩-right; <:-∩-glb; ≮:-∩-left; ≮:-∩-right; dec-language; scalar-<:; <:-everything; <:-function; ≮:-function-left; ≮:-function-right) +open import Properties.TypeNormalization using (FunType; Normal; never; unknown; _∩_; _∪_; _⇒_; normal; <:-normalize; normalize-<:) +open import Properties.FunctionTypes using (fun-¬scalar; ¬fun-scalar; fun-function; src-unknown-≮:; tgt-never-≮:; src-tgtᶠ-<:) +open import Properties.Equality using (_≢_) + +-- Honest this terminates, since src and tgt reduce the depth of nested arrows +{-# TERMINATING #-} +dec-subtypingˢⁿ : ∀ {T U} → Scalar T → Normal U → Either (T ≮: U) (T <: U) +dec-subtypingᶠ : ∀ {T U} → FunType T → FunType U → Either (T ≮: U) (T <: U) +dec-subtypingᶠⁿ : ∀ {T U} → FunType T → Normal U → Either (T ≮: U) (T <: U) +dec-subtypingⁿ : ∀ {T U} → Normal T → Normal U → Either (T ≮: U) (T <: U) +dec-subtyping : ∀ T U → Either (T ≮: U) (T <: U) + +dec-subtypingˢⁿ T U with dec-language _ (scalar T) +dec-subtypingˢⁿ T U | Left p = Left (witness (scalar T) (scalar T) p) +dec-subtypingˢⁿ T U | Right p = Right (scalar-<: T p) + +dec-subtypingᶠ {T = T} _ (U ⇒ V) with dec-subtypingⁿ U (normal (src T)) | dec-subtypingⁿ (normal (tgt T)) V +dec-subtypingᶠ {T = T} _ (U ⇒ V) | Left p | q = Left (≮:-trans-<: (src-unknown-≮: (≮:-trans-<: p (<:-normalize (src T)))) (<:-function <:-refl <:-unknown)) +dec-subtypingᶠ {T = T} _ (U ⇒ V) | Right p | Left q = Left (≮:-trans-<: (tgt-never-≮: (<:-trans-≮: (normalize-<: (tgt T)) q)) (<:-trans (<:-function <:-never <:-refl) <:-∪-right)) +dec-subtypingᶠ T (U ⇒ V) | Right p | Right q = Right (src-tgtᶠ-<: T (<:-trans p (normalize-<: _)) (<:-trans (<:-normalize _) q)) + +dec-subtypingᶠ T (U ∩ V) with dec-subtypingᶠ T U | dec-subtypingᶠ T V +dec-subtypingᶠ T (U ∩ V) | Left p | q = Left (≮:-∩-left p) +dec-subtypingᶠ T (U ∩ V) | Right p | Left q = Left (≮:-∩-right q) +dec-subtypingᶠ T (U ∩ V) | Right p | Right q = Right (<:-∩-glb p q) + +dec-subtypingᶠⁿ T never = Left (witness function (fun-function T) never) +dec-subtypingᶠⁿ T unknown = Right <:-unknown +dec-subtypingᶠⁿ T (U ⇒ V) = dec-subtypingᶠ T (U ⇒ V) +dec-subtypingᶠⁿ T (U ∩ V) = dec-subtypingᶠ T (U ∩ V) +dec-subtypingᶠⁿ T (U ∪ V) with dec-subtypingᶠⁿ T U +dec-subtypingᶠⁿ T (U ∪ V) | Left (witness t p q) = Left (witness t p (q , ¬fun-scalar V T p)) +dec-subtypingᶠⁿ T (U ∪ V) | Right p = Right (<:-trans p <:-∪-left) + +dec-subtypingⁿ never U = Right <:-never +dec-subtypingⁿ unknown unknown = Right <:-refl +dec-subtypingⁿ unknown U with dec-subtypingᶠⁿ (never ⇒ unknown) U +dec-subtypingⁿ unknown U | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ with dec-subtypingˢⁿ number U +dec-subtypingⁿ unknown U | Right p₁ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ with dec-subtypingˢⁿ string U +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ with dec-subtypingˢⁿ nil U +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Right p₄ with dec-subtypingˢⁿ boolean U +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Right p₄ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Right p₄ | Right p₅ = Right (<:-trans <:-everything (<:-∪-lub p₁ (<:-∪-lub p₂ (<:-∪-lub p₃ (<:-∪-lub p₄ p₅))))) +dec-subtypingⁿ (S ⇒ T) U = dec-subtypingᶠⁿ (S ⇒ T) U +dec-subtypingⁿ (S ∩ T) U = dec-subtypingᶠⁿ (S ∩ T) U +dec-subtypingⁿ (S ∪ T) U with dec-subtypingⁿ S U | dec-subtypingˢⁿ T U +dec-subtypingⁿ (S ∪ T) U | Left p | q = Left (≮:-∪-left p) +dec-subtypingⁿ (S ∪ T) U | Right p | Left q = Left (≮:-∪-right q) +dec-subtypingⁿ (S ∪ T) U | Right p | Right q = Right (<:-∪-lub p q) + +dec-subtyping T U with dec-subtypingⁿ (normal T) (normal U) +dec-subtyping T U | Left p = Left (<:-trans-≮: (normalize-<: T) (≮:-trans-<: p (<:-normalize U))) +dec-subtyping T U | Right p = Right (<:-trans (<:-normalize T) (<:-trans p (normalize-<: U))) + diff --git a/prototyping/Properties/FunctionTypes.agda b/prototyping/Properties/FunctionTypes.agda new file mode 100644 index 00000000..514477f1 --- /dev/null +++ b/prototyping/Properties/FunctionTypes.agda @@ -0,0 +1,150 @@ +{-# OPTIONS --rewriting #-} + +module Properties.FunctionTypes where + +open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) +open import Luau.FunctionTypes using (srcⁿ; src; tgt) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; skalar) +open import Properties.Contradiction using (CONTRADICTION; ¬; ⊥) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; ≮:-refl; <:-trans-≮:; skalar-scalar; <:-impl-⊇; skalar-function-ok; language-comp) +open import Properties.TypeNormalization using (FunType; Normal; never; unknown; _∩_; _∪_; _⇒_; normal; <:-normalize; normalize-<:) + +-- Properties of src +function-err-srcⁿ : ∀ {T t} → (FunType T) → (¬Language (srcⁿ T) t) → Language T (function-err t) +function-err-srcⁿ (S ⇒ T) p = function-err p +function-err-srcⁿ (S ∩ T) (p₁ , p₂) = (function-err-srcⁿ S p₁ , function-err-srcⁿ T p₂) + +¬function-err-srcᶠ : ∀ {T t} → (FunType T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) +¬function-err-srcᶠ (S ⇒ T) p = function-err p +¬function-err-srcᶠ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) +¬function-err-srcᶠ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) + +¬function-err-srcⁿ : ∀ {T t} → (Normal T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) +¬function-err-srcⁿ never p = never +¬function-err-srcⁿ unknown (scalar ()) +¬function-err-srcⁿ (S ⇒ T) p = function-err p +¬function-err-srcⁿ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) +¬function-err-srcⁿ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) +¬function-err-srcⁿ (S ∪ T) (scalar ()) + +¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) +¬function-err-src {T = S ⇒ T} p = function-err p +¬function-err-src {T = nil} p = scalar-function-err nil +¬function-err-src {T = never} p = never +¬function-err-src {T = unknown} (scalar ()) +¬function-err-src {T = boolean} p = scalar-function-err boolean +¬function-err-src {T = number} p = scalar-function-err number +¬function-err-src {T = string} p = scalar-function-err string +¬function-err-src {T = S ∪ T} p = <:-impl-⊇ (<:-normalize (S ∪ T)) _ (¬function-err-srcⁿ (normal (S ∪ T)) p) +¬function-err-src {T = S ∩ T} p = <:-impl-⊇ (<:-normalize (S ∩ T)) _ (¬function-err-srcⁿ (normal (S ∩ T)) p) + +src-¬function-errᶠ : ∀ {T t} → (FunType T) → Language T (function-err t) → (¬Language (srcⁿ T) t) +src-¬function-errᶠ (S ⇒ T) (function-err p) = p +src-¬function-errᶠ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) + +src-¬function-errⁿ : ∀ {T t} → (Normal T) → Language T (function-err t) → (¬Language (srcⁿ T) t) +src-¬function-errⁿ unknown p = never +src-¬function-errⁿ (S ⇒ T) (function-err p) = p +src-¬function-errⁿ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) +src-¬function-errⁿ (S ∪ T) p = never + +src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) +src-¬function-err {T = S ⇒ T} (function-err p) = p +src-¬function-err {T = unknown} p = never +src-¬function-err {T = S ∪ T} p = src-¬function-errⁿ (normal (S ∪ T)) (<:-normalize (S ∪ T) _ p) +src-¬function-err {T = S ∩ T} p = src-¬function-errⁿ (normal (S ∩ T)) (<:-normalize (S ∩ T) _ p) + +fun-¬scalar : ∀ {S T} (s : Scalar S) → FunType T → ¬Language T (scalar s) +fun-¬scalar s (S ⇒ T) = function-scalar s +fun-¬scalar s (S ∩ T) = left (fun-¬scalar s S) + +¬fun-scalar : ∀ {S T t} (s : Scalar S) → FunType T → Language T t → ¬Language S t +¬fun-scalar s (S ⇒ T) function = scalar-function s +¬fun-scalar s (S ⇒ T) (function-ok p) = scalar-function-ok s +¬fun-scalar s (S ⇒ T) (function-err p) = scalar-function-err s +¬fun-scalar s (S ∩ T) (p₁ , p₂) = ¬fun-scalar s T p₂ + +fun-function : ∀ {T} → FunType T → Language T function +fun-function (S ⇒ T) = function +fun-function (S ∩ T) = (fun-function S , fun-function T) + +srcⁿ-¬scalar : ∀ {S T t} (s : Scalar S) → Normal T → Language T (scalar s) → (¬Language (srcⁿ T) t) +srcⁿ-¬scalar s never (scalar ()) +srcⁿ-¬scalar s unknown p = never +srcⁿ-¬scalar s (S ⇒ T) (scalar ()) +srcⁿ-¬scalar s (S ∩ T) (p₁ , p₂) = CONTRADICTION (language-comp (scalar s) (fun-¬scalar s S) p₁) +srcⁿ-¬scalar s (S ∪ T) p = never + +src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) +src-¬scalar {T = nil} s p = never +src-¬scalar {T = T ⇒ U} s (scalar ()) +src-¬scalar {T = never} s (scalar ()) +src-¬scalar {T = unknown} s p = never +src-¬scalar {T = boolean} s p = never +src-¬scalar {T = number} s p = never +src-¬scalar {T = string} s p = never +src-¬scalar {T = T ∪ U} s p = srcⁿ-¬scalar s (normal (T ∪ U)) (<:-normalize (T ∪ U) (scalar s) p) +src-¬scalar {T = T ∩ U} s p = srcⁿ-¬scalar s (normal (T ∩ U)) (<:-normalize (T ∩ U) (scalar s) p) + +srcⁿ-unknown-≮: : ∀ {T U} → (Normal U) → (T ≮: srcⁿ U) → (U ≮: (T ⇒ unknown)) +srcⁿ-unknown-≮: never (witness t p q) = CONTRADICTION (language-comp t q unknown) +srcⁿ-unknown-≮: unknown (witness t p q) = witness (function-err t) unknown (function-err p) +srcⁿ-unknown-≮: (U ⇒ V) (witness t p q) = witness (function-err t) (function-err q) (function-err p) +srcⁿ-unknown-≮: (U ∩ V) (witness t p q) = witness (function-err t) (function-err-srcⁿ (U ∩ V) q) (function-err p) +srcⁿ-unknown-≮: (U ∪ V) (witness t p q) = witness (scalar V) (right (scalar V)) (function-scalar V) + +src-unknown-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ unknown)) +src-unknown-≮: {U = nil} (witness t p q) = witness (scalar nil) (scalar nil) (function-scalar nil) +src-unknown-≮: {U = T ⇒ U} (witness t p q) = witness (function-err t) (function-err q) (function-err p) +src-unknown-≮: {U = never} (witness t p q) = CONTRADICTION (language-comp t q unknown) +src-unknown-≮: {U = unknown} (witness t p q) = witness (function-err t) unknown (function-err p) +src-unknown-≮: {U = boolean} (witness t p q) = witness (scalar boolean) (scalar boolean) (function-scalar boolean) +src-unknown-≮: {U = number} (witness t p q) = witness (scalar number) (scalar number) (function-scalar number) +src-unknown-≮: {U = string} (witness t p q) = witness (scalar string) (scalar string) (function-scalar string) +src-unknown-≮: {U = T ∪ U} p = <:-trans-≮: (normalize-<: (T ∪ U)) (srcⁿ-unknown-≮: (normal (T ∪ U)) p) +src-unknown-≮: {U = T ∩ U} p = <:-trans-≮: (normalize-<: (T ∩ U)) (srcⁿ-unknown-≮: (normal (T ∩ U)) p) + +unknown-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ unknown)) → (U ≮: src T) +unknown-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) +unknown-src-≮: r (witness (function-ok (scalar s)) p (function-ok (scalar-scalar s () q))) +unknown-src-≮: r (witness (function-ok (function-ok _)) p (function-ok (scalar-function-ok ()))) +unknown-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) + +-- Properties of tgt +tgt-function-ok : ∀ {T t} → (Language (tgt T) t) → Language T (function-ok t) +tgt-function-ok {T = nil} (scalar ()) +tgt-function-ok {T = T₁ ⇒ T₂} p = function-ok p +tgt-function-ok {T = never} (scalar ()) +tgt-function-ok {T = unknown} p = unknown +tgt-function-ok {T = boolean} (scalar ()) +tgt-function-ok {T = number} (scalar ()) +tgt-function-ok {T = string} (scalar ()) +tgt-function-ok {T = T₁ ∪ T₂} (left p) = left (tgt-function-ok p) +tgt-function-ok {T = T₁ ∪ T₂} (right p) = right (tgt-function-ok p) +tgt-function-ok {T = T₁ ∩ T₂} (p₁ , p₂) = (tgt-function-ok p₁ , tgt-function-ok p₂) + +function-ok-tgt : ∀ {T t} → Language T (function-ok t) → (Language (tgt T) t) +function-ok-tgt (function-ok p) = p +function-ok-tgt (left p) = left (function-ok-tgt p) +function-ok-tgt (right p) = right (function-ok-tgt p) +function-ok-tgt (p₁ , p₂) = (function-ok-tgt p₁ , function-ok-tgt p₂) +function-ok-tgt unknown = unknown + +tgt-never-≮: : ∀ {T U} → (tgt T ≮: U) → (T ≮: (skalar ∪ (never ⇒ U))) +tgt-never-≮: (witness t p q) = witness (function-ok t) (tgt-function-ok p) (skalar-function-ok , function-ok q) + +never-tgt-≮: : ∀ {T U} → (T ≮: (skalar ∪ (never ⇒ U))) → (tgt T ≮: U) +never-tgt-≮: (witness (scalar s) p (q₁ , q₂)) = CONTRADICTION (≮:-refl (witness (scalar s) (skalar-scalar s) q₁)) +never-tgt-≮: (witness function p (q₁ , scalar-function ())) +never-tgt-≮: (witness (function-ok t) p (q₁ , function-ok q₂)) = witness t (function-ok-tgt p) q₂ +never-tgt-≮: (witness (function-err (scalar s)) p (q₁ , function-err (scalar ()))) + +src-tgtᶠ-<: : ∀ {T U V} → (FunType T) → (U <: src T) → (tgt T <: V) → (T <: (U ⇒ V)) +src-tgtᶠ-<: T p q (scalar s) r = CONTRADICTION (language-comp (scalar s) (fun-¬scalar s T) r) +src-tgtᶠ-<: T p q function r = function +src-tgtᶠ-<: T p q (function-ok s) r = function-ok (q s (function-ok-tgt r)) +src-tgtᶠ-<: T p q (function-err s) r = function-err (<:-impl-⊇ p s (src-¬function-err r)) + + diff --git a/prototyping/Properties/StrictMode.agda b/prototyping/Properties/StrictMode.agda index fd2cf2f2..69e9131c 100644 --- a/prototyping/Properties/StrictMode.agda +++ b/prototyping/Properties/StrictMode.agda @@ -11,7 +11,8 @@ open import Luau.StrictMode using (Warningᴱ; Warningᴮ; Warningᴼ; Warning open import Luau.Substitution using (_[_/_]ᴮ; _[_/_]ᴱ; _[_/_]ᴮunless_; var_[_/_]ᴱwhenever_) open import Luau.Subtyping using (_≮:_; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) open import Luau.Syntax using (Expr; yes; var; val; var_∈_; _⟨_⟩∈_; _$_; addr; number; bool; string; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name; ==; ~=) -open import Luau.Type using (Type; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; src; tgt; _≡ᵀ_; _≡ᴹᵀ_) +open import Luau.FunctionTypes using (src; tgt) +open import Luau.Type using (Type; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; _≡ᵀ_; _≡ᴹᵀ_) open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orUnknown; srcBinOp; tgtBinOp) open import Luau.Var using (_≡ⱽ_) open import Luau.Addr using (_≡ᴬ_) @@ -22,7 +23,8 @@ open import Properties.Equality using (_≢_; sym; cong; trans; subst₁) open import Properties.Dec using (Dec; yes; no) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Functions using (_∘_) -open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; never-tgt-≮:; tgt-never-≮:; src-unknown-≮:; unknown-src-≮:; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never) +open import Properties.FunctionTypes using (never-tgt-≮:; tgt-never-≮:; src-unknown-≮:; unknown-src-≮:) +open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never) open import Properties.TypeCheck using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) open import Luau.OpSem using (_⟦_⟧_⟶_; _⊢_⟶*_⊣_; _⊢_⟶ᴮ_⊣_; _⊢_⟶ᴱ_⊣_; app₁; app₂; function; beta; return; block; done; local; subst; binOp₀; binOp₁; binOp₂; refl; step; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.RuntimeError using (BinOpError; RuntimeErrorᴱ; RuntimeErrorᴮ; FunctionMismatch; BinOpMismatch₁; BinOpMismatch₂; UnboundVariable; SEGV; app₁; app₂; bin₁; bin₂; block; local; return; +; -; *; /; <; >; <=; >=; ··) diff --git a/prototyping/Properties/Subtyping.agda b/prototyping/Properties/Subtyping.agda index b713eaf7..34e6691f 100644 --- a/prototyping/Properties/Subtyping.agda +++ b/prototyping/Properties/Subtyping.agda @@ -4,9 +4,10 @@ module Properties.Subtyping where open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) +open import FFI.Data.Maybe using (Maybe; just; nothing) open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) -open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; src; tgt) -open import Properties.Contradiction using (CONTRADICTION; ¬) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; skalar) +open import Properties.Contradiction using (CONTRADICTION; ¬; ⊥) open import Properties.Equality using (_≢_) open import Properties.Functions using (_∘_) open import Properties.Product using (_×_; _,_) @@ -19,28 +20,28 @@ dec-language nil (scalar string) = Left (scalar-scalar string nil (λ ())) dec-language nil (scalar nil) = Right (scalar nil) dec-language nil function = Left (scalar-function nil) dec-language nil (function-ok t) = Left (scalar-function-ok nil) -dec-language nil (function-err t) = Right (scalar-function-err nil) +dec-language nil (function-err t) = Left (scalar-function-err nil) dec-language boolean (scalar number) = Left (scalar-scalar number boolean (λ ())) dec-language boolean (scalar boolean) = Right (scalar boolean) dec-language boolean (scalar string) = Left (scalar-scalar string boolean (λ ())) dec-language boolean (scalar nil) = Left (scalar-scalar nil boolean (λ ())) dec-language boolean function = Left (scalar-function boolean) dec-language boolean (function-ok t) = Left (scalar-function-ok boolean) -dec-language boolean (function-err t) = Right (scalar-function-err boolean) +dec-language boolean (function-err t) = Left (scalar-function-err boolean) dec-language number (scalar number) = Right (scalar number) dec-language number (scalar boolean) = Left (scalar-scalar boolean number (λ ())) dec-language number (scalar string) = Left (scalar-scalar string number (λ ())) dec-language number (scalar nil) = Left (scalar-scalar nil number (λ ())) dec-language number function = Left (scalar-function number) dec-language number (function-ok t) = Left (scalar-function-ok number) -dec-language number (function-err t) = Right (scalar-function-err number) +dec-language number (function-err t) = Left (scalar-function-err number) dec-language string (scalar number) = Left (scalar-scalar number string (λ ())) dec-language string (scalar boolean) = Left (scalar-scalar boolean string (λ ())) dec-language string (scalar string) = Right (scalar string) dec-language string (scalar nil) = Left (scalar-scalar nil string (λ ())) dec-language string function = Left (scalar-function string) dec-language string (function-ok t) = Left (scalar-function-ok string) -dec-language string (function-err t) = Right (scalar-function-err string) +dec-language string (function-err t) = Left (scalar-function-err string) dec-language (T₁ ⇒ T₂) (scalar s) = Left (function-scalar s) dec-language (T₁ ⇒ T₂) function = Right function dec-language (T₁ ⇒ T₂) (function-ok t) = mapLR function-ok function-ok (dec-language T₂ t) @@ -73,6 +74,11 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp <:-impl-¬≮: : ∀ {T U} → (T <: U) → ¬(T ≮: U) <:-impl-¬≮: p (witness t q r) = language-comp t r (p t q) +<:-impl-⊇ : ∀ {T U} → (T <: U) → ∀ t → ¬Language U t → ¬Language T t +<:-impl-⊇ {T} p t q with dec-language T t +<:-impl-⊇ {_} p t q | Left r = r +<:-impl-⊇ {_} p t q | Right r = CONTRADICTION (language-comp t q (p t r)) + -- reflexivity ≮:-refl : ∀ {T} → ¬(T ≮: T) ≮:-refl (witness t p q) = language-comp t q p @@ -91,10 +97,162 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp ≮:-trans {T = T} (witness t p q) = mapLR (witness t p) (λ z → witness t z q) (dec-language T t) <:-trans : ∀ {S T U} → (S <: T) → (T <: U) → (S <: U) -<:-trans p q = ¬≮:-impl-<: (cond (<:-impl-¬≮: p) (<:-impl-¬≮: q) ∘ ≮:-trans) +<:-trans p q t r = q t (p t r) + +<:-trans-≮: : ∀ {S T U} → (S <: T) → (S ≮: U) → (T ≮: U) +<:-trans-≮: p (witness t q r) = witness t (p t q) r + +≮:-trans-<: : ∀ {S T U} → (S ≮: U) → (T <: U) → (S ≮: T) +≮:-trans-<: (witness t p q) r = witness t p (<:-impl-⊇ r t q) + +-- Properties of union + +<:-union : ∀ {R S T U} → (R <: T) → (S <: U) → ((R ∪ S) <: (T ∪ U)) +<:-union p q t (left r) = left (p t r) +<:-union p q t (right r) = right (q t r) + +<:-∪-left : ∀ {S T} → S <: (S ∪ T) +<:-∪-left t p = left p + +<:-∪-right : ∀ {S T} → T <: (S ∪ T) +<:-∪-right t p = right p + +<:-∪-lub : ∀ {S T U} → (S <: U) → (T <: U) → ((S ∪ T) <: U) +<:-∪-lub p q t (left r) = p t r +<:-∪-lub p q t (right r) = q t r + +<:-∪-symm : ∀ {T U} → (T ∪ U) <: (U ∪ T) +<:-∪-symm t (left p) = right p +<:-∪-symm t (right p) = left p + +<:-∪-assocl : ∀ {S T U} → (S ∪ (T ∪ U)) <: ((S ∪ T) ∪ U) +<:-∪-assocl t (left p) = left (left p) +<:-∪-assocl t (right (left p)) = left (right p) +<:-∪-assocl t (right (right p)) = right p + +<:-∪-assocr : ∀ {S T U} → ((S ∪ T) ∪ U) <: (S ∪ (T ∪ U)) +<:-∪-assocr t (left (left p)) = left p +<:-∪-assocr t (left (right p)) = right (left p) +<:-∪-assocr t (right p) = right (right p) + +≮:-∪-left : ∀ {S T U} → (S ≮: U) → ((S ∪ T) ≮: U) +≮:-∪-left (witness t p q) = witness t (left p) q + +≮:-∪-right : ∀ {S T U} → (T ≮: U) → ((S ∪ T) ≮: U) +≮:-∪-right (witness t p q) = witness t (right p) q + +-- Properties of intersection + +<:-intersect : ∀ {R S T U} → (R <: T) → (S <: U) → ((R ∩ S) <: (T ∩ U)) +<:-intersect p q t (r₁ , r₂) = (p t r₁ , q t r₂) + +<:-∩-left : ∀ {S T} → (S ∩ T) <: S +<:-∩-left t (p , _) = p + +<:-∩-right : ∀ {S T} → (S ∩ T) <: T +<:-∩-right t (_ , p) = p + +<:-∩-glb : ∀ {S T U} → (S <: T) → (S <: U) → (S <: (T ∩ U)) +<:-∩-glb p q t r = (p t r , q t r) + +<:-∩-symm : ∀ {T U} → (T ∩ U) <: (U ∩ T) +<:-∩-symm t (p₁ , p₂) = (p₂ , p₁) + +≮:-∩-left : ∀ {S T U} → (S ≮: T) → (S ≮: (T ∩ U)) +≮:-∩-left (witness t p q) = witness t p (left q) + +≮:-∩-right : ∀ {S T U} → (S ≮: U) → (S ≮: (T ∩ U)) +≮:-∩-right (witness t p q) = witness t p (right q) + +-- Distribution properties +<:-∩-distl-∪ : ∀ {S T U} → (S ∩ (T ∪ U)) <: ((S ∩ T) ∪ (S ∩ U)) +<:-∩-distl-∪ t (p₁ , left p₂) = left (p₁ , p₂) +<:-∩-distl-∪ t (p₁ , right p₂) = right (p₁ , p₂) + +∩-distl-∪-<: : ∀ {S T U} → ((S ∩ T) ∪ (S ∩ U)) <: (S ∩ (T ∪ U)) +∩-distl-∪-<: t (left (p₁ , p₂)) = (p₁ , left p₂) +∩-distl-∪-<: t (right (p₁ , p₂)) = (p₁ , right p₂) + +<:-∩-distr-∪ : ∀ {S T U} → ((S ∪ T) ∩ U) <: ((S ∩ U) ∪ (T ∩ U)) +<:-∩-distr-∪ t (left p₁ , p₂) = left (p₁ , p₂) +<:-∩-distr-∪ t (right p₁ , p₂) = right (p₁ , p₂) + +∩-distr-∪-<: : ∀ {S T U} → ((S ∩ U) ∪ (T ∩ U)) <: ((S ∪ T) ∩ U) +∩-distr-∪-<: t (left (p₁ , p₂)) = (left p₁ , p₂) +∩-distr-∪-<: t (right (p₁ , p₂)) = (right p₁ , p₂) + +<:-∪-distl-∩ : ∀ {S T U} → (S ∪ (T ∩ U)) <: ((S ∪ T) ∩ (S ∪ U)) +<:-∪-distl-∩ t (left p) = (left p , left p) +<:-∪-distl-∩ t (right (p₁ , p₂)) = (right p₁ , right p₂) + +∪-distl-∩-<: : ∀ {S T U} → ((S ∪ T) ∩ (S ∪ U)) <: (S ∪ (T ∩ U)) +∪-distl-∩-<: t (left p₁ , p₂) = left p₁ +∪-distl-∩-<: t (right p₁ , left p₂) = left p₂ +∪-distl-∩-<: t (right p₁ , right p₂) = right (p₁ , p₂) + +<:-∪-distr-∩ : ∀ {S T U} → ((S ∩ T) ∪ U) <: ((S ∪ U) ∩ (T ∪ U)) +<:-∪-distr-∩ t (left (p₁ , p₂)) = left p₁ , left p₂ +<:-∪-distr-∩ t (right p) = (right p , right p) + +∪-distr-∩-<: : ∀ {S T U} → ((S ∪ U) ∩ (T ∪ U)) <: ((S ∩ T) ∪ U) +∪-distr-∩-<: t (left p₁ , left p₂) = left (p₁ , p₂) +∪-distr-∩-<: t (left p₁ , right p₂) = right p₂ +∪-distr-∩-<: t (right p₁ , p₂) = right p₁ + +-- Properties of functions +<:-function : ∀ {R S T U} → (R <: S) → (T <: U) → (S ⇒ T) <: (R ⇒ U) +<:-function p q function function = function +<:-function p q (function-ok t) (function-ok r) = function-ok (q t r) +<:-function p q (function-err s) (function-err r) = function-err (<:-impl-⊇ p s r) + +<:-function-∩-∪ : ∀ {R S T U} → ((R ⇒ T) ∩ (S ⇒ U)) <: ((R ∪ S) ⇒ (T ∪ U)) +<:-function-∩-∪ function (function , function) = function +<:-function-∩-∪ (function-ok t) (function-ok p₁ , function-ok p₂) = function-ok (right p₂) +<:-function-∩-∪ (function-err _) (function-err p₁ , function-err q₂) = function-err (p₁ , q₂) + +<:-function-∩ : ∀ {S T U} → ((S ⇒ T) ∩ (S ⇒ U)) <: (S ⇒ (T ∩ U)) +<:-function-∩ function (function , function) = function +<:-function-∩ (function-ok t) (function-ok p₁ , function-ok p₂) = function-ok (p₁ , p₂) +<:-function-∩ (function-err s) (function-err p₁ , function-err p₂) = function-err p₂ + +<:-function-∪ : ∀ {R S T U} → ((R ⇒ S) ∪ (T ⇒ U)) <: ((R ∩ T) ⇒ (S ∪ U)) +<:-function-∪ function (left function) = function +<:-function-∪ (function-ok t) (left (function-ok p)) = function-ok (left p) +<:-function-∪ (function-err s) (left (function-err p)) = function-err (left p) +<:-function-∪ (scalar s) (left (scalar ())) +<:-function-∪ function (right function) = function +<:-function-∪ (function-ok t) (right (function-ok p)) = function-ok (right p) +<:-function-∪ (function-err s) (right (function-err x)) = function-err (right x) +<:-function-∪ (scalar s) (right (scalar ())) + +<:-function-∪-∩ : ∀ {R S T U} → ((R ∩ S) ⇒ (T ∪ U)) <: ((R ⇒ T) ∪ (S ⇒ U)) +<:-function-∪-∩ function function = left function +<:-function-∪-∩ (function-ok t) (function-ok (left p)) = left (function-ok p) +<:-function-∪-∩ (function-ok t) (function-ok (right p)) = right (function-ok p) +<:-function-∪-∩ (function-err s) (function-err (left p)) = left (function-err p) +<:-function-∪-∩ (function-err s) (function-err (right p)) = right (function-err p) + +≮:-function-left : ∀ {R S T U} → (R ≮: S) → (S ⇒ T) ≮: (R ⇒ U) +≮:-function-left (witness t p q) = witness (function-err t) (function-err q) (function-err p) + +≮:-function-right : ∀ {R S T U} → (T ≮: U) → (S ⇒ T) ≮: (R ⇒ U) +≮:-function-right (witness t p q) = witness (function-ok t) (function-ok p) (function-ok q) -- Properties of scalars -skalar = number ∪ (string ∪ (nil ∪ boolean)) +skalar-function-ok : ∀ {t} → (¬Language skalar (function-ok t)) +skalar-function-ok = (scalar-function-ok number , (scalar-function-ok string , (scalar-function-ok nil , scalar-function-ok boolean))) + +scalar-<: : ∀ {S T} → (s : Scalar S) → Language T (scalar s) → (S <: T) +scalar-<: number p (scalar number) (scalar number) = p +scalar-<: boolean p (scalar boolean) (scalar boolean) = p +scalar-<: string p (scalar string) (scalar string) = p +scalar-<: nil p (scalar nil) (scalar nil) = p + +scalar-∩-function-<:-never : ∀ {S T U} → (Scalar S) → ((T ⇒ U) ∩ S) <: never +scalar-∩-function-<:-never number .(scalar number) (() , scalar number) +scalar-∩-function-<:-never boolean .(scalar boolean) (() , scalar boolean) +scalar-∩-function-<:-never string .(scalar string) (() , scalar string) +scalar-∩-function-<:-never nil .(scalar nil) (() , scalar nil) function-≮:-scalar : ∀ {S T U} → (Scalar U) → ((S ⇒ T) ≮: U) function-≮:-scalar s = witness function function (scalar-function s) @@ -111,28 +269,8 @@ scalar-≮:-never s = witness (scalar s) (scalar s) never scalar-≢-impl-≮: : ∀ {T U} → (Scalar T) → (Scalar U) → (T ≢ U) → (T ≮: U) scalar-≢-impl-≮: s₁ s₂ p = witness (scalar s₁) (scalar s₁) (scalar-scalar s₁ s₂ p) --- Properties of tgt -tgt-function-ok : ∀ {T t} → (Language (tgt T) t) → Language T (function-ok t) -tgt-function-ok {T = nil} (scalar ()) -tgt-function-ok {T = T₁ ⇒ T₂} p = function-ok p -tgt-function-ok {T = never} (scalar ()) -tgt-function-ok {T = unknown} p = unknown -tgt-function-ok {T = boolean} (scalar ()) -tgt-function-ok {T = number} (scalar ()) -tgt-function-ok {T = string} (scalar ()) -tgt-function-ok {T = T₁ ∪ T₂} (left p) = left (tgt-function-ok p) -tgt-function-ok {T = T₁ ∪ T₂} (right p) = right (tgt-function-ok p) -tgt-function-ok {T = T₁ ∩ T₂} (p₁ , p₂) = (tgt-function-ok p₁ , tgt-function-ok p₂) - -function-ok-tgt : ∀ {T t} → Language T (function-ok t) → (Language (tgt T) t) -function-ok-tgt (function-ok p) = p -function-ok-tgt (left p) = left (function-ok-tgt p) -function-ok-tgt (right p) = right (function-ok-tgt p) -function-ok-tgt (p₁ , p₂) = (function-ok-tgt p₁ , function-ok-tgt p₂) -function-ok-tgt unknown = unknown - -skalar-function-ok : ∀ {t} → (¬Language skalar (function-ok t)) -skalar-function-ok = (scalar-function-ok number , (scalar-function-ok string , (scalar-function-ok nil , scalar-function-ok boolean))) +scalar-≢-∩-<:-never : ∀ {T U V} → (Scalar T) → (Scalar U) → (T ≢ U) → (T ∩ U) <: V +scalar-≢-∩-<:-never s t p u (scalar s₁ , scalar s₂) = CONTRADICTION (p refl) skalar-scalar : ∀ {T} (s : Scalar T) → (Language skalar (scalar s)) skalar-scalar number = left (scalar number) @@ -140,72 +278,6 @@ skalar-scalar boolean = right (right (right (scalar boolean))) skalar-scalar string = right (left (scalar string)) skalar-scalar nil = right (right (left (scalar nil))) -tgt-never-≮: : ∀ {T U} → (tgt T ≮: U) → (T ≮: (skalar ∪ (never ⇒ U))) -tgt-never-≮: (witness t p q) = witness (function-ok t) (tgt-function-ok p) (skalar-function-ok , function-ok q) - -never-tgt-≮: : ∀ {T U} → (T ≮: (skalar ∪ (never ⇒ U))) → (tgt T ≮: U) -never-tgt-≮: (witness (scalar s) p (q₁ , q₂)) = CONTRADICTION (≮:-refl (witness (scalar s) (skalar-scalar s) q₁)) -never-tgt-≮: (witness function p (q₁ , scalar-function ())) -never-tgt-≮: (witness (function-ok t) p (q₁ , function-ok q₂)) = witness t (function-ok-tgt p) q₂ -never-tgt-≮: (witness (function-err (scalar s)) p (q₁ , function-err (scalar ()))) - --- Properties of src -function-err-src : ∀ {T t} → (¬Language (src T) t) → Language T (function-err t) -function-err-src {T = nil} never = scalar-function-err nil -function-err-src {T = T₁ ⇒ T₂} p = function-err p -function-err-src {T = never} (scalar-scalar number () p) -function-err-src {T = never} (scalar-function-ok ()) -function-err-src {T = unknown} never = unknown -function-err-src {T = boolean} p = scalar-function-err boolean -function-err-src {T = number} p = scalar-function-err number -function-err-src {T = string} p = scalar-function-err string -function-err-src {T = T₁ ∪ T₂} (left p) = left (function-err-src p) -function-err-src {T = T₁ ∪ T₂} (right p) = right (function-err-src p) -function-err-src {T = T₁ ∩ T₂} (p₁ , p₂) = function-err-src p₁ , function-err-src p₂ - -¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) -¬function-err-src {T = nil} (scalar ()) -¬function-err-src {T = T₁ ⇒ T₂} p = function-err p -¬function-err-src {T = never} unknown = never -¬function-err-src {T = unknown} (scalar ()) -¬function-err-src {T = boolean} (scalar ()) -¬function-err-src {T = number} (scalar ()) -¬function-err-src {T = string} (scalar ()) -¬function-err-src {T = T₁ ∪ T₂} (p₁ , p₂) = (¬function-err-src p₁ , ¬function-err-src p₂) -¬function-err-src {T = T₁ ∩ T₂} (left p) = left (¬function-err-src p) -¬function-err-src {T = T₁ ∩ T₂} (right p) = right (¬function-err-src p) - -src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) -src-¬function-err {T = nil} p = never -src-¬function-err {T = T₁ ⇒ T₂} (function-err p) = p -src-¬function-err {T = never} (scalar-function-err ()) -src-¬function-err {T = unknown} p = never -src-¬function-err {T = boolean} p = never -src-¬function-err {T = number} p = never -src-¬function-err {T = string} p = never -src-¬function-err {T = T₁ ∪ T₂} (left p) = left (src-¬function-err p) -src-¬function-err {T = T₁ ∪ T₂} (right p) = right (src-¬function-err p) -src-¬function-err {T = T₁ ∩ T₂} (p₁ , p₂) = (src-¬function-err p₁ , src-¬function-err p₂) - -src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) -src-¬scalar number (scalar number) = never -src-¬scalar boolean (scalar boolean) = never -src-¬scalar string (scalar string) = never -src-¬scalar nil (scalar nil) = never -src-¬scalar s (left p) = left (src-¬scalar s p) -src-¬scalar s (right p) = right (src-¬scalar s p) -src-¬scalar s (p₁ , p₂) = (src-¬scalar s p₁ , src-¬scalar s p₂) -src-¬scalar s unknown = never - -src-unknown-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ unknown)) -src-unknown-≮: (witness t p q) = witness (function-err t) (function-err-src q) (¬function-err-src p) - -unknown-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ unknown)) → (U ≮: src T) -unknown-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) -unknown-src-≮: r (witness (function-ok (scalar s)) p (function-ok (scalar-scalar s () q))) -unknown-src-≮: r (witness (function-ok (function-ok _)) p (function-ok (scalar-function-ok ()))) -unknown-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) - -- Properties of unknown and never unknown-≮: : ∀ {T U} → (T ≮: U) → (unknown ≮: U) unknown-≮: (witness t p q) = witness t unknown q @@ -219,6 +291,28 @@ unknown-≮:-never = witness (scalar nil) unknown never function-≮:-never : ∀ {T U} → ((T ⇒ U) ≮: never) function-≮:-never = witness function function never +<:-never : ∀ {T} → (never <: T) +<:-never t (scalar ()) + +≮:-never-left : ∀ {S T U} → (S <: (T ∪ U)) → (S ≮: T) → (S ∩ U) ≮: never +≮:-never-left p (witness t q₁ q₂) with p t q₁ +≮:-never-left p (witness t q₁ q₂) | left r = CONTRADICTION (language-comp t q₂ r) +≮:-never-left p (witness t q₁ q₂) | right r = witness t (q₁ , r) never + +≮:-never-right : ∀ {S T U} → (S <: (T ∪ U)) → (S ≮: U) → (S ∩ T) ≮: never +≮:-never-right p (witness t q₁ q₂) with p t q₁ +≮:-never-right p (witness t q₁ q₂) | left r = witness t (q₁ , r) never +≮:-never-right p (witness t q₁ q₂) | right r = CONTRADICTION (language-comp t q₂ r) + +<:-unknown : ∀ {T} → (T <: unknown) +<:-unknown t p = unknown + +<:-everything : unknown <: ((never ⇒ unknown) ∪ skalar) +<:-everything (scalar s) p = right (skalar-scalar s) +<:-everything function p = left function +<:-everything (function-ok t) p = left (function-ok unknown) +<:-everything (function-err s) p = left (function-err never) + -- A Gentle Introduction To Semantic Subtyping (https://www.cduce.org/papers/gentle.pdf) -- defines a "set-theoretic" model (sec 2.5) -- Unfortunately we don't quite have this property, due to uninhabited types, @@ -234,13 +328,21 @@ _⊗_ : ∀ {A B : Set} → (A → Set) → (B → Set) → ((A × B) → Set) Comp : ∀ {A : Set} → (A → Set) → (A → Set) Comp P a = ¬(P a) +Lift : ∀ {A : Set} → (A → Set) → (Maybe A → Set) +Lift P nothing = ⊥ +Lift P (just a) = P a + set-theoretic-if : ∀ {S₁ T₁ S₂ T₂} → -- This is the "if" part of being a set-theoretic model + -- though it uses the definition from Frisch's thesis + -- rather than from the Gentle Introduction. The difference + -- being the presence of Lift, (written D_Ω in Defn 4.2 of + -- https://www.cduce.org/papers/frisch_phd.pdf). (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) → - (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Language T₁)) → Q ⊆ Comp((Language S₂) ⊗ Comp(Language T₂))) + (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) → Q ⊆ Comp((Language S₂) ⊗ Comp(Lift(Language T₂)))) -set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , u) Qtu (S₂t , ¬T₂u) = q (t , u) Qtu (S₁t , ¬T₁u) where +set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , just u) Qtu (S₂t , ¬T₂u) = q (t , just u) Qtu (S₁t , ¬T₁u) where S₁t : Language S₁ t S₁t with dec-language S₁ t @@ -252,6 +354,14 @@ set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , u) Qtu (S₂t , ¬T₂u) ¬T₁u T₁u with p (function-ok u) (function-ok T₁u) ¬T₁u T₁u | function-ok T₂u = ¬T₂u T₂u +set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , nothing) Qt- (S₂t , _) = q (t , nothing) Qt- (S₁t , λ ()) where + + S₁t : Language S₁ t + S₁t with dec-language S₁ t + S₁t | Left ¬S₁t with p (function-err t) (function-err ¬S₁t) + S₁t | Left ¬S₁t | function-err ¬S₂t = CONTRADICTION (language-comp t ¬S₂t S₂t) + S₁t | Right r = r + not-quite-set-theoretic-only-if : ∀ {S₁ T₁ S₂ T₂} → -- We don't quite have that this is a set-theoretic model @@ -260,32 +370,33 @@ not-quite-set-theoretic-only-if : ∀ {S₁ T₁ S₂ T₂} → ∀ s₂ t₂ → Language S₂ s₂ → ¬Language T₂ t₂ → -- This is the "only if" part of being a set-theoretic model - (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Language T₁)) → Q ⊆ Comp((Language S₂) ⊗ Comp(Language T₂))) → + (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) → Q ⊆ Comp((Language S₂) ⊗ Comp(Lift(Language T₂)))) → (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) not-quite-set-theoretic-only-if {S₁} {T₁} {S₂} {T₂} s₂ t₂ S₂s₂ ¬T₂t₂ p = r where - Q : (Tree × Tree) → Set - Q (t , u) = Either (¬Language S₁ t) (Language T₁ u) + Q : (Tree × Maybe Tree) → Set + Q (t , just u) = Either (¬Language S₁ t) (Language T₁ u) + Q (t , nothing) = ¬Language S₁ t - q : Q ⊆ Comp((Language S₁) ⊗ Comp(Language T₁)) - q (t , u) (Left ¬S₁t) (S₁t , ¬T₁u) = language-comp t ¬S₁t S₁t - q (t , u) (Right T₂u) (S₁t , ¬T₁u) = ¬T₁u T₂u + q : Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) + q (t , just u) (Left ¬S₁t) (S₁t , ¬T₁u) = language-comp t ¬S₁t S₁t + q (t , just u) (Right T₂u) (S₁t , ¬T₁u) = ¬T₁u T₂u + q (t , nothing) ¬S₁t (S₁t , _) = language-comp t ¬S₁t S₁t r : Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂) r function function = function - r (function-err t) (function-err ¬S₁t) with dec-language S₂ t - r (function-err t) (function-err ¬S₁t) | Left ¬S₂t = function-err ¬S₂t - r (function-err t) (function-err ¬S₁t) | Right S₂t = CONTRADICTION (p Q q (t , t₂) (Left ¬S₁t) (S₂t , language-comp t₂ ¬T₂t₂)) + r (function-err s) (function-err ¬S₁s) with dec-language S₂ s + r (function-err s) (function-err ¬S₁s) | Left ¬S₂s = function-err ¬S₂s + r (function-err s) (function-err ¬S₁s) | Right S₂s = CONTRADICTION (p Q q (s , nothing) ¬S₁s (S₂s , λ ())) r (function-ok t) (function-ok T₁t) with dec-language T₂ t - r (function-ok t) (function-ok T₁t) | Left ¬T₂t = CONTRADICTION (p Q q (s₂ , t) (Right T₁t) (S₂s₂ , language-comp t ¬T₂t)) + r (function-ok t) (function-ok T₁t) | Left ¬T₂t = CONTRADICTION (p Q q (s₂ , just t) (Right T₁t) (S₂s₂ , language-comp t ¬T₂t)) r (function-ok t) (function-ok T₁t) | Right T₂t = function-ok T₂t -- A counterexample when the argument type is empty. -set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language never) ⊗ Comp(Language number)) → Q ⊆ Comp((Language never) ⊗ Comp(Language string))) +set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language never) ⊗ Comp(Lift(Language number))) → Q ⊆ Comp((Language never) ⊗ Comp(Lift(Language string)))) set-theoretic-counterexample-one Q q ((scalar s) , u) Qtu (scalar () , p) -set-theoretic-counterexample-one Q q ((function-err t) , u) Qtu (scalar-function-err () , p) set-theoretic-counterexample-two : (never ⇒ number) ≮: (never ⇒ string) set-theoretic-counterexample-two = witness diff --git a/prototyping/Properties/TypeCheck.agda b/prototyping/Properties/TypeCheck.agda index 0726a4be..37fbeda5 100644 --- a/prototyping/Properties/TypeCheck.agda +++ b/prototyping/Properties/TypeCheck.agda @@ -8,7 +8,8 @@ open import FFI.Data.Maybe using (Maybe; just; nothing) open import FFI.Data.Either using (Either) open import Luau.TypeCheck using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orUnknown; tgtBinOp) open import Luau.Syntax using (Block; Expr; Value; BinaryOperator; yes; nil; addr; number; bool; string; val; var; binexp; _$_; function_is_end; block_is_end; _∙_; return; done; local_←_; _⟨_⟩; _⟨_⟩∈_; var_∈_; name; fun; arg; +; -; *; /; <; >; ==; ~=; <=; >=) -open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_; src; tgt) +open import Luau.FunctionTypes using (src; tgt) +open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_) open import Luau.RuntimeType using (RuntimeType; nil; number; function; string; valueType) open import Luau.VarCtxt using (VarCtxt; ∅; _↦_; _⊕_↦_; _⋒_; _⊝_) renaming (_[_] to _[_]ⱽ) open import Luau.Addr using (Addr) diff --git a/prototyping/Properties/TypeNormalization.agda b/prototyping/Properties/TypeNormalization.agda new file mode 100644 index 00000000..299f648c --- /dev/null +++ b/prototyping/Properties/TypeNormalization.agda @@ -0,0 +1,376 @@ +{-# OPTIONS --rewriting #-} + +module Properties.TypeNormalization where + +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) +open import Luau.Subtyping using (scalar-function-err) +open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_; _∪ᶠ_; _∪ⁿˢ_; _∩ⁿˢ_; normalize) +open import Luau.Subtyping using (_<:_) +open import Properties.Subtyping using (<:-trans; <:-refl; <:-unknown; <:-never; <:-∪-left; <:-∪-right; <:-∪-lub; <:-∩-left; <:-∩-right; <:-∩-glb; <:-∩-symm; <:-function; <:-function-∪-∩; <:-function-∩-∪; <:-function-∪; <:-everything; <:-union; <:-∪-assocl; <:-∪-assocr; <:-∪-symm; <:-intersect; ∪-distl-∩-<:; ∪-distr-∩-<:; <:-∪-distr-∩; <:-∪-distl-∩; ∩-distl-∪-<:; <:-∩-distl-∪; <:-∩-distr-∪; scalar-∩-function-<:-never; scalar-≢-∩-<:-never) + +-- Notmal forms for types +data FunType : Type → Set +data Normal : Type → Set + +data FunType where + _⇒_ : ∀ {S T} → Normal S → Normal T → FunType (S ⇒ T) + _∩_ : ∀ {F G} → FunType F → FunType G → FunType (F ∩ G) + +data Normal where + never : Normal never + unknown : Normal unknown + _⇒_ : ∀ {S T} → Normal S → Normal T → Normal (S ⇒ T) + _∩_ : ∀ {F G} → FunType F → FunType G → Normal (F ∩ G) + _∪_ : ∀ {S T} → Normal S → Scalar T → Normal (S ∪ T) + +data OptScalar : Type → Set where + never : OptScalar never + number : OptScalar number + boolean : OptScalar boolean + string : OptScalar string + nil : OptScalar nil + +-- Normalization produces normal types +normal : ∀ T → Normal (normalize T) +normalᶠ : ∀ {F} → FunType F → Normal F +normal-∪ⁿ : ∀ {S T} → Normal S → Normal T → Normal (S ∪ⁿ T) +normal-∩ⁿ : ∀ {S T} → Normal S → Normal T → Normal (S ∩ⁿ T) +normal-∪ⁿˢ : ∀ {S T} → Normal S → OptScalar T → Normal (S ∪ⁿˢ T) +normal-∩ⁿˢ : ∀ {S T} → Normal S → Scalar T → OptScalar (S ∩ⁿˢ T) +normal-∪ᶠ : ∀ {F G} → FunType F → FunType G → FunType (F ∪ᶠ G) + +normal nil = never ∪ nil +normal (S ⇒ T) = normalᶠ ((normal S) ⇒ (normal T)) +normal never = never +normal unknown = unknown +normal boolean = never ∪ boolean +normal number = never ∪ number +normal string = never ∪ string +normal (S ∪ T) = normal-∪ⁿ (normal S) (normal T) +normal (S ∩ T) = normal-∩ⁿ (normal S) (normal T) + +normalᶠ (S ⇒ T) = S ⇒ T +normalᶠ (F ∩ G) = F ∩ G + +normal-∪ⁿ S (T₁ ∪ T₂) = (normal-∪ⁿ S T₁) ∪ T₂ +normal-∪ⁿ S never = S +normal-∪ⁿ S unknown = unknown +normal-∪ⁿ never (T ⇒ U) = T ⇒ U +normal-∪ⁿ never (G₁ ∩ G₂) = G₁ ∩ G₂ +normal-∪ⁿ unknown (T ⇒ U) = unknown +normal-∪ⁿ unknown (G₁ ∩ G₂) = unknown +normal-∪ⁿ (R ⇒ S) (T ⇒ U) = normalᶠ (normal-∪ᶠ (R ⇒ S) (T ⇒ U)) +normal-∪ⁿ (R ⇒ S) (G₁ ∩ G₂) = normalᶠ (normal-∪ᶠ (R ⇒ S) (G₁ ∩ G₂)) +normal-∪ⁿ (F₁ ∩ F₂) (T ⇒ U) = normalᶠ (normal-∪ᶠ (F₁ ∩ F₂) (T ⇒ U)) +normal-∪ⁿ (F₁ ∩ F₂) (G₁ ∩ G₂) = normalᶠ (normal-∪ᶠ (F₁ ∩ F₂) (G₁ ∩ G₂)) +normal-∪ⁿ (S₁ ∪ S₂) (T₁ ⇒ T₂) = normal-∪ⁿ S₁ (T₁ ⇒ T₂) ∪ S₂ +normal-∪ⁿ (S₁ ∪ S₂) (G₁ ∩ G₂) = normal-∪ⁿ S₁ (G₁ ∩ G₂) ∪ S₂ + +normal-∩ⁿ S never = never +normal-∩ⁿ S unknown = S +normal-∩ⁿ S (T ∪ U) = normal-∪ⁿˢ (normal-∩ⁿ S T) (normal-∩ⁿˢ S U ) +normal-∩ⁿ never (T ⇒ U) = never +normal-∩ⁿ unknown (T ⇒ U) = T ⇒ U +normal-∩ⁿ (R ⇒ S) (T ⇒ U) = (R ⇒ S) ∩ (T ⇒ U) +normal-∩ⁿ (R ∩ S) (T ⇒ U) = (R ∩ S) ∩ (T ⇒ U) +normal-∩ⁿ (R ∪ S) (T ⇒ U) = normal-∩ⁿ R (T ⇒ U) +normal-∩ⁿ never (T ∩ U) = never +normal-∩ⁿ unknown (T ∩ U) = T ∩ U +normal-∩ⁿ (R ⇒ S) (T ∩ U) = (R ⇒ S) ∩ (T ∩ U) +normal-∩ⁿ (R ∩ S) (T ∩ U) = (R ∩ S) ∩ (T ∩ U) +normal-∩ⁿ (R ∪ S) (T ∩ U) = normal-∩ⁿ R (T ∩ U) + +normal-∪ⁿˢ S never = S +normal-∪ⁿˢ never number = never ∪ number +normal-∪ⁿˢ unknown number = unknown +normal-∪ⁿˢ (R ⇒ S) number = (R ⇒ S) ∪ number +normal-∪ⁿˢ (R ∩ S) number = (R ∩ S) ∪ number +normal-∪ⁿˢ (R ∪ number) number = R ∪ number +normal-∪ⁿˢ (R ∪ boolean) number = normal-∪ⁿˢ R number ∪ boolean +normal-∪ⁿˢ (R ∪ string) number = normal-∪ⁿˢ R number ∪ string +normal-∪ⁿˢ (R ∪ nil) number = normal-∪ⁿˢ R number ∪ nil +normal-∪ⁿˢ never boolean = never ∪ boolean +normal-∪ⁿˢ unknown boolean = unknown +normal-∪ⁿˢ (R ⇒ S) boolean = (R ⇒ S) ∪ boolean +normal-∪ⁿˢ (R ∩ S) boolean = (R ∩ S) ∪ boolean +normal-∪ⁿˢ (R ∪ number) boolean = normal-∪ⁿˢ R boolean ∪ number +normal-∪ⁿˢ (R ∪ boolean) boolean = R ∪ boolean +normal-∪ⁿˢ (R ∪ string) boolean = normal-∪ⁿˢ R boolean ∪ string +normal-∪ⁿˢ (R ∪ nil) boolean = normal-∪ⁿˢ R boolean ∪ nil +normal-∪ⁿˢ never string = never ∪ string +normal-∪ⁿˢ unknown string = unknown +normal-∪ⁿˢ (R ⇒ S) string = (R ⇒ S) ∪ string +normal-∪ⁿˢ (R ∩ S) string = (R ∩ S) ∪ string +normal-∪ⁿˢ (R ∪ number) string = normal-∪ⁿˢ R string ∪ number +normal-∪ⁿˢ (R ∪ boolean) string = normal-∪ⁿˢ R string ∪ boolean +normal-∪ⁿˢ (R ∪ string) string = R ∪ string +normal-∪ⁿˢ (R ∪ nil) string = normal-∪ⁿˢ R string ∪ nil +normal-∪ⁿˢ never nil = never ∪ nil +normal-∪ⁿˢ unknown nil = unknown +normal-∪ⁿˢ (R ⇒ S) nil = (R ⇒ S) ∪ nil +normal-∪ⁿˢ (R ∩ S) nil = (R ∩ S) ∪ nil +normal-∪ⁿˢ (R ∪ number) nil = normal-∪ⁿˢ R nil ∪ number +normal-∪ⁿˢ (R ∪ boolean) nil = normal-∪ⁿˢ R nil ∪ boolean +normal-∪ⁿˢ (R ∪ string) nil = normal-∪ⁿˢ R nil ∪ string +normal-∪ⁿˢ (R ∪ nil) nil = R ∪ nil + +normal-∩ⁿˢ never number = never +normal-∩ⁿˢ never boolean = never +normal-∩ⁿˢ never string = never +normal-∩ⁿˢ never nil = never +normal-∩ⁿˢ unknown number = number +normal-∩ⁿˢ unknown boolean = boolean +normal-∩ⁿˢ unknown string = string +normal-∩ⁿˢ unknown nil = nil +normal-∩ⁿˢ (R ⇒ S) number = never +normal-∩ⁿˢ (R ⇒ S) boolean = never +normal-∩ⁿˢ (R ⇒ S) string = never +normal-∩ⁿˢ (R ⇒ S) nil = never +normal-∩ⁿˢ (R ∩ S) number = never +normal-∩ⁿˢ (R ∩ S) boolean = never +normal-∩ⁿˢ (R ∩ S) string = never +normal-∩ⁿˢ (R ∩ S) nil = never +normal-∩ⁿˢ (R ∪ number) number = number +normal-∩ⁿˢ (R ∪ boolean) number = normal-∩ⁿˢ R number +normal-∩ⁿˢ (R ∪ string) number = normal-∩ⁿˢ R number +normal-∩ⁿˢ (R ∪ nil) number = normal-∩ⁿˢ R number +normal-∩ⁿˢ (R ∪ number) boolean = normal-∩ⁿˢ R boolean +normal-∩ⁿˢ (R ∪ boolean) boolean = boolean +normal-∩ⁿˢ (R ∪ string) boolean = normal-∩ⁿˢ R boolean +normal-∩ⁿˢ (R ∪ nil) boolean = normal-∩ⁿˢ R boolean +normal-∩ⁿˢ (R ∪ number) string = normal-∩ⁿˢ R string +normal-∩ⁿˢ (R ∪ boolean) string = normal-∩ⁿˢ R string +normal-∩ⁿˢ (R ∪ string) string = string +normal-∩ⁿˢ (R ∪ nil) string = normal-∩ⁿˢ R string +normal-∩ⁿˢ (R ∪ number) nil = normal-∩ⁿˢ R nil +normal-∩ⁿˢ (R ∪ boolean) nil = normal-∩ⁿˢ R nil +normal-∩ⁿˢ (R ∪ string) nil = normal-∩ⁿˢ R nil +normal-∩ⁿˢ (R ∪ nil) nil = nil + +normal-∪ᶠ (R ⇒ S) (T ⇒ U) = (normal-∩ⁿ R T) ⇒ (normal-∪ⁿ S U) +normal-∪ᶠ (R ⇒ S) (G ∩ H) = normal-∪ᶠ (R ⇒ S) G ∩ normal-∪ᶠ (R ⇒ S) H +normal-∪ᶠ (E ∩ F) G = normal-∪ᶠ E G ∩ normal-∪ᶠ F G + +scalar-∩-fun-<:-never : ∀ {F S} → FunType F → Scalar S → (F ∩ S) <: never +scalar-∩-fun-<:-never (T ⇒ U) S = scalar-∩-function-<:-never S +scalar-∩-fun-<:-never (F ∩ G) S = <:-trans (<:-intersect <:-∩-left <:-refl) (scalar-∩-fun-<:-never F S) + +flipper : ∀ {S T U} → ((S ∪ T) ∪ U) <: ((S ∪ U) ∪ T) +flipper = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) <:-∪-assocl) + +∩-<:-∩ⁿ : ∀ {S T} → Normal S → Normal T → (S ∩ T) <: (S ∩ⁿ T) +∩ⁿ-<:-∩ : ∀ {S T} → Normal S → Normal T → (S ∩ⁿ T) <: (S ∩ T) +∩-<:-∩ⁿˢ : ∀ {S T} → Normal S → Scalar T → (S ∩ T) <: (S ∩ⁿˢ T) +∩ⁿˢ-<:-∩ : ∀ {S T} → Normal S → Scalar T → (S ∩ⁿˢ T) <: (S ∩ T) +∪ᶠ-<:-∪ : ∀ {F G} → FunType F → FunType G → (F ∪ᶠ G) <: (F ∪ G) +∪ⁿ-<:-∪ : ∀ {S T} → Normal S → Normal T → (S ∪ⁿ T) <: (S ∪ T) +∪-<:-∪ⁿ : ∀ {S T} → Normal S → Normal T → (S ∪ T) <: (S ∪ⁿ T) +∪ⁿˢ-<:-∪ : ∀ {S T} → Normal S → OptScalar T → (S ∪ⁿˢ T) <: (S ∪ T) +∪-<:-∪ⁿˢ : ∀ {S T} → Normal S → OptScalar T → (S ∪ T) <: (S ∪ⁿˢ T) + +∩-<:-∩ⁿ S never = <:-∩-right +∩-<:-∩ⁿ S unknown = <:-∩-left +∩-<:-∩ⁿ S (T ∪ U) = <:-trans <:-∩-distl-∪ (<:-trans (<:-union (∩-<:-∩ⁿ S T) (∩-<:-∩ⁿˢ S U)) (∪-<:-∪ⁿˢ (normal-∩ⁿ S T) (normal-∩ⁿˢ S U)) ) +∩-<:-∩ⁿ never (T ⇒ U) = <:-∩-left +∩-<:-∩ⁿ unknown (T ⇒ U) = <:-∩-right +∩-<:-∩ⁿ (R ⇒ S) (T ⇒ U) = <:-refl +∩-<:-∩ⁿ (R ∩ S) (T ⇒ U) = <:-refl +∩-<:-∩ⁿ (R ∪ S) (T ⇒ U) = <:-trans <:-∩-distr-∪ (<:-trans (<:-union (∩-<:-∩ⁿ R (T ⇒ U)) (<:-trans <:-∩-symm (∩-<:-∩ⁿˢ (T ⇒ U) S))) (<:-∪-lub <:-refl <:-never)) +∩-<:-∩ⁿ never (T ∩ U) = <:-∩-left +∩-<:-∩ⁿ unknown (T ∩ U) = <:-∩-right +∩-<:-∩ⁿ (R ⇒ S) (T ∩ U) = <:-refl +∩-<:-∩ⁿ (R ∩ S) (T ∩ U) = <:-refl +∩-<:-∩ⁿ (R ∪ S) (T ∩ U) = <:-trans <:-∩-distr-∪ (<:-trans (<:-union (∩-<:-∩ⁿ R (T ∩ U)) (<:-trans <:-∩-symm (∩-<:-∩ⁿˢ (T ∩ U) S))) (<:-∪-lub <:-refl <:-never)) + +∩ⁿ-<:-∩ S never = <:-never +∩ⁿ-<:-∩ S unknown = <:-∩-glb <:-refl <:-unknown +∩ⁿ-<:-∩ S (T ∪ U) = <:-trans (∪ⁿˢ-<:-∪ (normal-∩ⁿ S T) (normal-∩ⁿˢ S U)) (<:-trans (<:-union (∩ⁿ-<:-∩ S T) (∩ⁿˢ-<:-∩ S U)) ∩-distl-∪-<:) +∩ⁿ-<:-∩ never (T ⇒ U) = <:-never +∩ⁿ-<:-∩ unknown (T ⇒ U) = <:-∩-glb <:-unknown <:-refl +∩ⁿ-<:-∩ (R ⇒ S) (T ⇒ U) = <:-refl +∩ⁿ-<:-∩ (R ∩ S) (T ⇒ U) = <:-refl +∩ⁿ-<:-∩ (R ∪ S) (T ⇒ U) = <:-trans (∩ⁿ-<:-∩ R (T ⇒ U)) (<:-∩-glb (<:-trans <:-∩-left <:-∪-left) <:-∩-right) +∩ⁿ-<:-∩ never (T ∩ U) = <:-never +∩ⁿ-<:-∩ unknown (T ∩ U) = <:-∩-glb <:-unknown <:-refl +∩ⁿ-<:-∩ (R ⇒ S) (T ∩ U) = <:-refl +∩ⁿ-<:-∩ (R ∩ S) (T ∩ U) = <:-refl +∩ⁿ-<:-∩ (R ∪ S) (T ∩ U) = <:-trans (∩ⁿ-<:-∩ R (T ∩ U)) (<:-∩-glb (<:-trans <:-∩-left <:-∪-left) <:-∩-right) + +∩-<:-∩ⁿˢ never number = <:-∩-left +∩-<:-∩ⁿˢ never boolean = <:-∩-left +∩-<:-∩ⁿˢ never string = <:-∩-left +∩-<:-∩ⁿˢ never nil = <:-∩-left +∩-<:-∩ⁿˢ unknown T = <:-∩-right +∩-<:-∩ⁿˢ (R ⇒ S) T = scalar-∩-fun-<:-never (R ⇒ S) T +∩-<:-∩ⁿˢ (F ∩ G) T = scalar-∩-fun-<:-never (F ∩ G) T +∩-<:-∩ⁿˢ (R ∪ number) number = <:-∩-right +∩-<:-∩ⁿˢ (R ∪ boolean) number = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R number) (scalar-≢-∩-<:-never boolean number (λ ()))) +∩-<:-∩ⁿˢ (R ∪ string) number = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R number) (scalar-≢-∩-<:-never string number (λ ()))) +∩-<:-∩ⁿˢ (R ∪ nil) number = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R number) (scalar-≢-∩-<:-never nil number (λ ()))) +∩-<:-∩ⁿˢ (R ∪ number) boolean = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R boolean) (scalar-≢-∩-<:-never number boolean (λ ()))) +∩-<:-∩ⁿˢ (R ∪ boolean) boolean = <:-∩-right +∩-<:-∩ⁿˢ (R ∪ string) boolean = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R boolean) (scalar-≢-∩-<:-never string boolean (λ ()))) +∩-<:-∩ⁿˢ (R ∪ nil) boolean = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R boolean) (scalar-≢-∩-<:-never nil boolean (λ ()))) +∩-<:-∩ⁿˢ (R ∪ number) string = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R string) (scalar-≢-∩-<:-never number string (λ ()))) +∩-<:-∩ⁿˢ (R ∪ boolean) string = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R string) (scalar-≢-∩-<:-never boolean string (λ ()))) +∩-<:-∩ⁿˢ (R ∪ string) string = <:-∩-right +∩-<:-∩ⁿˢ (R ∪ nil) string = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R string) (scalar-≢-∩-<:-never nil string (λ ()))) +∩-<:-∩ⁿˢ (R ∪ number) nil = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R nil) (scalar-≢-∩-<:-never number nil (λ ()))) +∩-<:-∩ⁿˢ (R ∪ boolean) nil = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R nil) (scalar-≢-∩-<:-never boolean nil (λ ()))) +∩-<:-∩ⁿˢ (R ∪ string) nil = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R nil) (scalar-≢-∩-<:-never string nil (λ ()))) +∩-<:-∩ⁿˢ (R ∪ nil) nil = <:-∩-right + +∩ⁿˢ-<:-∩ never T = <:-never +∩ⁿˢ-<:-∩ unknown T = <:-∩-glb <:-unknown <:-refl +∩ⁿˢ-<:-∩ (R ⇒ S) T = <:-never +∩ⁿˢ-<:-∩ (F ∩ G) T = <:-never +∩ⁿˢ-<:-∩ (R ∪ number) number = <:-∩-glb <:-∪-right <:-refl +∩ⁿˢ-<:-∩ (R ∪ boolean) number = <:-trans (∩ⁿˢ-<:-∩ R number) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ string) number = <:-trans (∩ⁿˢ-<:-∩ R number) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ nil) number = <:-trans (∩ⁿˢ-<:-∩ R number) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ number) boolean = <:-trans (∩ⁿˢ-<:-∩ R boolean) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ boolean) boolean = <:-∩-glb <:-∪-right <:-refl +∩ⁿˢ-<:-∩ (R ∪ string) boolean = <:-trans (∩ⁿˢ-<:-∩ R boolean) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ nil) boolean = <:-trans (∩ⁿˢ-<:-∩ R boolean) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ number) string = <:-trans (∩ⁿˢ-<:-∩ R string) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ boolean) string = <:-trans (∩ⁿˢ-<:-∩ R string) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ string) string = <:-∩-glb <:-∪-right <:-refl +∩ⁿˢ-<:-∩ (R ∪ nil) string = <:-trans (∩ⁿˢ-<:-∩ R string) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ number) nil = <:-trans (∩ⁿˢ-<:-∩ R nil) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ boolean) nil = <:-trans (∩ⁿˢ-<:-∩ R nil) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ string) nil = <:-trans (∩ⁿˢ-<:-∩ R nil) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ nil) nil = <:-∩-glb <:-∪-right <:-refl + +∪ᶠ-<:-∪ (R ⇒ S) (T ⇒ U) = <:-trans (<:-function (∩-<:-∩ⁿ R T) (∪ⁿ-<:-∪ S U)) <:-function-∪-∩ +∪ᶠ-<:-∪ (R ⇒ S) (G ∩ H) = <:-trans (<:-intersect (∪ᶠ-<:-∪ (R ⇒ S) G) (∪ᶠ-<:-∪ (R ⇒ S) H)) ∪-distl-∩-<: +∪ᶠ-<:-∪ (E ∩ F) G = <:-trans (<:-intersect (∪ᶠ-<:-∪ E G) (∪ᶠ-<:-∪ F G)) ∪-distr-∩-<: + +∪-<:-∪ᶠ : ∀ {F G} → FunType F → FunType G → (F ∪ G) <: (F ∪ᶠ G) +∪-<:-∪ᶠ (R ⇒ S) (T ⇒ U) = <:-trans <:-function-∪ (<:-function (∩ⁿ-<:-∩ R T) (∪-<:-∪ⁿ S U)) +∪-<:-∪ᶠ (R ⇒ S) (G ∩ H) = <:-trans <:-∪-distl-∩ (<:-intersect (∪-<:-∪ᶠ (R ⇒ S) G) (∪-<:-∪ᶠ (R ⇒ S) H)) +∪-<:-∪ᶠ (E ∩ F) G = <:-trans <:-∪-distr-∩ (<:-intersect (∪-<:-∪ᶠ E G) (∪-<:-∪ᶠ F G)) + +∪ⁿˢ-<:-∪ S never = <:-∪-left +∪ⁿˢ-<:-∪ never number = <:-refl +∪ⁿˢ-<:-∪ never boolean = <:-refl +∪ⁿˢ-<:-∪ never string = <:-refl +∪ⁿˢ-<:-∪ never nil = <:-refl +∪ⁿˢ-<:-∪ unknown number = <:-∪-left +∪ⁿˢ-<:-∪ unknown boolean = <:-∪-left +∪ⁿˢ-<:-∪ unknown string = <:-∪-left +∪ⁿˢ-<:-∪ unknown nil = <:-∪-left +∪ⁿˢ-<:-∪ (R ⇒ S) number = <:-refl +∪ⁿˢ-<:-∪ (R ⇒ S) boolean = <:-refl +∪ⁿˢ-<:-∪ (R ⇒ S) string = <:-refl +∪ⁿˢ-<:-∪ (R ⇒ S) nil = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) number = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) boolean = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) string = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) nil = <:-refl +∪ⁿˢ-<:-∪ (R ∪ number) number = <:-union <:-∪-left <:-refl +∪ⁿˢ-<:-∪ (R ∪ boolean) number = <:-trans (<:-union (∪ⁿˢ-<:-∪ R number) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ string) number = <:-trans (<:-union (∪ⁿˢ-<:-∪ R number) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ nil) number = <:-trans (<:-union (∪ⁿˢ-<:-∪ R number) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ number) boolean = <:-trans (<:-union (∪ⁿˢ-<:-∪ R boolean) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ boolean) boolean = <:-union <:-∪-left <:-refl +∪ⁿˢ-<:-∪ (R ∪ string) boolean = <:-trans (<:-union (∪ⁿˢ-<:-∪ R boolean) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ nil) boolean = <:-trans (<:-union (∪ⁿˢ-<:-∪ R boolean) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ number) string = <:-trans (<:-union (∪ⁿˢ-<:-∪ R string) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ boolean) string = <:-trans (<:-union (∪ⁿˢ-<:-∪ R string) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ string) string = <:-union <:-∪-left <:-refl +∪ⁿˢ-<:-∪ (R ∪ nil) string = <:-trans (<:-union (∪ⁿˢ-<:-∪ R string) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ number) nil = <:-trans (<:-union (∪ⁿˢ-<:-∪ R nil) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ boolean) nil = <:-trans (<:-union (∪ⁿˢ-<:-∪ R nil) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ string) nil = <:-trans (<:-union (∪ⁿˢ-<:-∪ R nil) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ nil) nil = <:-union <:-∪-left <:-refl + +∪-<:-∪ⁿˢ T never = <:-∪-lub <:-refl <:-never +∪-<:-∪ⁿˢ never number = <:-refl +∪-<:-∪ⁿˢ never boolean = <:-refl +∪-<:-∪ⁿˢ never string = <:-refl +∪-<:-∪ⁿˢ never nil = <:-refl +∪-<:-∪ⁿˢ unknown number = <:-unknown +∪-<:-∪ⁿˢ unknown boolean = <:-unknown +∪-<:-∪ⁿˢ unknown string = <:-unknown +∪-<:-∪ⁿˢ unknown nil = <:-unknown +∪-<:-∪ⁿˢ (R ⇒ S) number = <:-refl +∪-<:-∪ⁿˢ (R ⇒ S) boolean = <:-refl +∪-<:-∪ⁿˢ (R ⇒ S) string = <:-refl +∪-<:-∪ⁿˢ (R ⇒ S) nil = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) number = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) boolean = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) string = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) nil = <:-refl +∪-<:-∪ⁿˢ (R ∪ number) number = <:-∪-lub <:-refl <:-∪-right +∪-<:-∪ⁿˢ (R ∪ boolean) number = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R number) <:-refl) +∪-<:-∪ⁿˢ (R ∪ string) number = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R number) <:-refl) +∪-<:-∪ⁿˢ (R ∪ nil) number = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R number) <:-refl) +∪-<:-∪ⁿˢ (R ∪ number) boolean = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R boolean) <:-refl) +∪-<:-∪ⁿˢ (R ∪ boolean) boolean = <:-∪-lub <:-refl <:-∪-right +∪-<:-∪ⁿˢ (R ∪ string) boolean = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R boolean) <:-refl) +∪-<:-∪ⁿˢ (R ∪ nil) boolean = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R boolean) <:-refl) +∪-<:-∪ⁿˢ (R ∪ number) string = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R string) <:-refl) +∪-<:-∪ⁿˢ (R ∪ boolean) string = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R string) <:-refl) +∪-<:-∪ⁿˢ (R ∪ string) string = <:-∪-lub <:-refl <:-∪-right +∪-<:-∪ⁿˢ (R ∪ nil) string = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R string) <:-refl) +∪-<:-∪ⁿˢ (R ∪ number) nil = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R nil) <:-refl) +∪-<:-∪ⁿˢ (R ∪ boolean) nil = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R nil) <:-refl) +∪-<:-∪ⁿˢ (R ∪ string) nil = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R nil) <:-refl) +∪-<:-∪ⁿˢ (R ∪ nil) nil = <:-∪-lub <:-refl <:-∪-right + +∪ⁿ-<:-∪ S never = <:-∪-left +∪ⁿ-<:-∪ S unknown = <:-∪-right +∪ⁿ-<:-∪ never (T ⇒ U) = <:-∪-right +∪ⁿ-<:-∪ unknown (T ⇒ U) = <:-∪-left +∪ⁿ-<:-∪ (R ⇒ S) (T ⇒ U) = ∪ᶠ-<:-∪ (R ⇒ S) (T ⇒ U) +∪ⁿ-<:-∪ (R ∩ S) (T ⇒ U) = ∪ᶠ-<:-∪ (R ∩ S) (T ⇒ U) +∪ⁿ-<:-∪ (R ∪ S) (T ⇒ U) = <:-trans (<:-union (∪ⁿ-<:-∪ R (T ⇒ U)) <:-refl) (<:-∪-lub (<:-∪-lub (<:-trans <:-∪-left <:-∪-left) <:-∪-right) (<:-trans <:-∪-right <:-∪-left)) +∪ⁿ-<:-∪ never (T ∩ U) = <:-∪-right +∪ⁿ-<:-∪ unknown (T ∩ U) = <:-∪-left +∪ⁿ-<:-∪ (R ⇒ S) (T ∩ U) = ∪ᶠ-<:-∪ (R ⇒ S) (T ∩ U) +∪ⁿ-<:-∪ (R ∩ S) (T ∩ U) = ∪ᶠ-<:-∪ (R ∩ S) (T ∩ U) +∪ⁿ-<:-∪ (R ∪ S) (T ∩ U) = <:-trans (<:-union (∪ⁿ-<:-∪ R (T ∩ U)) <:-refl) (<:-∪-lub (<:-∪-lub (<:-trans <:-∪-left <:-∪-left) <:-∪-right) (<:-trans <:-∪-right <:-∪-left)) +∪ⁿ-<:-∪ S (T ∪ U) = <:-∪-lub (<:-trans (∪ⁿ-<:-∪ S T) (<:-union <:-refl <:-∪-left)) (<:-trans <:-∪-right <:-∪-right) + +∪-<:-∪ⁿ S never = <:-∪-lub <:-refl <:-never +∪-<:-∪ⁿ S unknown = <:-unknown +∪-<:-∪ⁿ never (T ⇒ U) = <:-∪-lub <:-never <:-refl +∪-<:-∪ⁿ unknown (T ⇒ U) = <:-unknown +∪-<:-∪ⁿ (R ⇒ S) (T ⇒ U) = ∪-<:-∪ᶠ (R ⇒ S) (T ⇒ U) +∪-<:-∪ⁿ (R ∩ S) (T ⇒ U) = ∪-<:-∪ᶠ (R ∩ S) (T ⇒ U) +∪-<:-∪ⁿ (R ∪ S) (T ⇒ U) = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) (<:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ R (T ⇒ U)) <:-refl))) +∪-<:-∪ⁿ never (T ∩ U) = <:-∪-lub <:-never <:-refl +∪-<:-∪ⁿ unknown (T ∩ U) = <:-unknown +∪-<:-∪ⁿ (R ⇒ S) (T ∩ U) = ∪-<:-∪ᶠ (R ⇒ S) (T ∩ U) +∪-<:-∪ⁿ (R ∩ S) (T ∩ U) = ∪-<:-∪ᶠ (R ∩ S) (T ∩ U) +∪-<:-∪ⁿ (R ∪ S) (T ∩ U) = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) (<:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ R (T ∩ U)) <:-refl))) +∪-<:-∪ⁿ never (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ never T) <:-refl) +∪-<:-∪ⁿ unknown (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ unknown T) <:-refl) +∪-<:-∪ⁿ (R ⇒ S) (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ (R ⇒ S) T) <:-refl) +∪-<:-∪ⁿ (R ∩ S) (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ (R ∩ S) T) <:-refl) +∪-<:-∪ⁿ (R ∪ S) (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ (R ∪ S) T) <:-refl) + +normalize-<: : ∀ T → normalize T <: T +<:-normalize : ∀ T → T <: normalize T + +<:-normalize nil = <:-∪-right +<:-normalize (S ⇒ T) = <:-function (normalize-<: S) (<:-normalize T) +<:-normalize never = <:-refl +<:-normalize unknown = <:-refl +<:-normalize boolean = <:-∪-right +<:-normalize number = <:-∪-right +<:-normalize string = <:-∪-right +<:-normalize (S ∪ T) = <:-trans (<:-union (<:-normalize S) (<:-normalize T)) (∪-<:-∪ⁿ (normal S) (normal T)) +<:-normalize (S ∩ T) = <:-trans (<:-intersect (<:-normalize S) (<:-normalize T)) (∩-<:-∩ⁿ (normal S) (normal T)) + +normalize-<: nil = <:-∪-lub <:-never <:-refl +normalize-<: (S ⇒ T) = <:-function (<:-normalize S) (normalize-<: T) +normalize-<: never = <:-refl +normalize-<: unknown = <:-refl +normalize-<: boolean = <:-∪-lub <:-never <:-refl +normalize-<: number = <:-∪-lub <:-never <:-refl +normalize-<: string = <:-∪-lub <:-never <:-refl +normalize-<: (S ∪ T) = <:-trans (∪ⁿ-<:-∪ (normal S) (normal T)) (<:-union (normalize-<: S) (normalize-<: T)) +normalize-<: (S ∩ T) = <:-trans (∩ⁿ-<:-∩ (normal S) (normal T)) (<:-intersect (normalize-<: S) (normalize-<: T)) + + From bd6d44f5e33511e1a9241b7a57c0e3a00e3aa6eb Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 28 Apr 2022 18:24:24 -0700 Subject: [PATCH 050/102] Sync to upstream/release/525 (#467) --- Analysis/src/Frontend.cpp | 2 +- Analysis/src/Linter.cpp | 6 +- Analysis/src/Substitution.cpp | 2 +- Analysis/src/TypeInfer.cpp | 211 ++++++------------ Analysis/src/TypeVar.cpp | 11 +- Analysis/src/Unifier.cpp | 89 +++----- Ast/src/Lexer.cpp | 4 +- CLI/FileUtils.cpp | 4 +- CLI/Repl.cpp | 9 +- Compiler/include/Luau/BytecodeBuilder.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 24 ++- Compiler/src/Compiler.cpp | 8 +- Compiler/src/CostModel.cpp | 2 +- Sources.cmake | 3 +- VM/include/lua.h | 2 +- VM/src/lapi.cpp | 2 +- VM/src/lstate.h | 2 +- VM/src/ltable.cpp | 55 ++--- VM/src/ludata.cpp | 13 +- fuzz/proto.cpp | 27 ++- tests/Autocomplete.test.cpp | 1 - tests/Compiler.test.cpp | 50 +++-- tests/Conformance.test.cpp | 2 +- tests/Frontend.test.cpp | 18 +- tests/Module.test.cpp | 1 - tests/NonstrictMode.test.cpp | 1 - tests/Parser.test.cpp | 3 - tests/RuntimeLimits.test.cpp | 270 ++++++++++++++++++++++++ tests/TypeInfer.aliases.test.cpp | 19 +- tests/TypeInfer.annotations.test.cpp | 4 - tests/TypeInfer.functions.test.cpp | 7 - tests/TypeInfer.generics.test.cpp | 24 +-- tests/TypeInfer.loops.test.cpp | 18 ++ tests/TypeInfer.operators.test.cpp | 4 - tests/TypeInfer.primitives.test.cpp | 2 - tests/TypeInfer.provisional.test.cpp | 236 --------------------- tests/TypeInfer.singletons.test.cpp | 16 -- tests/TypeInfer.tables.test.cpp | 10 - tests/TypeInfer.tryUnify.test.cpp | 2 - tests/TypeVar.test.cpp | 2 - 40 files changed, 527 insertions(+), 641 deletions(-) create mode 100644 tests/RuntimeLimits.test.cpp diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 34ccdac4..b8f7836d 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -24,7 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTFLAGVARIABLE(LuauDirtySourceModule, false) -LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) namespace Luau { diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 5608e4b3..200b7d1b 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2653,12 +2653,12 @@ static void lintComments(LintContext& context, const std::vector& ho } else { - std::string::size_type space = hc.content.find_first_of(" \t"); + size_t space = hc.content.find_first_of(" \t"); std::string_view first = std::string_view(hc.content).substr(0, space); if (first == "nolint") { - std::string::size_type notspace = hc.content.find_first_not_of(" \t", space); + size_t notspace = hc.content.find_first_not_of(" \t", space); if (space == std::string::npos || notspace == std::string::npos) { @@ -2827,7 +2827,7 @@ uint64_t LintWarning::parseMask(const std::vector& hotcomments) if (hc.content.compare(0, 6, "nolint") != 0) continue; - std::string::size_type name = hc.content.find_first_not_of(" \t", 6); + size_t name = hc.content.find_first_not_of(" \t", 6); // --!nolint disables everything if (name == std::string::npos) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 1b51fa3d..30d8574a 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,7 +8,7 @@ #include LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6411e2ab..ba91ae1e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,29 +22,25 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) -LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) +LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) -LUAU_FASTFLAGVARIABLE(LuauSealExports, false) +LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) -LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) -LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) @@ -54,12 +50,9 @@ LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) -LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) -LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) -LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); @@ -1160,7 +1153,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } else { - iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + if (FFlag::LuauInstantiateFollows) + iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); + else + iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); } const FunctionTypeVar* iterFunc = get(iterTy); @@ -1172,7 +1168,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) - reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + { + if (FFlag::LuauDoNotRelyOnNextBinding) + reportError(firstValue->location, CannotCallNonFunction{iterTy}); + else + reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + } return check(loopScope, *forin.body); } @@ -1427,8 +1428,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - if (FFlag::LuauFixIncorrectLineNumberDuplicateType) - scope->typeAliasLocations[name] = typealias.location; + scope->typeAliasLocations[name] = typealias.location; } } else @@ -2217,7 +2217,7 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; - if (FFlag::LuauPropertiesGetExpectedType && expectedTable) + if (expectedTable) { auto it = expectedTable->props.find(key->value.data); if (it != expectedTable->props.end()) @@ -2309,9 +2309,8 @@ ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprT } } } - else if (FFlag::LuauExpectedTypesOfProperties) - if (const UnionTypeVar* utv = get(follow(*expectedType))) - expectedUnion = utv; + else if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -2334,7 +2333,7 @@ ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprT if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } - else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + else if (expectedUnion) { std::vector expectedResultTypes; for (TypeId expectedOption : expectedUnion) @@ -2713,8 +2712,6 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); } } @@ -2754,7 +2751,7 @@ TypeId TypeChecker::checkBinaryOperation( reportErrors(state.errors); bool hasErrors = !state.errors.empty(); - if (FFlag::LuauErrorRecoveryType && hasErrors) + if (hasErrors) { // If there are unification errors, the return type may still be unknown // so we loosen the argument types to see if that helps. @@ -2768,8 +2765,7 @@ TypeId TypeChecker::checkBinaryOperation( if (state.errors.empty()) state.log.commit(); } - - if (!hasErrors) + else { state.log.commit(); } @@ -3196,16 +3192,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T } else { - if (!ttv) - { - if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) - // This error now gets reported when we check the function body. - reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - - return errorRecoveryType(scope); - } - - if (lhsType->persistent || ttv->state == TableState::Sealed) + if (!ttv || lhsType->persistent || ttv->state == TableState::Sealed) return errorRecoveryType(scope); } @@ -3532,32 +3519,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } // Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount_DEPRECATED(TypePackId tp) -{ - size_t minCount = 0; - size_t optionalCount = 0; - - auto it = begin(tp); - auto endIter = end(tp); - - while (it != endIter) - { - TypeId ty = *it; - if (isOptional(ty)) - ++optionalCount; - else - { - minCount += optionalCount; - optionalCount = 0; - minCount++; - } - - ++it; - } - - return minCount; -} - static size_t getMinParameterCount(TxnLog* log, TypePackId tp) { size_t minCount = 0; @@ -3597,19 +3558,14 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); - - auto reportCountMismatchError = [&state, &argLocations, minParams, paramPack, argPack]() { + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack]() { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - size_t mp = minParams; - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - mp = getMinParameterCount(&state.log, paramPack); - - state.reportError(TypeError{location, CountMismatch{mp, std::distance(begin(argPack), end(argPack))}}); + size_t minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); }; while (true) @@ -3707,16 +3663,10 @@ void TypeChecker::checkArgumentList( } // ok else { - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); + size_t minParams = getMinParameterCount(&state.log, paramPack); - bool isVariadic = false; - if (FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic) - { - std::optional tail = flatten(paramPack, state.log).second; - if (tail) - isVariadic = Luau::isVariadic(*tail); - } + std::optional tail = flatten(paramPack, state.log).second; + bool isVariadic = tail && Luau::isVariadic(*tail); state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; @@ -3863,7 +3813,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = instantiate(scope, functionType, expr.func->location); } - actualFunctionType = follow(actualFunctionType); + if (!FFlag::LuauInstantiateFollows) + actualFunctionType = follow(actualFunctionType); TypePackId retPack; if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) @@ -3930,16 +3881,13 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - if (FFlag::LuauErrorRecoveryType) - { - const FunctionTypeVar* overload = nullptr; - if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); - if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); - if (overload) - return {errorRecoveryTypePack(overload->retType)}; - } + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; return {errorRecoveryTypePack(retPack)}; } @@ -4129,7 +4077,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); - else if (FFlag::LuauErrorRecoveryType) + else overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); @@ -4715,7 +4663,7 @@ bool Anyification::isDirty(TypeId ty) return false; if (const TableTypeVar* ttv = log->getMutable(ty)) - return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); + return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); else if (log->getMutable(ty)) return true; else if (get(ty)) @@ -4743,12 +4691,9 @@ TypeId Anyification::clean(TypeId ty) TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; - if (FFlag::LuauSealExports) - { - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.tags = ttv->tags; - } + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; TypeId res = addType(std::move(clone)); asMutable(res)->normal = ty->normal; return res; @@ -4791,9 +4736,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + if (FFlag::LuauInstantiateFollows) + ty = follow(ty); + if (FFlag::LuauTypecheckOptPass) { - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionTypeVar* ftv = get(FFlag::LuauInstantiateFollows ? ty : follow(ty)); if (ftv && ftv->hasNoGenerics) return ty; } @@ -5175,8 +5123,6 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); parameterCountErrorReported = true; - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); } } @@ -5294,33 +5240,25 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - if (FFlag::LuauErrorRecoveryType) - { - // Pad the types out with error recovery types - while (typeParams.size() < tf->typeParams.size()) - typeParams.push_back(errorRecoveryType(scope)); - while (typePackParams.size() < tf->typePackParams.size()) - typePackParams.push_back(errorRecoveryTypePack(scope)); - } - else - return errorRecoveryType(scope); + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); } - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { - return itp == tp.ty; + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; }); - bool sameTps = std::equal( - typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { - return itpp == tpp.tp; - }); - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - if (sameTys && sameTps) - return tf->type; - } + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + if (sameTys && sameTps) + return tf->type; return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } @@ -5483,7 +5421,7 @@ bool ApplyTypeFunction::isDirty(TypeId ty) return true; else if (const FreeTypeVar* ftv = get(ty)) { - if (FFlag::LuauRecursiveTypeParameterRestriction && ftv->forwardedTypeAlias) + if (ftv->forwardedTypeAlias) encounteredForwardedType = true; return false; } @@ -5562,7 +5500,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, reportError(location, UnificationTooComplex{}); return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) + if (applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); return errorRecoveryType(scope); @@ -5632,7 +5570,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && (!FFlag::LuauGenericFunctionsDontCacheTypeParams || useCache)) + if (useCache) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) @@ -5667,21 +5605,12 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; - if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - g = cached; - } - else - { - g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - } + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - genericPacks.push_back({g, defaultValue}); - scope->privateTypePackBindings[n] = g; + genericPacks.push_back({cached, defaultValue}); + scope->privateTypePackBindings[n] = cached; } return {generics, genericPacks}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4d42573c..463b4651 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) @@ -775,18 +774,12 @@ TypePackId SingletonTypes::errorRecoveryTypePack() TypeId SingletonTypes::errorRecoveryType(TypeId guess) { - if (FFlag::LuauErrorRecoveryType) - return guess; - else - return &errorType_; + return guess; } TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) { - if (FFlag::LuauErrorRecoveryType) - return guess; - else - return &errorTypePack_; + return guess; } SingletonTypes& getSingletonTypes() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9862d7b3..334806ce 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,10 +23,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauTypecheckOptPass) @@ -1021,7 +1018,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; - if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + if (log.haveSeen(superTp, subTp)) return; if (log.getMutable(superTp)) @@ -1265,12 +1262,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } - if (FFlag::LuauTxnLogSeesTypePacks2) + for (size_t i = 0; i < numGenericPacks; i++) { - for (size_t i = 0; i < numGenericPacks; i++) - { - log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - } + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } CountMismatch::Context context = ctx; @@ -1330,12 +1324,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ctx = context; - if (FFlag::LuauTxnLogSeesTypePacks2) + for (int i = int(numGenericPacks) - 1; 0 <= i; i--) { - for (int i = int(numGenericPacks) - 1; 0 <= i; i--) - { - log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - } + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } for (int i = int(numGenerics) - 1; 0 <= i; i--) @@ -1499,20 +1490,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else missingProperties.push_back(name); - if (FFlag::LuauTxnLogCheckForInvalidation) + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; } } @@ -1570,20 +1558,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); - if (FFlag::LuauTxnLogCheckForInvalidation) + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; } } @@ -1630,27 +1615,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } - if (FFlag::LuauTxnLogDontRetryForIndexers) - { - // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } - else if (FFlag::LuauTxnLogCheckForInvalidation) - { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } - } + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); if (!missingProperties.empty()) { diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 5dd4f04e..a1f1d469 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkip, false) - namespace Luau { @@ -361,7 +359,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) while (isSpace(peekch())) consume(); - if (!FFlag::LuauParseLocationIgnoreCommentSkip || updatePrevLocation) + if (updatePrevLocation) prevLocation = lexeme.location; lexeme = readNext(); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fb6ac373..39a14ec7 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -240,7 +240,7 @@ std::optional getParentPath(const std::string& path) return std::nullopt; #endif - std::string::size_type slash = path.find_last_of("\\/", path.size() - 1); + size_t slash = path.find_last_of("\\/", path.size() - 1); if (slash == 0) return "/"; @@ -253,7 +253,7 @@ std::optional getParentPath(const std::string& path) static std::string getExtension(const std::string& path) { - std::string::size_type dot = path.find_last_of(".\\/"); + size_t dot = path.find_last_of(".\\/"); if (dot == std::string::npos || path[dot] != '.') return ""; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 345cb7ac..4cb22346 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -34,7 +34,8 @@ enum class CliMode enum class CompileFormat { Text, - Binary + Binary, + Null }; constexpr int MaxTraversalLimit = 50; @@ -594,6 +595,8 @@ static bool compileFile(const char* name, CompileFormat format) case CompileFormat::Binary: fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; + case CompileFormat::Null: + break; } return true; @@ -716,6 +719,10 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Text; } + else if (strcmp(argv[1], "--compile=null") == 0) + { + compileFormat = CompileFormat::Null; + } else { fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 67b93028..b00440ae 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -232,7 +232,7 @@ private: DenseHashMap stringTable; - DenseHashMap debugRemarks; + std::vector> debugRemarks; std::string debugRemarkBuffer; BytecodeEncoder* encoder = nullptr; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 6c6f1225..871a1484 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -181,7 +181,6 @@ BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) , stringTable({nullptr, 0}) - , debugRemarks(~0u) , encoder(encoder) { LUAU_ASSERT(stringTable.find(StringRef{"", 0}) == nullptr); @@ -257,6 +256,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) void BytecodeBuilder::setMainFunction(uint32_t fid) { + LUAU_ASSERT(fid < functions.size()); + mainFunction = fid; } @@ -531,7 +532,7 @@ void BytecodeBuilder::addDebugRemark(const char* format, ...) // we null-terminate all remarks to avoid storing remark length debugRemarkBuffer += '\0'; - debugRemarks[uint32_t(insns.size())] = uint32_t(offset); + debugRemarks.emplace_back(uint32_t(insns.size()), uint32_t(offset)); } void BytecodeBuilder::finalize() @@ -1719,6 +1720,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const const uint32_t* codeEnd = insns.data() + insns.size(); int lastLine = -1; + size_t nextRemark = 0; std::string result; @@ -1741,6 +1743,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const while (code != codeEnd) { uint8_t op = LUAU_INSN_OP(*code); + uint32_t pc = uint32_t(code - insns.data()); if (op == LOP_PREPVARARGS) { @@ -1751,15 +1754,16 @@ std::string BytecodeBuilder::dumpCurrentFunction() const if (dumpFlags & Dump_Remarks) { - const uint32_t* remark = debugRemarks.find(uint32_t(code - insns.data())); - - if (remark) - formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + *remark); + while (nextRemark < debugRemarks.size() && debugRemarks[nextRemark].first == pc) + { + formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + debugRemarks[nextRemark].second); + nextRemark++; + } } if (dumpFlags & Dump_Source) { - int line = lines[code - insns.data()]; + int line = lines[pc]; if (line > 0 && line != lastLine) { @@ -1771,7 +1775,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const if (dumpFlags & Dump_Lines) { - formatAppend(result, "%d: ", lines[code - insns.data()]); + formatAppend(result, "%d: ", lines[pc]); } code = dumpInstruction(code, result); @@ -1784,11 +1788,11 @@ void BytecodeBuilder::setDumpSource(const std::string& source) { dumpSource.clear(); - std::string::size_type pos = 0; + size_t pos = 0; while (pos != std::string::npos) { - std::string::size_type next = source.find('\n', pos); + size_t next = source.find('\n', pos); if (next == std::string::npos) { diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 810caaee..0f17ee02 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -2206,9 +2206,15 @@ struct Compiler return false; } + if (Variable* lv = variables.find(stat->var); lv && lv->written) + { + bytecode.addDebugRemark("loop unroll failed: mutable loop variable"); + return false; + } + int tripCount = (to - from) / step + 1; - if (tripCount > thresholdBase * thresholdMaxBoost / 100) + if (tripCount > thresholdBase) { bytecode.addDebugRemark("loop unroll failed: too many iterations (%d)", tripCount); return false; diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index d8511bdb..9afd09f6 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -249,7 +249,7 @@ int computeCost(uint64_t model, const bool* varsConst, size_t varCount) return cost; for (size_t i = 0; i < varCount && i < 7; ++i) - cost -= int((model >> (8 * i + 8)) & 0x7f) * varsConst[i]; + cost -= int((model >> (i * 8 + 8)) & 0x7f) * varsConst[i]; return cost; } diff --git a/Sources.cmake b/Sources.cmake index 60e5dfda..f9263b24 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -220,8 +220,8 @@ if(TARGET Luau.UnitTest) tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp - tests/CostModel.test.cpp tests/Config.test.cpp + tests/CostModel.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp tests/JsonEncoder.test.cpp @@ -232,6 +232,7 @@ if(TARGET Luau.UnitTest) tests/Normalize.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp + tests/RuntimeLimits.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp tests/ToDot.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index d08b73ea..c3ebadb1 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -299,7 +299,7 @@ LUA_API uintptr_t lua_encodepointer(lua_State* L, uintptr_t p); LUA_API double lua_clock(); -LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)); +LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)); LUA_API void lua_clonefunction(lua_State* L, int idx); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 431f7e59..1f3b0943 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1323,7 +1323,7 @@ void lua_unref(lua_State* L, int ref) return; } -void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) +void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); L->global->udatagc[tag] = dtor; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 45d9ba2c..423514a7 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -200,7 +200,7 @@ typedef struct global_State uint64_t rngstate; /* PCG random number generator state */ uint64_t ptrenckey[4]; /* pointer encoding key for display */ - void (*udatagc[LUA_UTAG_LIMIT])(void*); /* for each userdata tag, a gc callback to be called immediately before freeing memory */ + void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); /* for each userdata tag, a gc callback to be called immediately before freeing memory */ lua_Callbacks cb; diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index dc40b6ef..3dc3bd1b 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,7 +33,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) // max size of both array and hash part is 2^MAXBITS @@ -400,16 +399,9 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) { if (!ttisnil(&t->array[i])) { - if (FFlag::LuauTableRehashRework) - { - TValue ok; - setnvalue(&ok, cast_num(i + 1)); - setobjt2t(L, newkey(L, t, &ok), &t->array[i]); - } - else - { - setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); - } + TValue ok; + setnvalue(&ok, cast_num(i + 1)); + setobjt2t(L, newkey(L, t, &ok), &t->array[i]); } } /* shrink array */ @@ -418,30 +410,14 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) /* used for the migration check at the end */ TValue* anew = t->array; /* re-insert elements from hash part */ - if (FFlag::LuauTableRehashRework) + for (int i = twoto(oldhsize) - 1; i >= 0; i--) { - for (int i = twoto(oldhsize) - 1; i >= 0; i--) + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) - { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); - } - } - } - else - { - for (int i = twoto(oldhsize) - 1; i >= 0; i--) - { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) - { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, luaH_set(L, t, &ok), gval(old)); - } + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); } } @@ -559,7 +535,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { rehash(L, t, key); /* grow table */ - // after rehash, numeric keys might be located in the new array part, but won't be found in the node part + /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ return arrayornewkey(L, t, key); } @@ -571,15 +547,8 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* cannot find a free place? */ rehash(L, t, key); /* grow table */ - if (!FFlag::LuauTableRehashRework) - { - return luaH_set(L, t, key); /* re-insert key into grown table */ - } - else - { - // after rehash, numeric keys might be located in the new array part, but won't be found in the node part - return arrayornewkey(L, t, key); - } + /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ + return arrayornewkey(L, t, key); } LUAU_ASSERT(n != dummynode); TValue mk; diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 819d1863..28152689 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -22,14 +22,21 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { - void (*dtor)(void*) = nullptr; if (u->tag < LUA_UTAG_LIMIT) + { + void (*dtor)(lua_State*, void*) = nullptr; dtor = L->global->udatagc[u->tag]; + if (dtor) + dtor(L, u->data); + } else if (u->tag == UTAG_IDTOR) + { + void (*dtor)(void*) = nullptr; memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + if (dtor) + dtor(u->data); + } - if (dtor) - dtor(u->data); luaM_freegco(L, u, sizeudata(u->len), u->memcat, page); } diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index a48f068b..22483f9e 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -137,6 +137,21 @@ int registerTypes(Luau::TypeChecker& env) return 0; } + +static void setupFrontend(Luau::Frontend& frontend) +{ + registerTypes(frontend.typeChecker); + Luau::freeze(frontend.typeChecker.globalTypes); + + registerTypes(frontend.typeCheckerForAutocomplete); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + + frontend.iceHandler.onInternalError = [](const char* error) { + printf("ICE: %s\n", error); + LUAU_ASSERT(!"ICE"); + }; +} + struct FuzzFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override @@ -238,19 +253,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) if (kFuzzTypeck) { static FuzzFileResolver fileResolver; - static Luau::NullConfigResolver configResolver; + static FuzzConfigResolver configResolver; static Luau::FrontendOptions options{true, true}; static Luau::Frontend frontend(&fileResolver, &configResolver, options); - static int once = registerTypes(frontend.typeChecker); - (void)once; - static int once2 = (Luau::freeze(frontend.typeChecker.globalTypes), 0); - (void)once2; - - frontend.iceHandler.onInternalError = [](const char* error) { - printf("ICE: %s\n", error); - LUAU_ASSERT(!"ICE"); - }; + static int once = (setupFrontend(frontend), 0); // restart frontend.clear(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index f66e23ed..5b70481b 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2761,7 +2761,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - ScopedFastFlag luauExpectedTypesOfProperties{"LuauExpectedTypesOfProperties", true}; check(R"( type tag = "cat" | "dog" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f3e60690..7b4bfc72 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -2698,16 +2698,22 @@ TEST_CASE("DebugRemarks") uint32_t fid = bcb.beginFunction(0); - bcb.addDebugRemark("test remark #%d", 42); + bcb.addDebugRemark("test remark #%d", 1); + bcb.emitABC(LOP_LOADNIL, 0, 0, 0); + bcb.addDebugRemark("test remark #%d", 2); + bcb.addDebugRemark("test remark #%d", 3); bcb.emitABC(LOP_RETURN, 0, 1, 0); - bcb.endFunction(0, 0); + bcb.endFunction(1, 0); bcb.setMainFunction(fid); bcb.finalize(); CHECK_EQ("\n" + bcb.dumpFunction(0), R"( -REMARK test remark #42 +REMARK test remark #1 +LOADNIL R0 +REMARK test remark #2 +REMARK test remark #3 RETURN R0 0 )"); } @@ -4332,7 +4338,7 @@ RETURN R0 1 // loops with body that's long but has a high boost factor due to constant folding CHECK_EQ("\n" + compileFunction(R"( local t = {} -for i=1,30 do +for i=1,25 do t[i] = i * i * i end return t @@ -4390,16 +4396,6 @@ LOADN R1 13824 SETTABLEN R1 R0 24 LOADN R1 15625 SETTABLEN R1 R0 25 -LOADN R1 17576 -SETTABLEN R1 R0 26 -LOADN R1 19683 -SETTABLEN R1 R0 27 -LOADN R1 21952 -SETTABLEN R1 R0 28 -LOADN R1 24389 -SETTABLEN R1 R0 29 -LOADN R1 27000 -SETTABLEN R1 R0 30 RETURN R0 1 )"); @@ -4431,4 +4427,30 @@ RETURN R0 1 )"); } +TEST_CASE("LoopUnrollMutable") +{ + // can't unroll loops that mutate iteration variable + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + i = 3 + print(i) -- should print 3 three times in a row +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 3 +LOADN R1 1 +FORNPREP R0 +7 +MOVE R3 R2 +LOADN R3 3 +GETIMPORT R4 1 +MOVE R5 R3 +CALL R4 1 0 +FORNLOOP R0 -7 +RETURN R0 0 +)"); +} + + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 0ed7dc44..6f136d36 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1056,7 +1056,7 @@ TEST_CASE("UserdataApi") lua_State* L = globalState.get(); // setup dtor for tag 42 (created later) - lua_setuserdatadtor(L, 42, [](void* data) { + lua_setuserdatadtor(L, 42, [](lua_State* l, void* data) { dtorhits += *(int*)data; }); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 9fc0a005..e771b6b1 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -975,8 +975,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { - ScopedFastFlag sffs("LuauSealExports", true); - frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -1035,4 +1033,20 @@ return false; fix.frontend.check("Module/B"); } +TEST_CASE("check_without_builtin_next") +{ + ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; + + TestFileResolver fileResolver; + TestConfigResolver configResolver; + Frontend frontend(&fileResolver, &configResolver); + + fileResolver.source["Module/A"] = "for k,v in 2 do end"; + fileResolver.source["Module/B"] = "return next"; + + // We don't care about the result. That we haven't crashed is enough. + frontend.check("Module/A"); + frontend.check("Module/B"); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index af7d76de..44cc20a7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -199,7 +199,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { ScopedFastFlag sff[]{ - {"LuauErrorRecoveryType", true}, {"LuauLosslessClone", true}, }; diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 9748eb27..feeaf2c2 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -283,7 +283,6 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") ScopedFastFlag sff[]{ {"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}, - {"LuauSealExports", true}, }; CheckResult result = check(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b941103d..55eafe3c 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1606,8 +1606,6 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") { - ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; - AstStatBlock* block = parse(R"( type F = number --comment @@ -1620,7 +1618,6 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") { - ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; // Same should hold when comments are captured diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp new file mode 100644 index 00000000..dcbf0b61 --- /dev/null +++ b/tests/RuntimeLimits.test.cpp @@ -0,0 +1,270 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +/* Tests in this source file are meant to be a bellwether to verify that the numeric limits we've set are sufficient for + * most real-world scripts. + * + * If a change breaks a test in this source file, please don't adjust the flag values set in the fixture. Instead, + * consider it a latent performance problem by default. + * + * We should periodically revisit this to retest the limits. + */ + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + +struct LimitFixture : Fixture +{ +#if defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; +#endif + + ScopedFastFlag LuauJustOneCallFrameForHaveSeen{"LuauJustOneCallFrameForHaveSeen", true}; +}; + +template +bool hasError(const CheckResult& result, T* = nullptr) +{ + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& a) { + return nullptr != get(a); + }); + return it != result.errors.end(); +} + +TEST_SUITE_BEGIN("RuntimeLimitTests"); + +TEST_CASE_FIXTURE(LimitFixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) +{ + constexpr const char* src = R"LUA( + --!strict + local TS = _G[script] + local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet + local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit + local Iterator + lazyGet("Iterator", function(c) + Iterator = c + end) + local Option + lazyGet("Option", function(c) + Option = c + end) + local Vec + lazyGet("Vec", function(c) + Vec = c + end) + local Result + do + Result = setmetatable({}, { + __tostring = function() + return "Result" + end, + }) + Result.__index = Result + function Result.new(...) + local self = setmetatable({}, Result) + self:constructor(...) + return self + end + function Result:constructor(okValue, errValue) + self.okValue = okValue + self.errValue = errValue + end + function Result:ok(val) + return Result.new(val, nil) + end + function Result:err(val) + return Result.new(nil, val) + end + function Result:fromCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) + end + function Result:fromVoidCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) + end + Result.fromPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + return TS.TRY_RETURN, { Result:ok(TS.await(p)) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + Result.fromVoidPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + TS.await(p) + return TS.TRY_RETURN, { Result:ok(unit()) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + function Result:isOk() + return self.okValue ~= nil + end + function Result:isErr() + return self.errValue ~= nil + end + function Result:contains(x) + return self.okValue == x + end + function Result:containsErr(x) + return self.errValue == x + end + function Result:okOption() + return Option:wrap(self.okValue) + end + function Result:errOption() + return Option:wrap(self.errValue) + end + function Result:map(func) + return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) + end + function Result:mapOr(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def + end + return _0 + end + function Result:mapOrElse(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def(self.errValue) + end + return _0 + end + function Result:mapErr(func) + return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) + end + Result["and"] = function(self, other) + return self:isErr() and Result:err(self.errValue) or other + end + function Result:andThen(func) + return self:isErr() and Result:err(self.errValue) or func(self.okValue) + end + Result["or"] = function(self, other) + return self:isOk() and Result:ok(self.okValue) or other + end + function Result:orElse(other) + return self:isOk() and Result:ok(self.okValue) or other(self.errValue) + end + function Result:expect(msg) + if self:isOk() then + return self.okValue + else + error(msg) + end + end + function Result:unwrap() + return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) + end + function Result:unwrapOr(def) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = def + end + return _0 + end + function Result:unwrapOrElse(gen) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = gen(self.errValue) + end + return _0 + end + function Result:expectErr(msg) + if self:isErr() then + return self.errValue + else + error(msg) + end + end + function Result:unwrapErr() + return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) + end + function Result:transpose() + return self:isOk() and self.okValue:map(function(some) + return Result:ok(some) + end) or Option:some(Result:err(self.errValue)) + end + function Result:flatten() + return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) + end + function Result:match(ifOk, ifErr) + local _0 + if self:isOk() then + _0 = ifOk(self.okValue) + else + _0 = ifErr(self.errValue) + end + return _0 + end + function Result:asPtr() + local _0 = (self.okValue) + if _0 == nil then + _0 = (self.errValue) + end + return _0 + end + end + local resultMeta = Result + resultMeta.__eq = function(a, b) + return b:match(function(ok) + return a:contains(ok) + end, function(err) + return a:containsErr(err) + end) + end + resultMeta.__tostring = function(result) + return result:match(function(ok) + return "Result.ok(" .. tostring(ok) .. ")" + end, function(err) + return "Result.err(" .. tostring(err) .. ")" + end) + end + return { + Result = Result, + } + )LUA"; + + if (FFlag::LuauLowerBoundsCalculation) + (void)check(src); + else + CHECK_THROWS_AS(check(src), std::exception); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b2e76052..b0eb31ce 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,8 +7,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauFixIncorrectLineNumberDuplicateType) - TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") @@ -257,11 +255,7 @@ TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_dup auto dtd = get(result.errors[0]); REQUIRE(dtd); CHECK_EQ(dtd->name, "B"); - - if (FFlag::LuauFixIncorrectLineNumberDuplicateType) - CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); - else - CHECK_EQ(dtd->previousLocation.begin.line + 1, 1); + CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); } TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") @@ -495,8 +489,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- OK because forwarded types are used with their parameters. type Tree = { data: T, children: Forest } @@ -508,8 +500,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- Not OK because forwarded types are used with different types than their parameters. type Forest = {Tree<{T}>} @@ -531,8 +521,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( type Tree1 = { data: T, children: {Tree2} } type Tree2 = { data: U, children: {Tree1} } @@ -647,9 +635,6 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni { ScopedFastFlag sff[] = { {"LuauTwoPassAliasDefinitionFix", true}, - - // We also force this flag because it surfaced an unfortunate interaction. - {"LuauErrorRecoveryType", true}, }; CheckResult result = check(R"( @@ -687,8 +672,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- this would be an infinite type if we allowed it type Tree = { data: T, children: {Tree<{T}>} } diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index e2971ad5..7f1c757a 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -221,8 +221,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") { - ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( local a = 55 :: string )"); @@ -407,8 +405,6 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( type A = B type B = A diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 7cd7bec3..0e071217 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -951,8 +951,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( type Overload = ((string) -> string) & ((number, number) -> number) local abc: Overload @@ -1538,7 +1536,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") { - ScopedFastFlag sff{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) end @@ -1560,8 +1557,6 @@ TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") { - ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; - ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 @@ -1587,8 +1582,6 @@ wrapper(test) TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") { - ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; - ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 49d31fc6..91be2c1c 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -11,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauFixArgumentCountMismatchAmountWithGenericTypes) - TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -679,8 +677,6 @@ local d: D = c TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") { - ScopedFastFlag sff{"LuauGenericFunctionsDontCacheTypeParams", true}; - CheckResult result = check(R"( -- See https://github.com/Roblox/luau/issues/332 -- This function has a type parameter with the same name as clones, @@ -707,8 +703,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") ScopedFastFlag sffs[] = { {"LuauTableSubtypingVariance2", true}, {"LuauUnsealedTableLiteral", true}, - {"LuauPropertiesGetExpectedType", true}, - {"LuauRecursiveTypeParameterRestriction", true}, }; CheckResult result = check(R"( @@ -733,8 +727,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -753,8 +745,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -773,8 +763,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -805,11 +793,7 @@ wrapper(test) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 1 is specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") @@ -826,11 +810,7 @@ wrapper(test2, 1, "", 3) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_function") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 30df717b..960c6edf 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -78,6 +78,8 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") { + ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; + CheckResult result = check(R"( local foo = "bar" for i, v in foo do @@ -85,6 +87,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") @@ -470,4 +473,19 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") +{ + ScopedFastFlag luauInstantiateFollows{"LuauInstantiateFollows", true}; + + // Just check that this doesn't assert + check(R"( + --!nonstrict + function _(l0:number) + return _ + end + for _ in _(8) do + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 5f2e2404..a2787cad 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -142,8 +142,6 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( --!strict local Vec3 = {} @@ -178,8 +176,6 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( --!strict local Vec3 = {} diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 3ddf9813..e1684df7 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -85,8 +85,6 @@ TEST_CASE_FIXTURE(Fixture, "string_function_other") TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 4b5075d9..2ef77419 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -268,242 +268,6 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doct } } -TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) -{ - ScopedFastInt sffi{"LuauTarjanChildLimit", 400}; - - CheckResult result = check(R"LUA( - --!strict - local TS = _G[script] - local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet - local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit - local Iterator - lazyGet("Iterator", function(c) - Iterator = c - end) - local Option - lazyGet("Option", function(c) - Option = c - end) - local Vec - lazyGet("Vec", function(c) - Vec = c - end) - local Result - do - Result = setmetatable({}, { - __tostring = function() - return "Result" - end, - }) - Result.__index = Result - function Result.new(...) - local self = setmetatable({}, Result) - self:constructor(...) - return self - end - function Result:constructor(okValue, errValue) - self.okValue = okValue - self.errValue = errValue - end - function Result:ok(val) - return Result.new(val, nil) - end - function Result:err(val) - return Result.new(nil, val) - end - function Result:fromCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) - end - function Result:fromVoidCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) - end - Result.fromPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - return TS.TRY_RETURN, { Result:ok(TS.await(p)) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - Result.fromVoidPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - TS.await(p) - return TS.TRY_RETURN, { Result:ok(unit()) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - function Result:isOk() - return self.okValue ~= nil - end - function Result:isErr() - return self.errValue ~= nil - end - function Result:contains(x) - return self.okValue == x - end - function Result:containsErr(x) - return self.errValue == x - end - function Result:okOption() - return Option:wrap(self.okValue) - end - function Result:errOption() - return Option:wrap(self.errValue) - end - function Result:map(func) - return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) - end - function Result:mapOr(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def - end - return _0 - end - function Result:mapOrElse(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def(self.errValue) - end - return _0 - end - function Result:mapErr(func) - return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) - end - Result["and"] = function(self, other) - return self:isErr() and Result:err(self.errValue) or other - end - function Result:andThen(func) - return self:isErr() and Result:err(self.errValue) or func(self.okValue) - end - Result["or"] = function(self, other) - return self:isOk() and Result:ok(self.okValue) or other - end - function Result:orElse(other) - return self:isOk() and Result:ok(self.okValue) or other(self.errValue) - end - function Result:expect(msg) - if self:isOk() then - return self.okValue - else - error(msg) - end - end - function Result:unwrap() - return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) - end - function Result:unwrapOr(def) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = def - end - return _0 - end - function Result:unwrapOrElse(gen) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = gen(self.errValue) - end - return _0 - end - function Result:expectErr(msg) - if self:isErr() then - return self.errValue - else - error(msg) - end - end - function Result:unwrapErr() - return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) - end - function Result:transpose() - return self:isOk() and self.okValue:map(function(some) - return Result:ok(some) - end) or Option:some(Result:err(self.errValue)) - end - function Result:flatten() - return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) - end - function Result:match(ifOk, ifErr) - local _0 - if self:isOk() then - _0 = ifOk(self.okValue) - else - _0 = ifErr(self.errValue) - end - return _0 - end - function Result:asPtr() - local _0 = (self.okValue) - if _0 == nil then - _0 = (self.errValue) - end - return _0 - end - end - local resultMeta = Result - resultMeta.__eq = function(a, b) - return b:match(function(ok) - return a:contains(ok) - end, function(err) - return a:containsErr(err) - end) - end - resultMeta.__tostring = function(result) - return result:match(function(ok) - return "Result.ok(" .. tostring(ok) .. ")" - end, function(err) - return "Result.err(" .. tostring(err) .. ")" - end) - end - return { - Result = Result, - } - )LUA"); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { - return nullptr != get(a); - }); - if (it == result.errors.end()) - { - dumpErrors(result); - FAIL("Expected a UnificationTooComplex error"); - } -} - // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 2b01c29e..8d6682b8 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -164,10 +164,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -281,10 +277,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Cat = { tag: 'cat', catfood: string } type Dog = { tag: 'dog', dogfood: string } @@ -302,10 +294,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Good = { success: true, result: string } type Bad = { success: false, error: string } @@ -323,10 +311,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Cat = { tag: 'cat', catfood: string } type Dog = { tag: 'dog', dogfood: string } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a727bb3..5bd522a3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2122,8 +2122,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, }; @@ -2143,8 +2141,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, {"LuauUnsealedTableLiteral", true}, }; @@ -2171,8 +2167,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, }; @@ -2377,8 +2371,6 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") { - ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; - CheckResult result = check(R"( -- This example produced a UAF at one point, caused by pointers to table types becoming -- invalidated by child unifiers. (Calling log.concat can cause pointers to become invalid.) @@ -2409,8 +2401,6 @@ end TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf2") { - ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; - CheckResult result = check(R"( -- Another example that UAFd, this time found by fuzzing. local _ diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index c21e1625..b6e93265 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -126,8 +126,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( function f(arg: number) return arg end local a diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index d03bb03c..e033fe22 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -184,8 +184,6 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - ScopedFastFlag sff{"LuauSealExports", true}; - TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; From 448f03218f8c42b1af06a724083252f5197a3653 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 29 Apr 2022 09:33:30 -0700 Subject: [PATCH 051/102] Add attribution for Result.ts (#468) --- tests/RuntimeLimits.test.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index dcbf0b61..42411de2 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -41,6 +41,11 @@ TEST_CASE_FIXTURE(LimitFixture, "bail_early_on_typescript_port_of_Result_type" * { constexpr const char* src = R"LUA( --!strict + + -- Big thanks to Dionysusnu by letting us use this code as part of our test suite! + -- https://github.com/Dionysusnu/rbxts-rust-classes + -- Licensed under the MPL 2.0: https://raw.githubusercontent.com/Dionysusnu/rbxts-rust-classes/master/LICENSE + local TS = _G[script] local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit From 9bc71c4b133aeb0e9734cfa5455f48d2829edf9d Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Tue, 3 May 2022 15:29:01 -0700 Subject: [PATCH 052/102] April 2022 recap (#470) --- .../2022-05-02-luau-recap-april-2022.md | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 docs/_posts/2022-05-02-luau-recap-april-2022.md diff --git a/docs/_posts/2022-05-02-luau-recap-april-2022.md b/docs/_posts/2022-05-02-luau-recap-april-2022.md new file mode 100644 index 00000000..f5a14e3e --- /dev/null +++ b/docs/_posts/2022-05-02-luau-recap-april-2022.md @@ -0,0 +1,51 @@ +--- +layout: single +title: "Luau Recap: April 2022" +--- + +Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-april-2022/).] + +It's been a bit of a quiet month. We mostly have small optimizations and bugfixes for you. + +It is now allowed to define functions on sealed tables that have string indexers. These functions will be typechecked against the indexer type. For example, the following is now valid: + +```lua +local a : {[string]: () -> number} = {} + +function b.y() return 4 end -- OK +``` + +Autocomplete will now provide string literal suggestions for singleton types. eg + +```lua +local function f(x: "a" | "b") end +f("_") -- suggest "a" and "b" +``` + +Improve error recovery in the case where we encounter a type pack variable in a place where one is not allowed. eg `type Foo = { value: A... }` + +When code does not pass enough arguments to a variadic function, the error feedback is now better. + +For example, the following script now produces a much nicer error message: +```lua +type A = { [number]: number } +type B = { [number]: string } + +local a: A = { 1, 2, 3 } + +-- ERROR: Type 'A' could not be converted into 'B' +-- caused by: +-- Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' +local b: B = a +``` + +If the following code were to error because `Hello` was undefined, we would erroneously include the comment in the span of the error. This is now fixed. +```lua +type Foo = Hello -- some comment over here +``` + +Fix a crash that could occur when strict scripts have cyclic require() dependencies. + +Add an option to autocomplete to cause it to abort processing after a certain amount of time has elapsed. From 47a8d28aa92379a724a0a9646d176c8375a7f111 Mon Sep 17 00:00:00 2001 From: Alexander McCord <11488393+alexmccord@users.noreply.github.com> Date: Tue, 3 May 2022 16:12:59 -0700 Subject: [PATCH 053/102] Fix a typo in recap. (#472) --- docs/_posts/2022-05-02-luau-recap-april-2022.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_posts/2022-05-02-luau-recap-april-2022.md b/docs/_posts/2022-05-02-luau-recap-april-2022.md index f5a14e3e..dd6b2c0c 100644 --- a/docs/_posts/2022-05-02-luau-recap-april-2022.md +++ b/docs/_posts/2022-05-02-luau-recap-april-2022.md @@ -14,7 +14,7 @@ It is now allowed to define functions on sealed tables that have string indexers ```lua local a : {[string]: () -> number} = {} -function b.y() return 4 end -- OK +function a.y() return 4 end -- OK ``` Autocomplete will now provide string literal suggestions for singleton types. eg From 9156b5ae6db17d937c75657674aa2470d6752b07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?byte-chan=E2=84=A2?= Date: Wed, 4 May 2022 21:27:12 +0200 Subject: [PATCH 054/102] Fix non-C locale issues in REPL (#474) --- CLI/Repl.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4cb22346..83060f5b 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -21,6 +21,8 @@ #include #endif +#include + LUAU_FASTFLAG(DebugLuauTimeTracing) enum class CliMode @@ -435,6 +437,9 @@ static void runReplImpl(lua_State* L) { ic_set_default_completer(completeRepl, L); + // Reset the locale to C + setlocale(LC_ALL, "C"); + // Make brace matching easier to see ic_style_def("ic-bracematch", "teal"); From 57016582a7f0ac07d1ce11767347ba5346fb004c Mon Sep 17 00:00:00 2001 From: phoebe <13684891+phoebethewitch@users.noreply.github.com> Date: Thu, 5 May 2022 17:37:27 -0400 Subject: [PATCH 055/102] fix feed link (#476) --- docs/_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_config.yml b/docs/_config.yml index 71308686..33a85609 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,7 +10,7 @@ logo: /assets/images/luau-88.png plugins: ["jekyll-include-cache", "jekyll-feed"] include: ["_pages"] atom_feed: - path: feed.xml + path: "/feed.xml" defaults: # _docs From e9cc76a3d5278533377eeff5b1de27a2b0f800e4 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 5 May 2022 17:03:43 -0700 Subject: [PATCH 056/102] Sync to upstream/release/526 (#477) --- Analysis/include/Luau/Frontend.h | 1 - Analysis/include/Luau/VisitTypeVar.h | 319 +++++++- Analysis/src/Autocomplete.cpp | 54 +- Analysis/src/Frontend.cpp | 34 +- Analysis/src/Normalize.cpp | 329 ++++++-- Analysis/src/Quantify.cpp | 50 +- Analysis/src/ToString.cpp | 50 +- Analysis/src/TxnLog.cpp | 34 +- Analysis/src/TypeInfer.cpp | 125 ++-- Analysis/src/Unifier.cpp | 134 +++- Ast/include/Luau/Ast.h | 2 +- Ast/src/Parser.cpp | 5 +- Compiler/include/Luau/Bytecode.h | 5 + Compiler/src/BytecodeBuilder.cpp | 10 + Compiler/src/Compiler.cpp | 338 ++++++++- Compiler/src/ConstantFolding.cpp | 53 +- Sources.cmake | 1 + VM/src/lapi.cpp | 2 +- VM/src/lbuiltins.cpp | 4 +- VM/src/lgc.h | 2 +- VM/src/ltable.cpp | 48 +- VM/src/ltm.cpp | 4 +- VM/src/ltm.h | 3 +- VM/src/lvmexecute.cpp | 196 ++++- .../test_LargeTableSum_loop_iter.lua | 17 + bench/tests/sunspider/3d-cube.lua | 30 +- bench/tests/sunspider/3d-morph.lua | 2 +- bench/tests/sunspider/3d-raytrace.lua | 44 +- bench/tests/sunspider/access-binary-trees.lua | 69 -- .../tests/sunspider/controlflow-recursive.lua | 8 +- bench/tests/sunspider/crypto-aes.lua | 148 ++-- bench/tests/sunspider/math-cordic.lua | 10 +- bench/tests/sunspider/math-partial-sums.lua | 2 +- bench/tests/sunspider/math-spectral-norm.lua | 72 -- tests/Autocomplete.test.cpp | 8 - tests/Compiler.test.cpp | 708 +++++++++++++++++- tests/Conformance.test.cpp | 12 + tests/Frontend.test.cpp | 4 - tests/Parser.test.cpp | 4 - tests/RuntimeLimits.test.cpp | 4 +- tests/TypeInfer.loops.test.cpp | 67 ++ tests/TypeInfer.modules.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 4 +- tests/TypeInfer.test.cpp | 41 + tests/TypeInfer.tryUnify.test.cpp | 4 - tests/TypeVar.test.cpp | 22 +- tests/VisitTypeVar.test.cpp | 48 ++ tests/conformance/iter.lua | 196 +++++ tests/conformance/nextvar.lua | 55 +- tools/lldb_formatters.py | 2 +- 50 files changed, 2658 insertions(+), 727 deletions(-) create mode 100644 bench/micro_tests/test_LargeTableSum_loop_iter.lua delete mode 100644 bench/tests/sunspider/access-binary-trees.lua delete mode 100644 bench/tests/sunspider/math-spectral-norm.lua create mode 100644 tests/VisitTypeVar.test.cpp create mode 100644 tests/conformance/iter.lua diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 59125470..37e3cfdc 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -145,7 +145,6 @@ struct Frontend */ std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); - CheckResult check(const SourceModule& module); // OLD. TODO KILL LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 045190ea..67fce5ed 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include + #include "Luau/DenseHash.h" -#include "Luau/TypeVar.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" +#include "Luau/TypeVar.h" + +LUAU_FASTFLAG(LuauUseVisitRecursionLimit) +LUAU_FASTINT(LuauVisitRecursionLimit) namespace Luau { @@ -219,24 +225,321 @@ void visit(TypePackId tp, F& f, Set& seen) } // namespace visit_detail +template +struct GenericTypeVarVisitor +{ + using Set = S; + + Set seen; + int recursionCounter = 0; + + GenericTypeVarVisitor() = default; + + explicit GenericTypeVarVisitor(Set seen) + : seen(std::move(seen)) + { + } + + virtual void cycle(TypeId) {} + virtual void cycle(TypePackId) {} + + virtual bool visit(TypeId ty) + { + return true; + } + virtual bool visit(TypeId ty, const BoundTypeVar& btv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const GenericTypeVar& gtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ErrorTypeVar& etv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const MetatableTypeVar& mtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ClassTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const AnyTypeVar& atv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const UnionTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const IntersectionTypeVar& itv) + { + return visit(ty); + } + + virtual bool visit(TypePackId tp) + { + return true; + } + virtual bool visit(TypePackId tp, const BoundTypePack& btp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const GenericTypePack& gtp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const Unifiable::Error& etp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const TypePack& pack) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const VariadicTypePack& vtp) + { + return visit(tp); + } + + void traverse(TypeId ty) + { + RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit, "TypeVarVisitor"}; + + if (visit_detail::hasSeen(seen, ty)) + { + cycle(ty); + return; + } + + if (auto btv = get(ty)) + { + if (visit(ty, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(ty)) + visit(ty, *ftv); + + else if (auto gtv = get(ty)) + visit(ty, *gtv); + + else if (auto etv = get(ty)) + visit(ty, *etv); + + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (TypeId part : ctv->parts) + traverse(part); + } + } + + else if (auto ptv = get(ty)) + visit(ty, *ptv); + + else if (auto ftv = get(ty)) + { + if (visit(ty, *ftv)) + { + traverse(ftv->argTypes); + traverse(ftv->retType); + } + } + + else if (auto ttv = get(ty)) + { + // Some visitors want to see bound tables, that's why we traverse the original type + if (visit(ty, *ttv)) + { + if (ttv->boundTo) + { + traverse(*ttv->boundTo); + } + else + { + for (auto& [_name, prop] : ttv->props) + traverse(prop.type); + + if (ttv->indexer) + { + traverse(ttv->indexer->indexType); + traverse(ttv->indexer->indexResultType); + } + } + } + } + + else if (auto mtv = get(ty)) + { + if (visit(ty, *mtv)) + { + traverse(mtv->table); + traverse(mtv->metatable); + } + } + + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (const auto& [name, prop] : ctv->props) + traverse(prop.type); + + if (ctv->parent) + traverse(*ctv->parent); + + if (ctv->metatable) + traverse(*ctv->metatable); + } + } + + else if (auto atv = get(ty)) + visit(ty, *atv); + + else if (auto utv = get(ty)) + { + if (visit(ty, *utv)) + { + for (TypeId optTy : utv->options) + traverse(optTy); + } + } + + else if (auto itv = get(ty)) + { + if (visit(ty, *itv)) + { + for (TypeId partTy : itv->parts) + traverse(partTy); + } + } + + visit_detail::unsee(seen, ty); + } + + void traverse(TypePackId tp) + { + if (visit_detail::hasSeen(seen, tp)) + { + cycle(tp); + return; + } + + if (auto btv = get(tp)) + { + if (visit(tp, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(tp)) + visit(tp, *ftv); + + else if (auto gtv = get(tp)) + visit(tp, *gtv); + + else if (auto etv = get(tp)) + visit(tp, *etv); + + else if (auto pack = get(tp)) + { + visit(tp, *pack); + + for (TypeId ty : pack->head) + traverse(ty); + + if (pack->tail) + traverse(*pack->tail); + } + else if (auto pack = get(tp)) + { + visit(tp, *pack); + traverse(pack->ty); + } + else + LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); + + visit_detail::unsee(seen, tp); + } +}; + +/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. + * + * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use + * TypeVarOnceVisitor. + */ +struct TypeVarVisitor : GenericTypeVarVisitor> +{ +}; + +/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. +struct TypeVarOnceVisitor : GenericTypeVarVisitor> +{ + TypeVarOnceVisitor() + : GenericTypeVarVisitor{DenseHashSet{nullptr}} + { + } +}; + +// Clip with FFlagLuauUseVisitRecursionLimit template -void visitTypeVar(TID ty, F& f, std::unordered_set& seen) +void DEPRECATED_visitTypeVar(TID ty, F& f, std::unordered_set& seen) { visit_detail::visit(ty, f, seen); } +// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit template -void visitTypeVar(TID ty, F& f) +void DEPRECATED_visitTypeVar(TID ty, F& f) { - std::unordered_set seen; - visit_detail::visit(ty, f, seen); + if (FFlag::LuauUseVisitRecursionLimit) + f.traverse(ty); + else + { + std::unordered_set seen; + visit_detail::visit(ty, f, seen); + } } +// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit template -void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) +void DEPRECATED_visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) { - seen.clear(); - visit_detail::visit(ty, f, seen); + if (FFlag::LuauUseVisitRecursionLimit) + f.traverse(ty); + else + { + seen.clear(); + visit_detail::visit(ty, f, seen); + } } } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dec12d01..19d06cfc 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,7 +14,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) @@ -1341,38 +1340,21 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - if (FFlag::LuauAutocompleteSingletonTypes) - { - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, result); - } - else - { - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - } + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, result); } } @@ -1680,11 +1662,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { AutocompleteEntryMap result; - if (FFlag::LuauAutocompleteSingletonTypes) - { - if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, result); - } + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, result); if (finder.ancestry.size() >= 2) { @@ -1693,8 +1672,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto it = module->astTypes.find(idxExpr->expr)) autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); } - else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as(); - binExpr && FFlag::LuauAutocompleteSingletonTypes) + else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b8f7836d..56c0ac2c 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) -LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) @@ -433,8 +432,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view sour return {std::move(sourceModule), classifyLints(warnings, config)}; } -CheckResult Frontend::check(const SourceModule& module) -{ - LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - - const Config& config = configResolver->getConfig(module.name); - - Mode mode = module.mode.value_or(config.mode); - - double timestamp = getTimestamp(); - - ModulePtr checkedModule = typeChecker.check(module, mode); - - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; - - if (checkedModule == nullptr) - throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name); - moduleResolver.modules[module.name] = checkedModule; - - return CheckResult{checkedModule->errors}; -} - LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 043526ed..d8c11388 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -304,37 +304,23 @@ static bool areNormal(TypePackId tp, const std::unordered_set& seen, Inte ++iterationLimit; \ } while (false) -struct Normalize +struct Normalize final : TypeVarVisitor { + using TypeVarVisitor::Set; + + Normalize(TypeArena& arena, InternalErrorReporter& ice) + : arena(arena) + , ice(ice) + { + } + TypeArena& arena; InternalErrorReporter& ice; - // Debug data. Types being normalized are invalidated but trying to see what's going on is painful. - // To actually see the original type, read it by using the pointer of the type being normalized. - // e.g. in lldb, `e dump(originalTys[ty])`. - SeenTypes originalTys; - SeenTypePacks originalTps; - int iterationLimit = 0; bool limitExceeded = false; - template - bool operator()(TypePackId, const T&) - { - return true; - } - - template - void cycle(TID) - { - } - - bool operator()(TypeId ty, const FreeTypeVar&) - { - LUAU_ASSERT(!ty->normal); - return false; - } - + // TODO: Clip with FFlag::LuauUseVisitRecursionLimit bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set& seen) { // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. @@ -349,27 +335,22 @@ struct Normalize return !ty->normal; } - bool operator()(TypeId ty, const PrimitiveTypeVar&) + bool operator()(TypeId ty, const FreeTypeVar& ftv) { - LUAU_ASSERT(ty->normal); - return false; + return visit(ty, ftv); } - - bool operator()(TypeId ty, const GenericTypeVar&) + bool operator()(TypeId ty, const PrimitiveTypeVar& ptv) { - if (!ty->normal) - asMutable(ty)->normal = true; - - return false; + return visit(ty, ptv); } - - bool operator()(TypeId ty, const ErrorTypeVar&) + bool operator()(TypeId ty, const GenericTypeVar& gtv) { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; + return visit(ty, gtv); + } + bool operator()(TypeId ty, const ErrorTypeVar& etv) + { + return visit(ty, etv); } - bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -470,17 +451,12 @@ struct Normalize bool operator()(TypeId ty, const ClassTypeVar& ctv) { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; + return visit(ty, ctv); } - - bool operator()(TypeId ty, const AnyTypeVar&) + bool operator()(TypeId ty, const AnyTypeVar& atv) { - LUAU_ASSERT(ty->normal); - return false; + return visit(ty, atv); } - bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -570,8 +546,257 @@ struct Normalize return false; } - bool operator()(TypeId ty, const LazyTypeVar&) + // TODO: Clip with FFlag::LuauUseVisitRecursionLimit + template + bool operator()(TypePackId, const T&) { + return true; + } + + // TODO: Clip with FFlag::LuauUseVisitRecursionLimit + template + void cycle(TID) + { + } + + bool visit(TypeId ty, const FreeTypeVar&) override + { + LUAU_ASSERT(!ty->normal); + return false; + } + + bool visit(TypeId ty, const BoundTypeVar& btv) override + { + // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. + // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. + if (seen.find(asMutable(btv.boundTo)) != seen.end()) + return false; + + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. + LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); + + asMutable(ty)->normal = btv.boundTo->normal; + return !ty->normal; + } + + bool visit(TypeId ty, const PrimitiveTypeVar&) override + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool visit(TypeId ty, const GenericTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + + return false; + } + + bool visit(TypeId ty, const ErrorTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override + { + CHECK_ITERATION_LIMIT(false); + + ConstrainedTypeVar* ctv = const_cast(&ctvRef); + + std::vector parts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId part : parts) + traverse(part); + + std::vector newParts = normalizeUnion(parts); + + const bool normal = areNormal(newParts, seen, ice); + + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const FunctionTypeVar& ftv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + traverse(ftv.argTypes); + traverse(ftv.retType); + + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + + return false; + } + + bool visit(TypeId ty, const TableTypeVar& ttv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + bool normal = true; + + auto checkNormal = [&](TypeId t) { + // if t is on the stack, it is possible that this type is normal. + // If t is not normal and it is not on the stack, this type is definitely not normal. + if (!t->normal && seen.find(asMutable(t)) == seen.end()) + normal = false; + }; + + if (ttv.boundTo) + { + traverse(*ttv.boundTo); + asMutable(ty)->normal = (*ttv.boundTo)->normal; + return false; + } + + for (const auto& [_name, prop] : ttv.props) + { + traverse(prop.type); + checkNormal(prop.type); + } + + if (ttv.indexer) + { + traverse(ttv.indexer->indexType); + checkNormal(ttv.indexer->indexType); + traverse(ttv.indexer->indexResultType); + checkNormal(ttv.indexer->indexResultType); + } + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const MetatableTypeVar& mtv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + traverse(mtv.table); + traverse(mtv.metatable); + + asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; + + return false; + } + + bool visit(TypeId ty, const ClassTypeVar& ctv) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const AnyTypeVar&) override + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool visit(TypeId ty, const UnionTypeVar& utvRef) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + UnionTypeVar* utv = &const_cast(utvRef); + std::vector options = std::move(utv->options); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId option : options) + traverse(option); + + std::vector newOptions = normalizeUnion(options); + + const bool normal = areNormal(newOptions, seen, ice); + + LUAU_ASSERT(!newOptions.empty()); + + if (newOptions.size() == 1) + *asMutable(ty) = BoundTypeVar{newOptions[0]}; + else + utv->options = std::move(newOptions); + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + IntersectionTypeVar* itv = &const_cast(itvRef); + + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + traverse(part); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } + return false; } @@ -778,9 +1003,9 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(ty, arena, state); - Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + Normalize n{arena, ice}; std::unordered_set seen; - visitTypeVar(ty, n, seen); + DEPRECATED_visitTypeVar(ty, n, seen); return {ty, !n.limitExceeded}; } @@ -803,9 +1028,9 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(tp, arena, state); - Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + Normalize n{arena, ice}; std::unordered_set seen; - visitTypeVar(tp, n, seen); + DEPRECATED_visitTypeVar(tp, n, seen); return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 305f83ce..4f3e4469 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -9,7 +9,7 @@ LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { -struct Quantifier +struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; std::vector generics; @@ -17,26 +17,17 @@ struct Quantifier bool seenGenericType = false; bool seenMutableType = false; - Quantifier(TypeLevel level) + explicit Quantifier(TypeLevel level) : level(level) { } - void cycle(TypeId) {} - void cycle(TypePackId) {} + void cycle(TypeId) override {} + void cycle(TypePackId) override {} bool operator()(TypeId ty, const FreeTypeVar& ftv) { - if (FFlag::LuauTypecheckOptPass) - seenMutableType = true; - - if (!level.subsumes(ftv.level)) - return false; - - *asMutable(ty) = GenericTypeVar{level}; - generics.push_back(ty); - - return false; + return visit(ty, ftv); } template @@ -56,8 +47,33 @@ struct Quantifier return true; } - bool operator()(TypeId ty, const TableTypeVar&) + bool operator()(TypeId ty, const TableTypeVar& ttv) { + return visit(ty, ttv); + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + + bool visit(TypeId ty, const FreeTypeVar& ftv) override + { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); + + return false; + } + + bool visit(TypeId ty, const TableTypeVar&) override + { + LUAU_ASSERT(getMutable(ty)); TableTypeVar& ttv = *getMutable(ty); if (FFlag::LuauTypecheckOptPass) @@ -93,7 +109,7 @@ struct Quantifier return true; } - bool operator()(TypePackId tp, const FreeTypePack& ftp) + bool visit(TypePackId tp, const FreeTypePack& ftp) override { if (FFlag::LuauTypecheckOptPass) seenMutableType = true; @@ -111,7 +127,7 @@ void quantify(TypeId ty, TypeLevel level) { Quantifier q{level}; DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, q, seen); + DEPRECATED_visitTypeVarOnce(ty, q, seen); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 610842da..b5d6a550 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -26,7 +26,7 @@ namespace Luau namespace { -struct FindCyclicTypes +struct FindCyclicTypes final : TypeVarVisitor { FindCyclicTypes() = default; FindCyclicTypes(const FindCyclicTypes&) = delete; @@ -38,20 +38,22 @@ struct FindCyclicTypes std::set cycles; std::set cycleTPs; - void cycle(TypeId ty) + void cycle(TypeId ty) override { cycles.insert(ty); } - void cycle(TypePackId tp) + void cycle(TypePackId tp) override { cycleTPs.insert(tp); } + // TODO: Clip all the operator()s when we clip FFlagLuauUseVisitRecursionLimit + template bool operator()(TypeId ty, const T&) { - return visited.insert(ty).second; + return visit(ty); } bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; @@ -64,10 +66,10 @@ struct FindCyclicTypes if (ttv.name || ttv.syntheticName) { for (TypeId itp : ttv.instantiatedTypeParams) - visitTypeVar(itp, *this, seen); + DEPRECATED_visitTypeVar(itp, *this, seen); for (TypePackId itp : ttv.instantiatedTypePackParams) - visitTypeVar(itp, *this, seen); + DEPRECATED_visitTypeVar(itp, *this, seen); return exhaustive; } @@ -82,9 +84,43 @@ struct FindCyclicTypes template bool operator()(TypePackId tp, const T&) + { + return visit(tp); + } + + bool visit(TypeId ty) override + { + return visited.insert(ty).second; + } + + bool visit(TypePackId tp) override { return visitedPacks.insert(tp).second; } + + bool visit(TypeId ty, const TableTypeVar& ttv) override + { + if (!visited.insert(ty).second) + return false; + + if (ttv.name || ttv.syntheticName) + { + for (TypeId itp : ttv.instantiatedTypeParams) + traverse(itp); + + for (TypePackId itp : ttv.instantiatedTypePackParams) + traverse(itp); + + return exhaustive; + } + + return true; + } + + bool visit(TypeId ty, const ClassTypeVar&) override + { + return false; + } }; template @@ -92,7 +128,7 @@ void findCyclicTypes(std::set& cycles, std::set& cycleTPs, T { FindCyclicTypes fct; fct.exhaustive = exhaustive; - visitTypeVar(ty, fct); + DEPRECATED_visitTypeVar(ty, fct); cycles = std::move(fct.cycles); cycleTPs = std::move(fct.cycleTPs); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index a5f9d26c..1fb5a61a 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,7 +7,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) namespace Luau @@ -81,31 +80,20 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { - if (FFlag::LuauTxnLogPreserveOwner) + for (auto& [ty, rep] : typeVarChanges) { - for (auto& [ty, rep] : typeVarChanges) - { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; - } - - for (auto& [tp, rep] : typePackChanges) - { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; - } + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; } - else - { - for (auto& [ty, rep] : typeVarChanges) - *asMutable(ty) = rep.get()->pending; - for (auto& [tp, rep] : typePackChanges) - *asMutable(tp) = rep.get()->pending; + for (auto& [tp, rep] : typePackChanges) + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; } clear(); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index ba91ae1e..4466ede2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,11 +26,11 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) +LUAU_FASTFLAGVARIABLE(LuauUseVisitRecursionLimit, false) +LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) -LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. @@ -40,6 +40,7 @@ LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) +LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) @@ -57,6 +58,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAG(LuauLosslessClone) +LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); namespace Luau { @@ -1159,6 +1161,47 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); } + if (FFlag::LuauTypecheckIter) + { + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + { + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + + return check(loopScope, *forin.body); + } + else if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) + { + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + + for (TypeId var : varTypes) + unify(varTy, var, forin.location); + + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); + } + + return check(loopScope, *forin.body); + } + } + const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { @@ -2026,15 +2069,29 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) if (const UnionTypeVar* utv = get(t)) { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) + if (FFlag::LuauReduceUnionRecursion) { - ty = follow(ty); - if (get(ty) || get(ty)) - return {ty}; + for (TypeId ty : utv) + { + if (get(ty) || get(ty)) + return {ty}; - if (std::find(result.begin(), result.end(), ty) == result.end()) - result.push_back(ty); + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + { + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (std::find(result.begin(), result.end(), ty) == result.end()) + result.push_back(ty); + } } } else if (std::find(result.begin(), result.end(), t) == result.end()) @@ -4372,17 +4429,12 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module } // Types of requires that transitively refer to current module have to be replaced with 'any' - std::string humanReadableName; + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - if (FFlag::LuauCyclicModuleTypeSurface) + for (const auto& [location, path] : requireCycles) { - humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - - for (const auto& [location, path] : requireCycles) - { - if (!path.empty() && path.front() == humanReadableName) - return anyType; - } + if (!path.empty() && path.front() == humanReadableName) + return anyType; } ModulePtr module = resolver->getModule(moduleInfo.name); @@ -4392,32 +4444,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) - { - if (FFlag::LuauCyclicModuleTypeSurface) - { - reportError(TypeError{location, UnknownRequire{humanReadableName}}); - } - else - { - std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(TypeError{location, UnknownRequire{reportedModulePath}}); - } - } + reportError(TypeError{location, UnknownRequire{humanReadableName}}); return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { - if (FFlag::LuauCyclicModuleTypeSurface) - { - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - } - else - { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - } + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); return errorRecoveryType(scope); } @@ -4429,15 +4463,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - if (FFlag::LuauCyclicModuleTypeSurface) - { - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - } - else - { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - } + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); return errorRecoveryType(scope); } @@ -4947,10 +4973,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { - if (FFlag::LuauAutocompleteSingletonTypes) - return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; - - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); + return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; } TypeId TypeChecker::singletonType(std::string value) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 334806ce..f5c1dde9 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) -LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) +LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauTypecheckOptPass) @@ -30,7 +30,7 @@ LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { -struct PromoteTypeLevels +struct PromoteTypeLevels final : TypeVarOnceVisitor { TxnLog& log; const TypeArena* typeArena = nullptr; @@ -53,13 +53,34 @@ struct PromoteTypeLevels } } + // TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped template void cycle(TID) { } - template bool operator()(TID ty, const T&) + { + return visit(ty); + } + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty, ftv); + } + bool operator()(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty, ftv); + } + bool operator()(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty, ttv); + } + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + + bool visit(TypeId ty) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -68,7 +89,16 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const FreeTypeVar&) + bool visit(TypePackId tp) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (tp->owningArena != typeArena) + return false; + + return true; + } + + bool visit(TypeId ty, const FreeTypeVar&) override { // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. @@ -79,7 +109,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const FunctionTypeVar&) + bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -89,7 +119,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const TableTypeVar& ttv) + bool visit(TypeId ty, const TableTypeVar& ttv) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -102,7 +132,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypePackId tp, const FreeTypePack&) + bool visit(TypePackId tp, const FreeTypePack&) override { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. @@ -122,7 +152,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, ptl, seen); + DEPRECATED_visitTypeVarOnce(ty, ptl, seen); } void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) @@ -133,10 +163,10 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; - visitTypeVarOnce(tp, ptl, seen); + DEPRECATED_visitTypeVarOnce(tp, ptl, seen); } -struct SkipCacheForType +struct SkipCacheForType final : TypeVarOnceVisitor { SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) : skipCacheForType(skipCacheForType) @@ -144,28 +174,68 @@ struct SkipCacheForType { } - void cycle(TypeId) {} - void cycle(TypePackId) {} + // TODO cycle() and operator() can be clipped with FFlagLuauUseVisitRecursionLimit + void cycle(TypeId) override {} + void cycle(TypePackId) override {} bool operator()(TypeId ty, const FreeTypeVar& ftv) { - result = true; - return false; + return visit(ty, ftv); } - bool operator()(TypeId ty, const BoundTypeVar& btv) { - result = true; - return false; + return visit(ty, btv); + } + bool operator()(TypeId ty, const GenericTypeVar& gtv) + { + return visit(ty, gtv); + } + bool operator()(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty, ttv); + } + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + bool operator()(TypePackId tp, const BoundTypePack& ftp) + { + return visit(tp, ftp); + } + bool operator()(TypePackId tp, const GenericTypePack& ftp) + { + return visit(tp, ftp); + } + template + bool operator()(TypeId ty, const T& t) + { + return visit(ty); + } + template + bool operator()(TypePackId tp, const T&) + { + return visit(tp); } - bool operator()(TypeId ty, const GenericTypeVar& btv) + bool visit(TypeId, const FreeTypeVar&) override { result = true; return false; } - bool operator()(TypeId ty, const TableTypeVar&) + bool visit(TypeId, const BoundTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId, const GenericTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId ty, const TableTypeVar&) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) @@ -188,8 +258,7 @@ struct SkipCacheForType return true; } - template - bool operator()(TypeId ty, const T& t) + bool visit(TypeId ty) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) @@ -206,8 +275,7 @@ struct SkipCacheForType return true; } - template - bool operator()(TypePackId tp, const T&) + bool visit(TypePackId tp) override { // Types from other modules don't contain mutable elements and are ok to cache if (tp->owningArena != typeArena) @@ -216,19 +284,19 @@ struct SkipCacheForType return true; } - bool operator()(TypePackId tp, const FreeTypePack& ftp) + bool visit(TypePackId tp, const FreeTypePack&) override { result = true; return false; } - bool operator()(TypePackId tp, const BoundTypePack& ftp) + bool visit(TypePackId tp, const BoundTypePack&) override { result = true; return false; } - bool operator()(TypePackId tp, const GenericTypePack& ftp) + bool visit(TypePackId tp, const GenericTypePack&) override { result = true; return false; @@ -578,7 +646,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) { } else @@ -593,7 +661,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId } // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) { auto tryBind = [this, subTy](TypeId superOption) { superOption = log.follow(superOption); @@ -603,6 +671,14 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) return; + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) + { + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; + } + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. if (log.haveSeen(subTy, superOption)) @@ -822,7 +898,7 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; - visitTypeVarOnce(ty, visitor, sharedState.seenAny); + DEPRECATED_visitTypeVarOnce(ty, visitor, sharedState.seenAny); sharedState.skipCacheForType[ty] = visitor.result; diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 31cd01cc..6f39e3fd 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -313,7 +313,7 @@ template struct AstArray { T* data; - std::size_t size; + size_t size; const T* begin() const { diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 31ff3f77..91f5cd25 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau @@ -1430,7 +1429,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } - else if (FFlag::LuauParseRecoverUnexpectedPack && c == Lexeme::Dot3) + else if (c == Lexeme::Dot3) { report(lexer.current().location, "Unexpected '...' after type annotation"); nextLexeme(); @@ -1551,7 +1550,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) prefix = name.name; name = parseIndexName("field name", pointPosition); } - else if (FFlag::LuauParseRecoverUnexpectedPack && lexer.current().type == Lexeme::Dot3) + else if (lexer.current().type == Lexeme::Dot3) { report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); nextLexeme(); diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index c6e5a03b..f71d893c 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -353,6 +353,11 @@ enum LuauOpcode // AUX: constant index LOP_FASTCALL2K, + // FORGPREP: prepare loop variables for a generic for loop, jump to the loop backedge unconditionally + // A: target register; generic for loops assume a register layout [generator, state, index, variables...] + // D: jump offset (-32768..32767) + LOP_FORGPREP, + // Enum entry for number of opcodes, not a valid opcode by itself! LOP__COUNT }; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 871a1484..fb70392e 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -96,6 +96,7 @@ inline bool isJumpD(LuauOpcode op) case LOP_JUMPIFNOTLT: case LOP_FORNPREP: case LOP_FORNLOOP: + case LOP_FORGPREP: case LOP_FORGLOOP: case LOP_FORGPREP_INEXT: case LOP_FORGLOOP_INEXT: @@ -1269,6 +1270,11 @@ void BytecodeBuilder::validate() const VJUMP(LUAU_INSN_D(insn)); break; + case LOP_FORGPREP: + VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VJUMP(LUAU_INSN_D(insn)); + break; + case LOP_FORGLOOP: VREG( LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables @@ -1622,6 +1628,10 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); break; + case LOP_FORGPREP: + formatAppend(result, "FORGPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + case LOP_FORGLOOP: formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 0f17ee02..4fe26222 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -17,9 +17,19 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileSupportInlining, false) + +LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) +LUAU_FASTFLAGVARIABLE(LuauCompileIterNoReserve, false) +LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) + LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) +LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) +LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) +LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) + namespace Luau { @@ -147,6 +157,52 @@ struct Compiler } } + AstExprFunction* getFunctionExpr(AstExpr* node) + { + if (AstExprLocal* le = node->as()) + { + Variable* lv = variables.find(le->local); + + if (!lv || lv->written || !lv->init) + return nullptr; + + return getFunctionExpr(lv->init); + } + else if (AstExprGroup* ge = node->as()) + return getFunctionExpr(ge->expr); + else + return node->as(); + } + + bool canInlineFunctionBody(AstStat* stat) + { + struct CanInlineVisitor : AstVisitor + { + bool result = true; + + bool visit(AstExpr* node) override + { + // nested functions may capture function arguments, and our upval handling doesn't handle elided variables (constant) + // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues + // TODO: additionally we would need to change upvalue handling in compileExprFunction to handle upvalue->local migration + result = result && !node->is(); + return result; + } + + bool visit(AstStat* node) override + { + // loops may need to be unrolled which can result in cost amplification + result = result && !node->is(); + return result; + } + }; + + CanInlineVisitor canInline; + stat->visit(&canInline); + + return canInline.result; + } + uint32_t compileFunction(AstExprFunction* func) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -214,13 +270,21 @@ struct Compiler bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); - stackSize = 0; - Function& f = functions[func]; f.id = fid; f.upvals = upvals; + // record information for inlining + if (FFlag::LuauCompileSupportInlining && options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && + !getfenvUsed && !setfenvUsed) + { + f.canInline = true; + f.stackSize = stackSize; + f.costModel = modelCost(func->body, func->args.data, func->args.size); + } + upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes + stackSize = 0; return fid; } @@ -390,12 +454,183 @@ struct Compiler } } + bool tryCompileInlinedCall(AstExprCall* expr, AstExprFunction* func, uint8_t target, uint8_t targetCount, bool multRet, int thresholdBase, + int thresholdMaxBoost, int depthLimit) + { + Function* fi = functions.find(func); + LUAU_ASSERT(fi); + + // make sure we have enough register space + if (regTop > 128 || fi->stackSize > 32) + { + bytecode.addDebugRemark("inlining failed: high register pressure"); + return false; + } + + // we should ideally aggregate the costs during recursive inlining, but for now simply limit the depth + if (int(inlineFrames.size()) >= depthLimit) + { + bytecode.addDebugRemark("inlining failed: too many inlined frames"); + return false; + } + + // compiling recursive inlining is difficult because we share constant/variable state but need to bind variables to different registers + for (InlineFrame& frame : inlineFrames) + if (frame.func == func) + { + bytecode.addDebugRemark("inlining failed: can't inline recursive calls"); + return false; + } + + // TODO: we can compile multret functions if all returns of the function are multret as well + if (multRet) + { + bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); + return false; + } + + // TODO: we can compile functions with mismatching arity at call site but it's more annoying + if (func->args.size != expr->args.size) + { + bytecode.addDebugRemark("inlining failed: argument count mismatch (expected %d, got %d)", int(func->args.size), int(expr->args.size)); + return false; + } + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + bool varc[8] = {}; + for (size_t i = 0; i < expr->args.size && i < 8; ++i) + varc[i] = isConstant(expr->args.data[i]); + + int inlinedCost = computeCost(fi->costModel, varc, std::min(int(expr->args.size), 8)); + int baselineCost = computeCost(fi->costModel, nullptr, 0) + 3; + int inlineProfit = (inlinedCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / inlinedCost); + + int threshold = thresholdBase * inlineProfit / 100; + + if (inlinedCost > threshold) + { + bytecode.addDebugRemark("inlining failed: too expensive (cost %d, profit %.2fx)", inlinedCost, double(inlineProfit) / 100); + return false; + } + + bytecode.addDebugRemark( + "inlining succeeded (cost %d, profit %.2fx, depth %d)", inlinedCost, double(inlineProfit) / 100, int(inlineFrames.size())); + + compileInlinedCall(expr, func, target, targetCount); + return true; + } + + void compileInlinedCall(AstExprCall* expr, AstExprFunction* func, uint8_t target, uint8_t targetCount) + { + RegScope rs(this); + + size_t oldLocals = localStack.size(); + + // note that we push the frame early; this is needed to block recursive inline attempts + inlineFrames.push_back({func, target, targetCount}); + + // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) + for (size_t i = 0; i < func->args.size; ++i) + { + AstLocal* var = func->args.data[i]; + AstExpr* arg = expr->args.data[i]; + + if (Variable* vv = variables.find(var); vv && vv->written) + { + // if the argument is mutated, we need to allocate a fresh register even if it's a constant + uint8_t reg = allocReg(arg, 1); + compileExprTemp(arg, reg); + pushLocal(var, reg); + } + else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) + { + // since the argument is not mutated, we can simply fold the value into the expressions that need it + locstants[var] = *cv; + } + else + { + AstExprLocal* le = arg->as(); + Variable* lv = le ? variables.find(le->local) : nullptr; + + // if the argument is a local that isn't mutated, we will simply reuse the existing register + if (isExprLocalReg(arg) && (!lv || !lv->written)) + { + uint8_t reg = getLocal(le->local); + pushLocal(var, reg); + } + else + { + uint8_t reg = allocReg(arg, 1); + compileExprTemp(arg, reg); + pushLocal(var, reg); + } + } + } + + // fold constant values updated above into expressions in the function body + foldConstants(constants, variables, locstants, func->body); + + bool usedFallthrough = false; + + for (size_t i = 0; i < func->body->body.size; ++i) + { + AstStat* stat = func->body->body.data[i]; + + if (AstStatReturn* ret = stat->as()) + { + // Optimization: use fallthrough when compiling return at the end of the function to avoid an extra JUMP + compileInlineReturn(ret, /* fallthrough= */ true); + // TODO: This doesn't work when return is part of control flow; ideally we would track the state somehow and generalize this + usedFallthrough = true; + break; + } + else + compileStat(stat); + } + + // for the fallthrough path we need to ensure we clear out target registers + if (!usedFallthrough && !allPathsEndWithReturn(func->body)) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + } + + popLocals(oldLocals); + + size_t returnLabel = bytecode.emitLabel(); + patchJumps(expr, inlineFrames.back().returnJumps, returnLabel); + + inlineFrames.pop_back(); + + // clean up constant state for future inlining attempts + for (size_t i = 0; i < func->args.size; ++i) + if (Constant* var = locstants.find(func->args.data[i])) + var->type = Constant::Type_Unknown; + + foldConstants(constants, variables, locstants, func->body); + } + void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) { LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly + // try inlining the function + if (options.optimizationLevel >= 2 && !expr->self) + { + AstExprFunction* func = getFunctionExpr(expr->func); + Function* fi = func ? functions.find(func) : nullptr; + + if (fi && fi->canInline && + tryCompileInlinedCall(expr, func, target, targetCount, multRet, FInt::LuauCompileInlineThreshold, + FInt::LuauCompileInlineThresholdMaxBoost, FInt::LuauCompileInlineDepth)) + return; + + if (fi && !fi->canInline) + bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + } + RegScope rs(this); unsigned int regCount = std::max(unsigned(1 + expr->self + expr->args.size), unsigned(targetCount)); @@ -760,7 +995,7 @@ struct Compiler { const Constant* c = constants.find(node); - if (!c) + if (!c || c->type == Constant::Type_Unknown) return -1; int cid = -1; @@ -1395,27 +1630,29 @@ struct Compiler { RegScope rs(this); + // note: cv may be invalidated by compileExpr* so we stop using it before calling compile recursively const Constant* cv = constants.find(expr->index); if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && double(int(cv->valueNumber)) == cv->valueNumber) { - uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); + uint8_t rt = compileExprAuto(expr->expr, rs); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } else if (cv && cv->type == Constant::Type_String) { - uint8_t rt = compileExprAuto(expr->expr, rs); - BytecodeBuilder::StringRef iname = sref(cv->getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + uint8_t rt = compileExprAuto(expr->expr, rs); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); @@ -1561,8 +1798,9 @@ struct Compiler } else if (AstExprLocal* expr = node->as()) { - if (expr->upvalue) + if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) { + LUAU_ASSERT(expr->upvalue); uint8_t uid = getUpval(expr->local); bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); @@ -1650,12 +1888,12 @@ struct Compiler // initializes target..target+targetCount-1 range using expressions from the list // if list has fewer expressions, and last expression is a call, we assume the call returns the rest of the values // if list has fewer expressions, and last expression isn't a call, we fill the rest with nil - // assumes target register range can be clobbered and is at the top of the register space - void compileExprListTop(const AstArray& list, uint8_t target, uint8_t targetCount) + // assumes target register range can be clobbered and is at the top of the register space if targetTop = true + void compileExprListTemp(const AstArray& list, uint8_t target, uint8_t targetCount, bool targetTop) { // we assume that target range is at the top of the register space and can be clobbered // this is what allows us to compile the last call expression - if it's a call - using targetTop=true - LUAU_ASSERT(unsigned(target + targetCount) == regTop); + LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); if (list.size == targetCount) { @@ -1683,7 +1921,7 @@ struct Compiler if (AstExprCall* expr = last->as()) { - compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), /* targetTop= */ true); + compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), targetTop); } else if (AstExprVarargs* expr = last->as()) { @@ -1765,8 +2003,10 @@ struct Compiler if (AstExprLocal* expr = node->as()) { - if (expr->upvalue) + if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) { + LUAU_ASSERT(expr->upvalue); + LValue result = {LValue::Kind_Upvalue}; result.upval = getUpval(expr->local); result.location = node->location; @@ -1873,7 +2113,7 @@ struct Compiler bool isExprLocalReg(AstExpr* expr) { AstExprLocal* le = expr->as(); - if (!le || le->upvalue) + if (!le || (!FFlag::LuauCompileSupportInlining && le->upvalue)) return false; Local* l = locals.find(le->local); @@ -2080,6 +2320,23 @@ struct Compiler loops.pop_back(); } + void compileInlineReturn(AstStatReturn* stat, bool fallthrough) + { + setDebugLine(stat); // normally compileStat sets up line info, but compileInlineReturn can be called directly + + InlineFrame frame = inlineFrames.back(); + + compileExprListTemp(stat->list, frame.target, frame.targetCount, /* targetTop= */ false); + + if (!fallthrough) + { + size_t jumpLabel = bytecode.emitLabel(); + bytecode.emitAD(LOP_JUMP, 0, 0); + + inlineFrames.back().returnJumps.push_back(jumpLabel); + } + } + void compileStatReturn(AstStatReturn* stat) { RegScope rs(this); @@ -2138,7 +2395,7 @@ struct Compiler // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); - compileExprListTop(stat->values, vars, uint8_t(stat->vars.size)); + compileExprListTemp(stat->values, vars, uint8_t(stat->vars.size), /* targetTop= */ true); for (size_t i = 0; i < stat->vars.size; ++i) pushLocal(stat->vars.data[i], uint8_t(vars + i)); @@ -2168,6 +2425,7 @@ struct Compiler bool visit(AstExpr* node) override { // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) + // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues result = result && !node->is(); return result; } @@ -2251,6 +2509,11 @@ struct Compiler compileStat(stat->body); } + // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again + locstants[var].type = Constant::Type_Unknown; + + foldConstants(constants, variables, locstants, stat); + return true; } @@ -2336,12 +2599,17 @@ struct Compiler uint8_t regs = allocReg(stat, 3); // this puts initial values of (generator, state, index) into the loop registers - compileExprListTop(stat->values, regs, 3); + compileExprListTemp(stat->values, regs, 3, /* targetTop= */ true); - // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" - // this requires at least extra 3 stack slots after index - // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough - reserveReg(stat, 3); + // we don't need this because the extra stack space is just for calling the function with a loop protocol which is similar to calling + // metamethods - it should fit into the extra stack reservation + if (!FFlag::LuauCompileIterNoReserve) + { + // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" + // this requires at least extra 3 stack slots after index + // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough + reserveReg(stat, 3); + } // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); @@ -2350,7 +2618,7 @@ struct Compiler // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 // variables, which is why we allocate at least 2 above (see vars assignment) - LuauOpcode skipOp = LOP_JUMP; + LuauOpcode skipOp = FFlag::LuauCompileIter ? LOP_FORGPREP : LOP_JUMP; LuauOpcode loopOp = LOP_FORGLOOP; if (options.optimizationLevel >= 1 && stat->vars.size <= 2) @@ -2367,7 +2635,7 @@ struct Compiler else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { skipOp = LOP_FORGPREP_NEXT; - loopOp = LOP_FORGLOOP_NEXT; + loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; } } else if (stat->values.size == 2) @@ -2377,7 +2645,7 @@ struct Compiler if (builtin.isGlobal("next")) // for .. in next,t { skipOp = LOP_FORGPREP_NEXT; - loopOp = LOP_FORGLOOP_NEXT; + loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; } } } @@ -2514,10 +2782,10 @@ struct Compiler // compute values into temporaries uint8_t regs = allocReg(stat, unsigned(stat->vars.size)); - compileExprListTop(stat->values, regs, uint8_t(stat->vars.size)); + compileExprListTemp(stat->values, regs, uint8_t(stat->vars.size), /* targetTop= */ true); - // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because compileExprListTop - // will generate nils + // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because + // compileExprListTemp will generate nils for (size_t i = 0; i < stat->vars.size; ++i) { setDebugLine(stat->vars.data[i]); @@ -2675,7 +2943,10 @@ struct Compiler } else if (AstStatReturn* stat = node->as()) { - compileStatReturn(stat); + if (options.optimizationLevel >= 2 && !inlineFrames.empty()) + compileInlineReturn(stat, /* fallthrough= */ false); + else + compileStatReturn(stat); } else if (AstStatExpr* stat = node->as()) { @@ -3069,6 +3340,10 @@ struct Compiler { uint32_t id; std::vector upvals; + + uint64_t costModel = 0; + unsigned int stackSize = 0; + bool canInline = false; }; struct Local @@ -3098,6 +3373,16 @@ struct Compiler AstExpr* untilCondition; }; + struct InlineFrame + { + AstExprFunction* func; + + uint8_t target; + uint8_t targetCount; + + std::vector returnJumps; + }; + BytecodeBuilder& bytecode; CompileOptions options; @@ -3120,6 +3405,7 @@ struct Compiler std::vector upvals; std::vector loopJumps; std::vector loops; + std::vector inlineFrames; }; void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 7ad91d4b..52ece73e 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -3,6 +3,8 @@ #include +LUAU_FASTFLAG(LuauCompileSupportInlining) + namespace Luau { namespace Compile @@ -314,12 +316,35 @@ struct ConstantVisitor : AstVisitor LUAU_ASSERT(!"Unknown expression type"); } - if (result.type != Constant::Type_Unknown) - constants[node] = result; + recordConstant(constants, node, result); return result; } + template + void recordConstant(DenseHashMap& map, T key, const Constant& value) + { + if (value.type != Constant::Type_Unknown) + map[key] = value; + else if (!FFlag::LuauCompileSupportInlining) + ; + else if (Constant* old = map.find(key)) + old->type = Constant::Type_Unknown; + } + + void recordValue(AstLocal* local, const Constant& value) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(local); + LUAU_ASSERT(v); + + if (!v->written) + { + v->constant = (value.type != Constant::Type_Unknown); + recordConstant(locals, local, value); + } + } + bool visit(AstExpr* node) override { // note: we short-circuit the visitor traversal through any expression trees by returning false @@ -336,18 +361,7 @@ struct ConstantVisitor : AstVisitor { Constant arg = analyze(node->values.data[i]); - if (arg.type != Constant::Type_Unknown) - { - // note: we rely on trackValues to have been run before us - Variable* v = variables.find(node->vars.data[i]); - LUAU_ASSERT(v); - - if (!v->written) - { - locals[node->vars.data[i]] = arg; - v->constant = true; - } - } + recordValue(node->vars.data[i], arg); } if (node->vars.size > node->values.size) @@ -361,15 +375,8 @@ struct ConstantVisitor : AstVisitor { for (size_t i = node->values.size; i < node->vars.size; ++i) { - // note: we rely on trackValues to have been run before us - Variable* v = variables.find(node->vars.data[i]); - LUAU_ASSERT(v); - - if (!v->written) - { - locals[node->vars.data[i]].type = Constant::Type_Nil; - v->constant = true; - } + Constant nil = {Constant::Type_Nil}; + recordValue(node->vars.data[i], nil); } } } diff --git a/Sources.cmake b/Sources.cmake index f9263b24..d2430cc9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -264,6 +264,7 @@ if(TARGET Luau.UnitTest) tests/TypePack.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp + tests/VisitTypeVar.test.cpp tests/main.cpp) endif() diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 1f3b0943..f8baefaf 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1270,7 +1270,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) L->top--; setobj(L, val, L->top); luaC_barrier(L, clvalue(fi), L->top); - luaC_upvalbarrier(L, NULL, val); + luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val); } return name; } diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 718d387d..60149199 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -15,6 +15,8 @@ #include #endif +LUAU_FASTFLAGVARIABLE(LuauFixBuiltinsStackLimit, false) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -1003,7 +1005,7 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St else if (nparams == 3 && ttisnumber(args) && ttisnumber(args + 1) && nvalue(args) == 1.0) n = int(nvalue(args + 1)); - if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n) + if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n && (!FFlag::LuauFixBuiltinsStackLimit || n + nparams <= LUAI_MAXCSTACK)) { TValue* array = t->array; for (int i = 0; i < n; ++i) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 08d1ff5d..797284a2 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -120,7 +120,7 @@ #define luaC_upvalbarrier(L, uv, tv) \ { \ - if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || ((UpVal*)uv)->v != &((UpVal*)uv)->u.value)) \ + if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || (uv)->v != &(uv)->u.value)) \ luaC_barrierupval(L, gcvalue(tv)); \ } diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 3dc3bd1b..8251b51c 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,8 +33,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -431,7 +429,6 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) static int adjustasize(Table* t, int size, const TValue* ek) { - LUAU_ASSERT(FFlag::LuauTableNewBoundary2); bool tbound = t->node != dummynode || size < t->sizearray; int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; /* move the array size up until the boundary is guaranteed to be inside the array part */ @@ -443,7 +440,7 @@ static int adjustasize(Table* t, int size, const TValue* ek) void luaH_resizearray(lua_State* L, Table* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); - int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize; + int asize = adjustasize(t, nasize, NULL); resize(L, t, asize, nsize); } @@ -468,8 +465,7 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) int na = computesizes(nums, &nasize); int nh = totaluse - na; /* enforce the boundary invariant; for performance, only do hash lookups if we must */ - if (FFlag::LuauTableNewBoundary2) - nasize = adjustasize(t, nasize, ek); + nasize = adjustasize(t, nasize, ek); /* resize the table to new computed sizes */ resize(L, t, nasize, nh); } @@ -531,7 +527,7 @@ static LuaNode* getfreepos(Table* t) static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* enforce boundary invariant */ - if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + if (ttisnumber(key) && nvalue(key) == t->sizearray + 1) { rehash(L, t, key); /* grow table */ @@ -713,37 +709,6 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) } } -static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) -{ - LUAU_ASSERT(!FFlag::LuauTableNewBoundary2); - unsigned int i = j; /* i is zero or a present index */ - j++; - /* find `i' and `j' such that i is present and j is not */ - while (!ttisnil(luaH_getnum(t, j))) - { - i = j; - j *= 2; - if (j > cast_to(unsigned int, INT_MAX)) - { /* overflow? */ - /* table was built with bad purposes: resort to linear search */ - i = 1; - while (!ttisnil(luaH_getnum(t, i))) - i++; - return i - 1; - } - } - /* now do a binary search between them */ - while (j - i > 1) - { - unsigned int m = (i + j) / 2; - if (ttisnil(luaH_getnum(t, m))) - j = m; - else - i = m; - } - return i; -} - static int updateaboundary(Table* t, int boundary) { if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) @@ -800,17 +765,12 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } - else if (FFlag::LuauTableNewBoundary2) + else { /* validate boundary invariant */ LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); return j; } - /* else must find a boundary in hash part */ - else if (t->node == dummynode) /* hash part is empty? */ - return j; /* that is easy... */ - else - return unbound_search(t, j); } Table* luaH_clone(lua_State* L, Table* tt) diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 106efb2b..9b99506b 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -37,6 +37,8 @@ const char* const luaT_eventname[] = { "__newindex", "__mode", "__namecall", + "__call", + "__iter", "__eq", @@ -54,13 +56,13 @@ const char* const luaT_eventname[] = { "__lt", "__le", "__concat", - "__call", "__type", }; // clang-format on static_assert(sizeof(luaT_typenames) / sizeof(luaT_typenames[0]) == LUA_T_COUNT, "luaT_typenames size mismatch"); static_assert(sizeof(luaT_eventname) / sizeof(luaT_eventname[0]) == TM_N, "luaT_eventname size mismatch"); +static_assert(TM_EQ < 8, "fasttm optimization stores a bitfield with metamethods in a byte"); void luaT_init(lua_State* L) { diff --git a/VM/src/ltm.h b/VM/src/ltm.h index 0e4e915d..e1b95c21 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -16,6 +16,8 @@ typedef enum TM_NEWINDEX, TM_MODE, TM_NAMECALL, + TM_CALL, + TM_ITER, TM_EQ, /* last tag method with `fast' access */ @@ -33,7 +35,6 @@ typedef enum TM_LT, TM_LE, TM_CONCAT, - TM_CALL, TM_TYPE, TM_N /* number of elements in the enum */ diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 39c60eac..3c7c276a 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,10 @@ #include -LUAU_FASTFLAG(LuauTableNewBoundary2) +LUAU_FASTFLAGVARIABLE(LuauIter, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauIterCallTelemetry, false) + +void (*lua_iter_call_telemetry)(lua_State* L); // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -110,7 +113,7 @@ LUAU_FASTFLAG(LuauTableNewBoundary2) VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ - VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), + VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), #if defined(__GNUC__) || defined(__clang__) #define VM_USE_CGOTO 1 @@ -150,8 +153,20 @@ LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pst LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) StkId ra = &L->base[a]; - LUAU_ASSERT(ra + 6 <= L->top); + LUAU_ASSERT(ra + 3 <= L->top); + + if (DFFlag::LuauIterCallTelemetry) + { + /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ + void (*telemetrycb)(lua_State* L) = lua_iter_call_telemetry; + + if (telemetrycb && ttistable(ra) && fasttm(L, hvalue(ra)->metatable, TM_CALL)) + telemetrycb(L); + if (telemetrycb && ttisuserdata(ra) && fasttm(L, uvalue(ra)->metatable, TM_CALL)) + telemetrycb(L); + } setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 1, ra + 1); @@ -2204,20 +2219,149 @@ static void luau_execute(lua_State* L) } } + VM_CASE(LOP_FORGPREP) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (ttisfunction(ra)) + { + /* will be called during FORGLOOP */ + } + else if (FFlag::LuauIter) + { + Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + + if (const TValue* fn = fasttm(L, mt, TM_ITER)) + { + setobj2s(L, ra + 1, ra); + setobj2s(L, ra, fn); + + L->top = ra + 2; /* func + self arg */ + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra, 3)); + L->top = L->ci->top; + } + else if (fasttm(L, mt, TM_CALL)) + { + /* table or userdata with __call, will be called during FORGLOOP */ + /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ + } + else if (ttistable(ra)) + { + /* set up registers for builtin iteration */ + setobj2s(L, ra + 1, ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setnilvalue(ra); + } + else + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + VM_CASE(LOP_FORGLOOP) { VM_INTERRUPT(); Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); uint32_t aux = *pc; - // note: this is a slow generic path, fast-path is FORGLOOP_INEXT/NEXT - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); + if (!FFlag::LuauIter) + { + bool stop; + VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); - // note that we need to increment pc by 1 to exit the loop since we need to skip over aux - pc += stop ? 1 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += stop ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + // fast-path: builtin table iteration + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) + { + Table* h = hvalue(ra + 1); + int index = int(reinterpret_cast(pvalue(ra + 2))); + + int sizearray = h->sizearray; + int sizenode = 1 << h->lsizenode; + + // clear extra variables since we might have more than two + if (LUAU_UNLIKELY(aux > 2)) + for (int i = 2; i < int(aux); ++i) + setnilvalue(ra + 3 + i); + + // first we advance index through the array portion + while (unsigned(index) < unsigned(sizearray)) + { + if (!ttisnil(&h->array[index])) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, &h->array[index]); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // then we advance index through the hash portion + while (unsigned(index - sizearray) < unsigned(sizenode)) + { + LuaNode* n = &h->node[index - sizearray]; + + if (!ttisnil(gval(n))) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + getnodekey(L, ra + 3, n); + setobj2s(L, ra + 4, gval(n)); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // fallthrough to exit + pc++; + VM_NEXT(); + } + else + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + setobjs2s(L, ra + 3 + 2, ra + 2); + setobjs2s(L, ra + 3 + 1, ra + 1); + setobjs2s(L, ra + 3, ra); + + L->top = ra + 3 + 3; /* func + 2 args (state and index) */ + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra + 3, aux)); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // copy first variable back into the iteration index + setobjs2s(L, ra + 2, ra + 3); + + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } } VM_CASE(LOP_FORGPREP_INEXT) @@ -2228,8 +2372,15 @@ static void luau_execute(lua_State* L) // fast-path: ipairs/inext if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { + if (FFlag::LuauIter) + setnilvalue(ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } + else if (FFlag::LuauIter && !ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2268,23 +2419,9 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node)))) - { - // fallthrough to exit - VM_NEXT(); - } else { - // the table has a hash part; index + 1 may appear in it in which case we need to iterate through the hash portion as well - const TValue* val = luaH_getnum(h, index + 1); - - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, val); - - // note that nil elements inside the array terminate the traversal - pc += ttisnil(ra + 4) ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + // fallthrough to exit VM_NEXT(); } } @@ -2308,8 +2445,15 @@ static void luau_execute(lua_State* L) // fast-path: pairs/next if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { + if (FFlag::LuauIter) + setnilvalue(ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } + else if (FFlag::LuauIter && !ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2704,7 +2848,7 @@ static void luau_execute(lua_State* L) { VM_PROTECT_PC(); - int n = f(L, ra, arg, nresults, nullptr, nparams); + int n = f(L, ra, arg, nresults, NULL, nparams); if (n >= 0) { diff --git a/bench/micro_tests/test_LargeTableSum_loop_iter.lua b/bench/micro_tests/test_LargeTableSum_loop_iter.lua new file mode 100644 index 00000000..057420f6 --- /dev/null +++ b/bench/micro_tests/test_LargeTableSum_loop_iter.lua @@ -0,0 +1,17 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local t = {} + + for i=1,1000000 do t[i] = i end + + local ts0 = os.clock() + local sum = 0 + for k,v in t do sum = sum + v end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "LargeTableSum: for k,v in {}") diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 5d162ab9..77fa0854 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -25,7 +25,7 @@ local DisplArea = {} DisplArea.Width = 300; DisplArea.Height = 300; -function DrawLine(From, To) +local function DrawLine(From, To) local x1 = From.V[1]; local x2 = To.V[1]; local y1 = From.V[2]; @@ -81,7 +81,7 @@ function DrawLine(From, To) Q.LastPx = NumPix; end -function CalcCross(V0, V1) +local function CalcCross(V0, V1) local Cross = {}; Cross[1] = V0[2]*V1[3] - V0[3]*V1[2]; Cross[2] = V0[3]*V1[1] - V0[1]*V1[3]; @@ -89,7 +89,7 @@ function CalcCross(V0, V1) return Cross; end -function CalcNormal(V0, V1, V2) +local function CalcNormal(V0, V1, V2) local A = {}; local B = {}; for i = 1,3 do A[i] = V0[i] - V1[i]; @@ -102,14 +102,14 @@ function CalcNormal(V0, V1, V2) return A; end -function CreateP(X,Y,Z) +local function CreateP(X,Y,Z) local result = {} result.V = {X,Y,Z,1}; return result end -- multiplies two matrices -function MMulti(M1, M2) +local function MMulti(M1, M2) local M = {{},{},{},{}}; for i = 1,4 do for j = 1,4 do @@ -120,7 +120,7 @@ function MMulti(M1, M2) end -- multiplies matrix with vector -function VMulti(M, V) +local function VMulti(M, V) local Vect = {}; for i = 1,4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; @@ -128,7 +128,7 @@ function VMulti(M, V) return Vect; end -function VMulti2(M, V) +local function VMulti2(M, V) local Vect = {}; for i = 1,3 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; @@ -137,7 +137,7 @@ function VMulti2(M, V) end -- add to matrices -function MAdd(M1, M2) +local function MAdd(M1, M2) local M = {{},{},{},{}}; for i = 1,4 do for j = 1,4 do @@ -147,7 +147,7 @@ function MAdd(M1, M2) return M; end -function Translate(M, Dx, Dy, Dz) +local function Translate(M, Dx, Dy, Dz) local T = { {1,0,0,Dx}, {0,1,0,Dy}, @@ -157,7 +157,7 @@ function Translate(M, Dx, Dy, Dz) return MMulti(T, M); end -function RotateX(M, Phi) +local function RotateX(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -171,7 +171,7 @@ function RotateX(M, Phi) return MMulti(R, M); end -function RotateY(M, Phi) +local function RotateY(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -185,7 +185,7 @@ function RotateY(M, Phi) return MMulti(R, M); end -function RotateZ(M, Phi) +local function RotateZ(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -199,7 +199,7 @@ function RotateZ(M, Phi) return MMulti(R, M); end -function DrawQube() +local function DrawQube() -- calc current normals local CurN = {}; local i = 5; @@ -245,7 +245,7 @@ function DrawQube() Q.LastPx = 0; end -function Loop() +local function Loop() if (Testing.LoopCount > Testing.LoopMax) then return; end local TestingStr = tostring(Testing.LoopCount); while (#TestingStr < 3) do TestingStr = "0" .. TestingStr; end @@ -265,7 +265,7 @@ function Loop() Loop(); end -function Init(CubeSize) +local function Init(CubeSize) -- init/reset vars Origin.V = {150,150,20,1}; Testing.LoopCount = 0; diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua index f73f173b..79e91419 100644 --- a/bench/tests/sunspider/3d-morph.lua +++ b/bench/tests/sunspider/3d-morph.lua @@ -31,7 +31,7 @@ local loops = 15 local nx = 120 local nz = 120 -function morph(a, f) +local function morph(a, f) local PI2nx = math.pi * 8/nx local sin = math.sin local f30 = -(50 * sin(f*math.pi*2)) diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index c8f6b5dc..3d5276c7 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -28,40 +28,40 @@ function test() local size = 30 -function createVector(x,y,z) +local function createVector(x,y,z) return { x,y,z }; end -function sqrLengthVector(self) +local function sqrLengthVector(self) return self[1] * self[1] + self[2] * self[2] + self[3] * self[3]; end -function lengthVector(self) +local function lengthVector(self) return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); end -function addVector(self, v) +local function addVector(self, v) self[1] = self[1] + v[1]; self[2] = self[2] + v[2]; self[3] = self[3] + v[3]; return self; end -function subVector(self, v) +local function subVector(self, v) self[1] = self[1] - v[1]; self[2] = self[2] - v[2]; self[3] = self[3] - v[3]; return self; end -function scaleVector(self, scale) +local function scaleVector(self, scale) self[1] = self[1] * scale; self[2] = self[2] * scale; self[3] = self[3] * scale; return self; end -function normaliseVector(self) +local function normaliseVector(self) local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); self[1] = self[1] / len; self[2] = self[2] / len; @@ -69,39 +69,39 @@ function normaliseVector(self) return self; end -function add(v1, v2) +local function add(v1, v2) return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] }; end -function sub(v1, v2) +local function sub(v1, v2) return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] }; end -function scalev(v1, v2) +local function scalev(v1, v2) return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] }; end -function dot(v1, v2) +local function dot(v1, v2) return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3]; end -function scale(v, scale) +local function scale(v, scale) return { v[1] * scale, v[2] * scale, v[3] * scale }; end -function cross(v1, v2) +local function cross(v1, v2) return { v1[2] * v2[3] - v1[3] * v2[2], v1[3] * v2[1] - v1[1] * v2[3], v1[1] * v2[2] - v1[2] * v2[1] }; end -function normalise(v) +local function normalise(v) local len = lengthVector(v); return { v[1] / len, v[2] / len, v[3] / len }; end -function transformMatrix(self, v) +local function transformMatrix(self, v) local vals = self; local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4]; local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8]; @@ -109,7 +109,7 @@ function transformMatrix(self, v) return { x, y, z }; end -function invertMatrix(self) +local function invertMatrix(self) local temp = {} local tx = -self[4]; local ty = -self[8]; @@ -131,7 +131,7 @@ function invertMatrix(self) end -- Triangle intersection using barycentric coord method -function Triangle(p1, p2, p3) +local function Triangle(p1, p2, p3) local this = {} local edge1 = sub(p3, p1); @@ -205,7 +205,7 @@ function Triangle(p1, p2, p3) return this end -function Scene(a_triangles) +local function Scene(a_triangles) local this = {} this.triangles = a_triangles; this.lights = {}; @@ -302,7 +302,7 @@ local zero = { 0,0,0 }; -- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where -- that somewhere is -function Camera(origin, lookat, up) +local function Camera(origin, lookat, up) local this = {} local zaxis = normaliseVector(subVector(lookat, origin)); @@ -357,7 +357,7 @@ function Camera(origin, lookat, up) return this end -function raytraceScene() +local function raytraceScene() local startDate = 13154863; local numTriangles = 2 * 6; local triangles = {}; -- numTriangles); @@ -450,7 +450,7 @@ function raytraceScene() return pixels; end -function arrayToCanvasCommands(pixels) +local function arrayToCanvasCommands(pixels) local s = {}; table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do @@ -485,7 +485,7 @@ for (var y = 0; y < size; y++) {\n\ return table.concat(s); end -testOutput = arrayToCanvasCommands(raytraceScene()); +local testOutput = arrayToCanvasCommands(raytraceScene()); --local f = io.output("output.html") --f:write(testOutput) diff --git a/bench/tests/sunspider/access-binary-trees.lua b/bench/tests/sunspider/access-binary-trees.lua deleted file mode 100644 index 9eb93588..00000000 --- a/bench/tests/sunspider/access-binary-trees.lua +++ /dev/null @@ -1,69 +0,0 @@ ---[[ - The Great Computer Language Shootout - http://shootout.alioth.debian.org/ - contributed by Isaac Gouy -]] - -local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - -function TreeNode(left,right,item) - local this = {} - this.left = left; - this.right = right; - this.item = item; - - this.itemCheck = function(self) - if (self.left==nil) then return self.item; - else return self.item + self.left:itemCheck() - self.right:itemCheck(); end - end - - return this -end - -function bottomUpTree(item,depth) - if (depth>0) then - return TreeNode( - bottomUpTree(2*item-1, depth-1) - ,bottomUpTree(2*item, depth-1) - ,item - ); - else - return TreeNode(nil,nil,item); - end -end - -local ret = 0; - -for n = 4,7,1 do - local minDepth = 4; - local maxDepth = math.max(minDepth + 2, n); - local stretchDepth = maxDepth + 1; - - local check = bottomUpTree(0,stretchDepth):itemCheck(); - - local longLivedTree = bottomUpTree(0,maxDepth); - - for depth = minDepth,maxDepth,2 do - local iterations = 2.0 ^ (maxDepth - depth + minDepth - 1) -- 1 << (maxDepth - depth + minDepth); - - check = 0; - for i = 1,iterations do - check = check + bottomUpTree(i,depth):itemCheck(); - check = check + bottomUpTree(-i,depth):itemCheck(); - end - end - - ret = ret + longLivedTree:itemCheck(); -end - -local expected = -4; - -if (ret ~= expected) then - assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. ret); -end - -end - -bench.runCode(test, "access-binary-trees") diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua index d0791626..a2591b2f 100644 --- a/bench/tests/sunspider/controlflow-recursive.lua +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -7,18 +7,18 @@ local bench = script and require(script.Parent.bench_support) or require("bench_ function test() -function ack(m,n) +local function ack(m,n) if (m==0) then return n+1; end if (n==0) then return ack(m-1,1); end return ack(m-1, ack(m,n-1) ); end -function fib(n) +local function fib(n) if (n < 2) then return 1; end return fib(n-2) + fib(n-1); end -function tak(x,y,z) +local function tak(x,y,z) if (y >= x) then return z; end return tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)); end @@ -27,7 +27,7 @@ local result = 0; for i = 3,5 do result = result + ack(3,i); - result = result + fib(17.0+i); + result = result + fib(17+i); result = result + tak(3*i+3,2*i+2,i+1); end diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 3b289729..8dd0cec6 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -42,7 +42,68 @@ local Rcon = { { 0x00, 0x00, 0x00, 0x00 }, {0x1b, 0x00, 0x00, 0x00}, {0x36, 0x00, 0x00, 0x00} }; -function Cipher(input, w) -- main Cipher function [§5.1] +local function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] + for r = 0,3 do + for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end + end + return s; +end + + +local function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] + local t = {}; + for r = 1,3 do + for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy + for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back + end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): + return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf +end + + +local function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] + for c = 0,3 do + local a = {}; -- 'a' is a copy of the current column from 's' + local b = {}; -- 'b' is a•{02} in GF(2^8) + for i = 0,3 do + a[i + 1] = s[i + 1][c + 1]; + + if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then + b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); + else + b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); + end + end + -- a[n] ^ b[n] is a•{03} in GF(2^8) + s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 + s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 + s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 + s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 +end + return s; +end + + +local function SubWord(w) -- apply SBox to 4-byte word w + for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end + return w; +end + +local function RotWord(w) -- rotate 4-byte word w left by one byte + w[5] = w[1]; + for i = 0,3 do w[i + 1] = w[i + 2]; end + return w; +end + + + +local function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] + for r = 0,3 do + for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end + end + return state; +end + +local function Cipher(input, w) -- main Cipher function [§5.1] local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys @@ -69,56 +130,7 @@ function Cipher(input, w) -- main Cipher function [§5.1] end -function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] - for r = 0,3 do - for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end - end - return s; -end - - -function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] - local t = {}; - for r = 1,3 do - for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy - for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back - end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): - return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf -end - - -function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] - for c = 0,3 do - local a = {}; -- 'a' is a copy of the current column from 's' - local b = {}; -- 'b' is a•{02} in GF(2^8) - for i = 0,3 do - a[i + 1] = s[i + 1][c + 1]; - - if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then - b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); - else - b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); - end - end - -- a[n] ^ b[n] is a•{03} in GF(2^8) - s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 - s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 - s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 - s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 -end - return s; -end - - -function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] - for r = 0,3 do - for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end - end - return state; -end - - -function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] +local function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys @@ -146,17 +158,17 @@ function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from return w; end -function SubWord(w) -- apply SBox to 4-byte word w - for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end - return w; +local function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext + return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end); end -function RotWord(w) -- rotate 4-byte word w left by one byte - w[5] = w[1]; - for i = 0,3 do w[i + 1] = w[i + 2]; end - return w; -end +local function unescCtrlChars(str) -- unescape potentially problematic control characters + return string.gsub(str, "!%d%d?%d?!", function(c) + local sc = string.sub(c, 2,-2) + return string.char(tonumber(sc)); + end); +end --[[ * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation @@ -166,7 +178,7 @@ end * - cipherblock = plaintext xor outputblock ]] -function AESEncryptCtr(plaintext, password, nBits) +local function AESEncryptCtr(plaintext, password, nBits) if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; @@ -243,7 +255,7 @@ end * - cipherblock = plaintext xor outputblock ]] -function AESDecryptCtr(ciphertext, password, nBits) +local function AESDecryptCtr(ciphertext, password, nBits) if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys local nBytes = nBits/8; -- no bytes in key @@ -300,19 +312,7 @@ function AESDecryptCtr(ciphertext, password, nBits) return table.concat(plaintext) end -function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext - return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end); -end - -function unescCtrlChars(str) -- unescape potentially problematic control characters - return string.gsub(str, "!%d%d?%d?!", function(c) - local sc = string.sub(c, 2,-2) - - return string.char(tonumber(sc)); - end); -end - -function test() +local function test() local plainText = "ROMEO: But, soft! what light through yonder window breaks?\n\ It is the east, and Juliet is the sun.\n\ diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua index 94a64f45..cdb10fa2 100644 --- a/bench/tests/sunspider/math-cordic.lua +++ b/bench/tests/sunspider/math-cordic.lua @@ -31,15 +31,15 @@ function test() local AG_CONST = 0.6072529350; -function FIXED(X) +local function FIXED(X) return X * 65536.0; end -function FLOAT(X) +local function FLOAT(X) return X / 65536.0; end -function DEG2RAD(X) +local function DEG2RAD(X) return 0.017453 * (X); end @@ -52,7 +52,7 @@ local Angles = { local Target = 28.027; -function cordicsincos(Target) +local function cordicsincos(Target) local X; local Y; local TargetAngle; @@ -85,7 +85,7 @@ end local total = 0; -function cordic( runs ) +local function cordic( runs ) for i = 1,runs do total = total + cordicsincos(Target); end diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua index 3c222876..9977ceff 100644 --- a/bench/tests/sunspider/math-partial-sums.lua +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -7,7 +7,7 @@ local bench = script and require(script.Parent.bench_support) or require("bench_ function test() -function partial(n) +local function partial(n) local a1, a2, a3, a4, a5, a6, a7, a8, a9 = 0, 0, 0, 0, 0, 0, 0, 0, 0; local twothirds = 2.0/3.0; local alt = -1.0; diff --git a/bench/tests/sunspider/math-spectral-norm.lua b/bench/tests/sunspider/math-spectral-norm.lua deleted file mode 100644 index 7d7ec163..00000000 --- a/bench/tests/sunspider/math-spectral-norm.lua +++ /dev/null @@ -1,72 +0,0 @@ ---[[ -The Great Computer Language Shootout -http://shootout.alioth.debian.org/ - -contributed by Ian Osgood -]] -local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - -function A(i,j) - return 1/((i+j)*(i+j+1)/2+i+1); -end - -function Au(u,v) - for i = 0,#u-1 do - local t = 0; - for j = 0,#u-1 do - t = t + A(i,j) * u[j + 1]; - end - v[i + 1] = t; - end -end - -function Atu(u,v) - for i = 0,#u-1 do - local t = 0; - for j = 0,#u-1 do - t = t + A(j,i) * u[j + 1]; - end - v[i + 1] = t; - end -end - -function AtAu(u,v,w) - Au(u,w); - Atu(w,v); -end - -function spectralnorm(n) - local u, v, w, vv, vBv = {}, {}, {}, 0, 0; - for i = 1,n do - u[i] = 1; v[i] = 0; w[i] = 0; - end - for i = 0,9 do - AtAu(u,v,w); - AtAu(v,u,w); - end - for i = 1,n do - vBv = vBv + u[i]*v[i]; - vv = vv + v[i]*v[i]; - end - return math.sqrt(vBv/vv); -end - -local total = 0; -local i = 6 - -while i <= 48 do - total = total + spectralnorm(i); - i = i * 2 -end - -local expected = 5.086694231303284; - -if (total ~= expected) then - assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total) -end - -end - -bench.runCode(test, "math-spectral-norm") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 5b70481b..1c284f1f 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2760,8 +2760,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2798,8 +2796,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} local x: tagged = {tag="cat", fieldx=2} @@ -2821,8 +2817,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( local function f(x: true) end f(@1) @@ -2838,8 +2832,6 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7b4bfc72..f206438f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,6 +261,9 @@ RETURN R0 0 TEST_CASE("ForBytecode") { + ScopedFastFlag sff("LuauCompileIter", true); + ScopedFastFlag sff2("LuauCompileIterNoPairs", false); + // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -295,7 +298,7 @@ GETIMPORT R0 2 LOADK R1 K3 LOADK R2 K4 CALL R0 2 3 -JUMP +4 +FORGPREP R0 +4 GETIMPORT R5 6 MOVE R6 R3 CALL R5 1 0 @@ -347,6 +350,8 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { + ScopedFastFlag sff("LuauCompileIter", true); + // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 @@ -385,7 +390,7 @@ GETIMPORT R0 3 MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 -JUMP +0 +FORGPREP R1 +0 FORGLOOP R1 -1 2 RETURN R0 0 )"); @@ -397,7 +402,7 @@ SETGLOBAL R0 K2 GETGLOBAL R0 K2 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 +FORGPREP R0 +0 FORGLOOP R0 -1 2 RETURN R0 0 )"); @@ -407,7 +412,7 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 +FORGPREP R0 +0 FORGLOOP R0 -1 2 RETURN R0 0 )"); @@ -2260,6 +2265,8 @@ TEST_CASE("TypeAliasing") TEST_CASE("DebugLineInfo") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2316,6 +2323,8 @@ return result TEST_CASE("DebugLineInfoFor") { + ScopedFastFlag sff("LuauCompileIter", true); + Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2336,7 +2345,7 @@ end 5: LOADN R0 1 7: LOADN R1 2 9: LOADN R2 3 -9: JUMP +4 +9: FORGPREP R0 +4 11: GETIMPORT R5 1 11: MOVE R6 R3 11: CALL R5 1 0 @@ -2541,6 +2550,8 @@ a TEST_CASE("DebugSource") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + const char* source = R"( local kSelectedBiomes = { ['Mountains'] = true, @@ -2616,6 +2627,8 @@ RETURN R1 1 TEST_CASE("DebugLocals") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + const char* source = R"( function foo(e, f) local a = 1 @@ -3767,6 +3780,8 @@ RETURN R0 1 TEST_CASE("SharedClosure") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -4452,5 +4467,688 @@ RETURN R0 0 )"); } +TEST_CASE("InlineBasic") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // inline function that returns a constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns the argument + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, math.random(), 5) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +MOVE R1 R2 +RETURN R1 1 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, 5, math.random()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +LOADN R1 5 +RETURN R1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMutate") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // if the argument is mutated, it gets a register even if the value is constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 5 + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +ORK R2 R2 K1 +MOVE R1 R2 +RETURN R1 1 +)"); + + // if the argument is a local, it can be used directly + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R2 R1 +RETURN R2 1 +)"); + + // ... but if it's mutated, we move it in case it is mutated through a capture during the inlined function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +x = nil +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +LOADNIL R1 +MOVE R3 R1 +MOVE R2 R3 +RETURN R2 1 +)"); + + // we also don't inline functions if they have been assigned to + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +foo = foo + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R0 R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineUpval") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // if the argument is an upvalue, we naturally need to copy it to a local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local b = ... + +function bar() + local x = foo(b) + return x +end +)", + 1, 2), + R"( +GETUPVAL R1 0 +MOVE R0 R1 +RETURN R0 1 +)"); + + // if the function uses an upvalue it's more complicated, because the lexical upvalue may become a local + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +local function foo(a) + return a + b +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +GETVARARGS R0 1 +DUPCLOSURE R1 K0 +CAPTURE VAL R0 +LOADN R3 42 +ADD R2 R3 R0 +RETURN R2 1 +)"); + + // sometimes the lexical upvalue is deep enough that it's still an upvalue though + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +function bar() + local function foo(a) + return a + b + end + + local x = foo(42) + return x +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +LOADN R2 42 +GETUPVAL R3 0 +ADD R1 R2 R3 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineFallthrough") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // if the function doesn't return, we still fill the results with nil + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +local a, b = foo() + +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +MOVE R3 R1 +MOVE R4 R2 +RETURN R3 2 +)"); + + // this happens even if the function returns conditionally + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if a then return 42 end +end + +local a, b = foo(false) + +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +MOVE R3 R1 +MOVE R4 R2 +RETURN R3 2 +)"); + + // note though that we can't inline a function like this in multret context + // this is because we don't have a SETTOP instruction + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +return foo() +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("InlineCapture") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // can't inline function with nested functions that capture locals because they might be constants + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + local function bar() + return a + end + return bar() +end +)", + 1, 2), + R"( +NEWCLOSURE R1 P0 +CAPTURE VAL R0 +MOVE R2 R1 +CALL R2 0 -1 +RETURN R2 -1 +)"); +} + +TEST_CASE("InlineArgMismatch") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // when inlining a function, we must respect all the usual rules + + // caller might not have enough arguments + // TODO: we don't inline this atm + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // caller might be using multret for arguments + // TODO: we don't inline this atm + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(math.modf(1.5)) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADK R3 K1 +FASTCALL1 20 R3 +2 +GETIMPORT R2 4 +CALL R2 1 -1 +CALL R1 -1 1 +RETURN R1 1 +)"); + + // caller might have too many arguments, but we still need to compute them for side effects + // TODO: we don't inline this atm + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42, print()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +GETIMPORT R3 2 +CALL R3 0 -1 +CALL R1 -1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMultiple") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // we call this with a different set of variable/constant args + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x, y = ... +local a = foo(x, 1) +local b = foo(1, x) +local c = foo(1, 2) +local d = foo(x, y) +return a, b, c, d +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 2 +ADDK R3 R1 K1 +LOADN R5 1 +ADD R4 R5 R1 +LOADN R5 3 +ADD R6 R1 R2 +MOVE R7 R3 +MOVE R8 R4 +MOVE R9 R5 +MOVE R10 R6 +RETURN R7 4 +)"); +} + +TEST_CASE("InlineChain") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // inline a chain of functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +LOADN R4 43 +LOADN R5 41 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineThresholds") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + ScopedFastInt sfis[] = { + {"LuauCompileInlineThreshold", 25}, + {"LuauCompileInlineThresholdMaxBoost", 300}, + {"LuauCompileInlineDepth", 2}, + }; + + // this function has enormous register pressure (50 regs) so we choose not to inline it + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this function has less register pressure but a large cost + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {},{},{},{},{} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this chain of function is of length 3 but our limit in this test is 2, so we call foo twice + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +MOVE R4 R0 +LOADN R5 42 +LOADN R6 1 +CALL R4 2 1 +MOVE R5 R0 +LOADN R6 42 +LOADN R7 -1 +CALL R5 2 1 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineIIFE") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // IIFE with arguments + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function(a, b, c) if a then return b else return c end end)(a, b, c)) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 +2 +MOVE R3 R1 +RETURN R3 1 +MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); + + // IIFE with upvalues + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function() if a then return b else return c end end)()) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 +2 +MOVE R3 R1 +RETURN R3 1 +MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineRecurseArguments") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // we can't inline a function if it's used to compute its own arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) +end +foo(foo(foo,foo(foo,foo))[foo]) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +MOVE R4 R0 +MOVE R5 R0 +MOVE R6 R0 +CALL R4 2 1 +LOADNIL R3 +GETTABLE R2 R3 R0 +CALL R1 1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineFastCallK") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + CHECK_EQ("\n" + compileFunction(R"( +local function set(l0) + rawset({}, l0) +end + +set(false) +set({}) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +FASTCALL2K 49 R2 K1 +4 +LOADK R3 K1 +GETIMPORT R1 3 +CALL R1 2 0 +NEWTABLE R1 0 0 +NEWTABLE R3 0 0 +FASTCALL2 49 R3 R1 +4 +MOVE R4 R1 +GETIMPORT R2 3 +CALL R2 2 0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineExprIndexK") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + CHECK_EQ("\n" + compileFunction(R"( +local _ = function(l0) +local _ = nil +while _(_)[_] do +end +end +local _ = _(0)[""] +if _ then +do +for l0=0,8 do +end +end +elseif _ then +_ = nil +do +for l0=0,8 do +return true +end +end +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R4 +LOADNIL R5 +CALL R4 1 1 +LOADNIL R5 +GETTABLE R3 R4 R5 +JUMPIFNOT R3 +1 +JUMPBACK -7 +LOADNIL R2 +GETTABLEKS R1 R2 K1 +JUMPIFNOT R1 +1 +RETURN R0 0 +JUMPIFNOT R1 +19 +LOADNIL R1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +RETURN R0 0 +)"); +} TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 6f136d36..a23ea470 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,6 +241,8 @@ TEST_CASE("Math") TEST_CASE("Table") { + ScopedFastFlag sff("LuauFixBuiltinsStackLimit", true); + runConformance("nextvar.lua"); } @@ -1099,4 +1101,14 @@ TEST_CASE("UserdataApi") CHECK(dtorhits == 42); } +TEST_CASE("Iter") +{ + ScopedFastFlag sffs[] = { + { "LuauCompileIter", true }, + { "LuauIter", true }, + }; + + runConformance("iter.lua"); +} + TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index e771b6b1..a10e8f7f 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -386,8 +386,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") { - ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; - fileResolver.source["game/A"] = R"( return {hello = 2} )"; @@ -410,8 +408,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") { - ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; - fileResolver.source["game/A"] = R"( return {mod_a = 2} )"; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 55eafe3c..69ff73ad 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2041,8 +2041,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") { - ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; - matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", Location{{0, 20}, {0, 23}}); matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); @@ -2618,8 +2616,6 @@ type Y = (T...) -> U... TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") { - ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; - ParseResult result = tryParse(R"( type X = { a: T..., b: number } type Y = { a: T..., b: number } diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 42411de2..538f3576 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -35,9 +35,9 @@ bool hasError(const CheckResult& result, T* = nullptr) return it != result.errors.end(); } -TEST_SUITE_BEGIN("RuntimeLimitTests"); +TEST_SUITE_BEGIN("RuntimeLimits"); -TEST_CASE_FIXTURE(LimitFixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) +TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") { constexpr const char* src = R"LUA( --!strict diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 960c6edf..f9b510c1 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -488,4 +488,71 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") )"); } +TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t: {string} = {} + local key + for k: number in t do + end + for k: number, v: string in t do + end + for k, v in t do + key = k + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.numberType, *requireType("key")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t: {string} = {} + local extra + for k, v, e in t do + extra = e + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.nilType, *requireType("extra")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t = {} + for k, v in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Cannot iterate over a table without indexer", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_iter_metamethod") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t = {} + setmetatable(t, { __iter = function(o) return next, o.children end }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index fa1f519c..b6f49f9f 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -5,7 +5,6 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" #include "Fixture.h" diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 5bd522a3..8e535995 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2331,7 +2331,7 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") { - ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter2", true}; CheckResult result = check(R"( local a: {x: number, y: number, [any]: any} | {y: number} @@ -2351,7 +2351,7 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") { - ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter2", true}; CheckResult result = check(R"( local a: {y: number} | {x: number, y: number, [any]: any} diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index a578b1cf..e81ef1a9 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1034,4 +1034,45 @@ TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") LUAU_REQUIRE_NO_ERRORS(result); } +/** + * The problem we had here was that the type of q in B.h was initially inferring to {} | {prop: free} before we bound + * that second table to the enclosing union. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table") +{ + ScopedFastFlag flag[] = { + {"LuauStatFunctionSimplify4", true}, + {"LuauLowerBoundsCalculation", true}, + {"LuauDifferentOrderOfUnificationDoesntMatter2", true}, + }; + + CheckResult result = check(R"( + --!strict + + local A = {} + + function A:f() + local t = {} + + for key, value in pairs(self) do + t[key] = value + end + + return t + end + + local B = A:f() + + function B.g(t) + assert(type(t) == "table") + assert(t.prop ~= nil) + end + + function B.h(q) + q = q or {} + return q or {} + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index b6e93265..87562644 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -242,8 +242,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { - ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; - TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); TypeId b = typeChecker.numberType; @@ -255,8 +253,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") { - ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; - TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); TypePackId b = typeChecker.anyTypePack; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index e033fe22..a45af39c 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -313,23 +313,33 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } -struct VisitCountTracker +struct VisitCountTracker final : TypeVarOnceVisitor { std::unordered_map tyVisits; std::unordered_map tpVisits; - void cycle(TypeId) {} - void cycle(TypePackId) {} + void cycle(TypeId) override {} + void cycle(TypePackId) override {} template bool operator()(TypeId ty, const T& t) + { + return visit(ty); + } + + template + bool operator()(TypePackId tp, const T&) + { + return visit(tp); + } + + bool visit(TypeId ty) override { tyVisits[ty]++; return true; } - template - bool operator()(TypePackId tp, const T&) + bool visit(TypePackId tp) override { tpVisits[tp]++; return true; @@ -348,7 +358,7 @@ local b: (T, T, T) -> T VisitCountTracker tester; DenseHashSet seen{nullptr}; - visitTypeVarOnce(bType, tester, seen); + DEPRECATED_visitTypeVarOnce(bType, tester, seen); for (auto [_, count] : tester.tyVisits) CHECK_EQ(count, 1); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp new file mode 100644 index 00000000..3d426f10 --- /dev/null +++ b/tests/VisitTypeVar.test.cpp @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/RecursionCounter.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauUseVisitRecursionLimit) +LUAU_FASTINT(LuauVisitRecursionLimit) + +struct VisitTypeVarFixture : Fixture +{ + ScopedFastFlag flag1 = {"LuauUseVisitRecursionLimit", true}; + ScopedFastFlag flag2 = {"LuauRecursionLimitException", true}; +}; + +TEST_SUITE_BEGIN("VisitTypeVar"); + +TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + CHECK_THROWS_AS(toString(tType), RecursionLimitException); +} + +TEST_CASE_FIXTURE(VisitTypeVarFixture, "dont_throw_when_limit_is_high_enough") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + (void)toString(tType); +} + +TEST_SUITE_END(); diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.lua new file mode 100644 index 00000000..468ffafb --- /dev/null +++ b/tests/conformance/iter.lua @@ -0,0 +1,196 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing iteration') + +-- basic for loop tests +do + local a + for a,b in pairs{} do error("not here") end + for i=1,0 do error("not here") end + for i=0,1,-1 do error("not here") end + a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) + a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) + a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) +end + +-- precision tests for for loops +do + local a + --a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) + a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) + a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) + a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) + a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) + a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) + a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) +end + +-- for loops do string->number coercion +do + local a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) +end + +-- generic for with function iterators +do + local function f (n, p) + local t = {}; for i=1,p do t[i] = i*10 end + return function (_,n) + if n > 0 then + n = n-1 + return n, unpack(t) + end + end, nil, n + end + + local x = 0 + for n,a,b,c,d in f(5,3) do + x = x+1 + assert(a == 10 and b == 20 and c == 30 and d == nil) + end + assert(x == 5) +end + +-- generic for with __call (tables) +do + local f = {} + setmetatable(f, { __call = function(_, _, n) if n > 0 then return n - 1 end end }) + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with __call (userdata) +do + local f = newproxy(true) + getmetatable(f).__call = function(_, _, n) if n > 0 then return n - 1 end end + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with pairs +do + local x = 0 + for k, v in pairs({a = 1, b = 2, c = 3}) do + x += v + end + assert(x == 6) +end + +-- generic for with pairs with holes +do + local x = 0 + for k, v in pairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 11) +end + +-- generic for with ipairs +do + local x = 0 + for k, v in ipairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 6) +end + +-- generic for with __iter (tables) +do + local f = {} + setmetatable(f, { __iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end }) + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with __iter (userdata) +do + local f = newproxy(true) + getmetatable(f).__iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with tables (dictionary) +do + local x = 0 + for k, v in {a = 1, b = 2, c = 3} do + print(k, v) + x += v + end + assert(x == 6) +end + +-- generic for with tables (arrays) +do + local x = '' + for k, v in {1, 2, 3, nil, 5} do + x ..= tostring(v) + end + assert(x == "1235") +end + +-- generic for with tables (mixed) +do + local x = 0 + for k, v in {1, 2, 3, nil, 5, a = 1, b = 2, c = 3} do + x += v + end + assert(x == 17) +end + +-- generic for over a non-iterable object +do + local ok, err = pcall(function() for x in 42 do end end) + assert(not ok and err:match("attempt to iterate")) +end + +-- generic for over an iterable object that doesn't return a function +do + local obj = {} + setmetatable(obj, { __iter = function() end }) + + local ok, err = pcall(function() for x in obj do end end) + assert(not ok and err:match("attempt to call a nil value")) +end + +-- it's okay to iterate through a table with a single variable +do + local x = 0 + for k in {1, 2, 3, 4, 5} do + x += k + end + assert(x == 15) +end + +-- all extra variables should be set to nil during builtin traversal +do + local x = 0 + for k,v,a,b,c,d,e in {1, 2, 3, 4, 5} do + x += k + assert(a == nil and b == nil and c == nil and d == nil and e == nil) + end + assert(x == 15) +end + +return"OK" diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index c8176456..0dba8fa6 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,48 +368,6 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error("not here") end -for i=1,0 do error("not here") end -for i=0,1,-1 do error("not here") end -a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) -a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) - -a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) --- precision problems ---a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) -a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) -a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) -a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) -a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) -a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) -a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) - --- conversion -a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) - - -collectgarbage() - - --- testing generic 'for' - -local function f (n, p) - local t = {}; for i=1,p do t[i] = i*10 end - return function (_,n) - if n > 0 then - n = n-1 - return n, unpack(t) - end - end, nil, n -end - -local x = 0 -for n,a,b,c,d in f(5,3) do - x = x+1 - assert(a == 10 and b == 20 and c == 30 and d == nil) -end -assert(x == 5) - -- testing table.create and table.find do local t = table.create(5) @@ -596,4 +554,17 @@ do assert(#t2 == 6) end +-- test table.unpack fastcall for rejecting large unpacks +do + local ok, res = pcall(function() + local a = table.create(7999, 0) + local b = table.create(8000, 0) + + local at = { table.unpack(a) } + local bt = { table.unpack(b) } + end) + + assert(not ok) +end + return"OK" diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index b3d2b4f5..ff610d09 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -97,7 +97,7 @@ class LuauVariantSyntheticChildrenProvider: if self.current_type: storage = self.valobj.GetChildMemberWithName("storage") - self.stored_value = storage.Cast(self.current_type.GetPointerType()).Dereference() + self.stored_value = storage.Cast(self.current_type) else: self.stored_value = None else: From 72d8d443431875607fd457a13fe36ea62804d327 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 5 May 2022 17:05:57 -0700 Subject: [PATCH 057/102] Add documentation for generalized iteration (#475) --- docs/_pages/syntax.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/_pages/syntax.md b/docs/_pages/syntax.md index 4d39e462..fe825fda 100644 --- a/docs/_pages/syntax.md +++ b/docs/_pages/syntax.md @@ -196,3 +196,26 @@ local sign = if x < 0 then -1 elseif x > 0 then 1 else 0 ``` **Note:** In Luau, the `if-then-else` expression is preferred vs the standard Lua idiom of writing `a and b or c` (which roughly simulates a ternary operator). However, the Lua idiom may return an unexpected result if `b` evaluates to false. The `if-then-else` expression will behave as expected in all situations. + +## Generalized iteration + +Luau uses the standard Lua syntax for iterating through containers, `for vars in values`, but extends the semantics with support for generalized iteration. In Lua, to iterate over a table you need to use an iterator like `next` or a function that returns one like `pairs` or `ipairs`. In Luau, you can simply iterate over a table: + +```lua +for k, v in {1, 4, 9} do + assert(k * k == v) +end +``` + +This works for tables but can also be extended for tables or userdata by implementing `__iter` metamethod that is called before the iteration begins, and should return an iterator function like `next` (or a custom one): + +```lua +local obj = { items = {1, 4, 9} } +setmetatable(obj, { __iter = function(o) return next, o.items end }) + +for k, v in obj do + assert(k * k == v) +end +``` + +The default iteration order for tables is specified to be consecutive for elements `1..#t` and unordered after that, visiting every element; similarly to iteration using `pairs`, modifying the table entries for keys other than the current one results in unspecified behavior. From 7935f9f8b66e81175ab2997e9c15b431fb1c1dfd Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 9 May 2022 18:33:53 -0700 Subject: [PATCH 058/102] Update sandbox.md Reword the GC docs to avoid back-referencing the thread identity mechanism, since it's entirely Roblox-side and isn't fully documented here anymore. --- docs/_pages/sandbox.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/_pages/sandbox.md b/docs/_pages/sandbox.md index 409a0929..04e72658 100644 --- a/docs/_pages/sandbox.md +++ b/docs/_pages/sandbox.md @@ -4,7 +4,7 @@ title: Sandboxing toc: true --- -Luau is safe to embed. Broadly speaking, this means that even in the face of untrusted (and in Roblox case, actively malicious) code, the language and the standard library don't allow any unsafe access to the underlying system, and don't have any bugs that allow escaping out of the sandbox (e.g. to gain native code execution through ROP gadgets et al). Additionally, the VM provides extra features to implement isolation of privileged code from unprivileged code and protect one from the other; this is important if the embedding environment (Roblox) decides to expose some APIs that may not be safe to call from untrusted code, for example because they do provide controlled access to the underlying system or risk PII exposure through fingerprinting etc. +Luau is safe to embed. Broadly speaking, this means that even in the face of untrusted (and in Roblox case, actively malicious) code, the language and the standard library don't allow any unsafe access to the underlying system, and don't have any bugs that allow escaping out of the sandbox (e.g. to gain native code execution through ROP gadgets et al). Additionally, the VM provides extra features to implement isolation of privileged code from unprivileged code and protect one from the other; this is important if the embedding environment decides to expose some APIs that may not be safe to call from untrusted code, for example because they do provide controlled access to the underlying system or risk PII exposure through fingerprinting etc. This safety is achieved through a combination of removing features from the standard library that are unsafe, adding features to the VM that make it possible to implement sandboxing and isolation, and making sure the implementation is safe from memory safety issues using fuzzing. @@ -54,7 +54,7 @@ This mechanism is bad for performance, memory safety and isolation: - In Lua 5.1, `__gc` support requires traversing userdata lists redundantly during garbage collection to filter out finalizable objects - In later versions of Lua, userdata that implement `__gc` are split into separate lists; however, finalization prolongs the lifetime of the finalized objects which results in less prompt memory reclamation, and two-step destruction results in extra cache misses for userdata -- `__gc` runs during garbage collection in context of an arbitrary thread which makes the thread identity mechanism described above invalid +- `__gc` runs during garbage collection in context of an arbitrary thread which makes the thread identity mechanism used in Roblox to support trusted Luau code invalid - Objects can be removed from weak tables *after* being finalized, which means that accessing these objects can result in memory safety bugs, unless all exposed userdata methods guard against use-after-gc. - If `__gc` method ever leaks to scripts, they can call it directly on an object and use any method exposed by that object after that. This means that `__gc` and all other exposed methods must support memory safety when called on a destroyed object. From be0b7d07e24549e18c54836d3c6e2c3f5f94cd57 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 9 May 2022 18:34:31 -0700 Subject: [PATCH 059/102] Update sandbox.md Replace debug.getinfo with debug.info --- docs/_pages/sandbox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_pages/sandbox.md b/docs/_pages/sandbox.md index 04e72658..d1d7d118 100644 --- a/docs/_pages/sandbox.md +++ b/docs/_pages/sandbox.md @@ -19,7 +19,7 @@ The following libraries and global functions have been removed as a result: - `io.` library has been removed entirely, as it gives access to files and allows running processes - `package.` library has been removed entirely, as it gives access to files and allows loading native modules - `os.` library has been cleaned up from file and environment access functions (`execute`, `exit`, etc.). The only supported functions in the library are `clock`, `date`, `difftime` and `time`. -- `debug.` library has been removed to a large extent, as it has functions that aren't memory safe and other functions break isolation; the only supported functions are `traceback` ~~and `getinfo` (with reduced functionality)~~. +- `debug.` library has been removed to a large extent, as it has functions that aren't memory safe and other functions break isolation; the only supported functions are `traceback` and `info` (which is similar to `debug.getinfo` but has a slightly different interface). - `dofile` and `loadfile` allowed access to file system and have been removed. To achieve memory safety, access to function bytecode has been removed. Bytecode is hard to validate and using untrusted bytecode may lead to exploits. Thus, `loadstring` doesn't work with bytecode inputs, and `string.dump`/`load` have been removed as they aren't necessary anymore. When embedding Luau, bytecode should be encrypted/signed to prevent MITM attacks as well, as the VM assumes that the bytecode was generated by the Luau compiler (which never produces invalid/unsafe bytecode). From f3f231ea6b1b063a511c7ef693ca46f02dc01087 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 9 May 2022 18:38:10 -0700 Subject: [PATCH 060/102] Update compatibility.md Update `__pairs` note with `__iter`, change `__len` to unsure as with `__iter` lack of `__len` on tables is the only issue preventing complete user created containers. --- docs/_pages/compatibility.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/_pages/compatibility.md b/docs/_pages/compatibility.md index 00d883e2..eab1aac8 100644 --- a/docs/_pages/compatibility.md +++ b/docs/_pages/compatibility.md @@ -54,7 +54,7 @@ Sandboxing challenges are [covered in the dedicated section](sandbox). | goto statement | ❌ | this complicates the compiler, makes control flow unstructured and doesn't address a significant need | | finalizers for tables | ❌ | no `__gc` support due to sandboxing and performance/complexity | | no more fenv for threads or functions | 😞 | we love this, but it breaks compatibility | -| tables honor the `__len` metamethod | ❌ | performance implications, no strong use cases +| tables honor the `__len` metamethod | 🤷‍♀️ | performance implications, no strong use cases | hex and `\z` escapes in strings | ✔️ | | | support for hexadecimal floats | 🤷‍♀️ | no strong use cases | | order metamethods work for different types | ❌ | no strong use cases and more complicated semantics + compat | @@ -63,7 +63,7 @@ Sandboxing challenges are [covered in the dedicated section](sandbox). | arguments for function called through `xpcall` | ✔️ | | | optional base in `math.log` | ✔️ | | | optional separator in `string.rep` | 🤷‍♀️ | no real use cases | -| new metamethods `__pairs` and `__ipairs` | ❌ | would like to reevaluate iteration design long term | +| new metamethods `__pairs` and `__ipairs` | ❌ | superseded by `__iter` | | frontier patterns | ✔️ | | | `%g` in patterns | ✔️ | | | `\0` in patterns | ✔️ | | @@ -72,7 +72,7 @@ Sandboxing challenges are [covered in the dedicated section](sandbox). Two things that are important to call out here are various new metamethods for tables and yielding in metamethods. In both cases, there are performance implications to supporting this - our implementation is *very* highly tuned for performance, so any changes that affect the core fundamentals of how Lua works have a price. To support yielding in metamethods we'd need to make the core of the VM more involved, since almost every single "interesting" opcode would need to learn how to be resumable - which also complicates future JIT/AOT story. Metamethods in general are important for extensibility, but very challenging to deal with in implementation, so we err on the side of not supporting any new metamethods unless a strong need arises. -For `__pairs`/`__ipairs`, we aren't sure that this is the right design choice - self-iterating tables via `__iter` are very appealing, and if we can resolve some challenges with array iteration order, that would make the language more accessible so we may go that route instead. +For `__pairs`/`__ipairs`, we felt that extending library functions to enable custom containers wasn't the right choice. Instead we revisited iteration design to allow for self-iterating objects via `__iter` metamethod, which results in a cleaner iteration design that also makes it easier to iterate over tables. As such, we have no plans to support `__pairs`/`__ipairs` as all use cases for it can now be solved by `__iter`. Ephemeron tables may be implemented at some point since they do have valid uses and they make weak tables semantically cleaner, however the cleanup mechanism for these is expensive and complicated, and as such this can only be considered after the pending GC rework is complete. From 105e74c7d940c2a5b416b5fdf0af718b3d040ed1 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Wed, 11 May 2022 15:14:51 -0700 Subject: [PATCH 061/102] Update STATUS.md Both generalized iteration and LBC are implemented but not fully enabled in Roblox yet. --- rfcs/STATUS.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index e3e227a0..15a232ce 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -39,10 +39,10 @@ This document tracks unimplemented RFCs. [RFC: Generalized iteration](https://github.com/Roblox/luau/blob/master/rfcs/generalized-iteration.md) -**Status**: Needs implementation +**Status**: Implemented but not fully rolled out yet. ## Lower Bounds Calculation [RFC: Lower bounds calculation](https://github.com/Roblox/luau/blob/master/rfcs/lower-bounds-calculation.md) -**Status**: Needs implementation +**Status**: Implemented but not fully rolled out yet. From a775e6dc8e5261efd090e6a5b62a191b2f007055 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 12 May 2022 10:08:10 -0700 Subject: [PATCH 062/102] Mark last table subtyping RFC as implemented --- rfcs/unsealed-table-subtyping-strips-optional-properties.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rfcs/unsealed-table-subtyping-strips-optional-properties.md b/rfcs/unsealed-table-subtyping-strips-optional-properties.md index deecfdb3..d99c1f81 100644 --- a/rfcs/unsealed-table-subtyping-strips-optional-properties.md +++ b/rfcs/unsealed-table-subtyping-strips-optional-properties.md @@ -1,5 +1,7 @@ # Only strip optional properties from unsealed tables during subtyping +**Status**: Implemented + ## Summary Currently subtyping allows optional properties to be stripped from table types during subtyping. From 87fe15ac510dbafa1263dad28e6bd1d41d11d1fd Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 12 May 2022 10:08:36 -0700 Subject: [PATCH 063/102] Update STATUS.md Mark last table subtyping RFC as implemented --- rfcs/STATUS.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index 15a232ce..ef55b5c4 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -15,12 +15,6 @@ This document tracks unimplemented RFCs. **Status**: Needs implementation -## Sealed/unsealed typing changes - -[RFC: Only strip optional properties from unsealed tables during subtyping](https://github.com/Roblox/luau/blob/master/rfcs/unsealed-table-subtyping-strips-optional-properties.md) - -**Status**: Implemented but not fully rolled out yet. - ## Safe navigation operator [RFC: Safe navigation postfix operator (?)](https://github.com/Roblox/luau/blob/master/rfcs/syntax-safe-navigation-operator.md) From a36b1eb29ba5c6c414305b4bbd2b6e2c58ed7d06 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 13 May 2022 12:36:37 -0700 Subject: [PATCH 064/102] Sync to upstream/release/527 (#481) --- Analysis/include/Luau/TypeVar.h | 2 +- Analysis/src/Clone.cpp | 4 +- Analysis/src/Normalize.cpp | 3 +- Analysis/src/Substitution.cpp | 4 +- Analysis/src/ToString.cpp | 16 +- Analysis/src/TypeInfer.cpp | 191 ++++++++------------- Ast/src/Parser.cpp | 3 +- CMakeLists.txt | 6 + Compiler/src/Compiler.cpp | 30 +++- Compiler/src/ConstantFolding.cpp | 6 +- Compiler/src/CostModel.cpp | 2 +- extern/isocline/src/bbcode.c | 1 + tests/AstQuery.test.cpp | 2 +- tests/Autocomplete.test.cpp | 44 +++-- tests/BuiltinDefinitions.test.cpp | 4 +- tests/Compiler.test.cpp | 71 ++++++-- tests/CostModel.test.cpp | 14 +- tests/Fixture.cpp | 19 +- tests/Fixture.h | 5 + tests/Frontend.test.cpp | 2 +- tests/Linter.test.cpp | 6 +- tests/Module.test.cpp | 4 +- tests/NonstrictMode.test.cpp | 6 +- tests/Normalize.test.cpp | 7 +- tests/Parser.test.cpp | 2 - tests/RuntimeLimits.test.cpp | 2 +- tests/ToDot.test.cpp | 2 +- tests/ToString.test.cpp | 4 +- tests/TypeInfer.aliases.test.cpp | 10 +- tests/TypeInfer.annotations.test.cpp | 8 +- tests/TypeInfer.anyerror.test.cpp | 4 +- tests/TypeInfer.builtins.test.cpp | 117 +++++++------ tests/TypeInfer.classes.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 38 ++-- tests/TypeInfer.generics.test.cpp | 12 +- tests/TypeInfer.intersectionTypes.test.cpp | 6 +- tests/TypeInfer.loops.test.cpp | 28 +-- tests/TypeInfer.modules.test.cpp | 28 +-- tests/TypeInfer.oop.test.cpp | 4 +- tests/TypeInfer.operators.test.cpp | 71 ++++++-- tests/TypeInfer.provisional.test.cpp | 10 +- tests/TypeInfer.refinements.test.cpp | 20 +-- tests/TypeInfer.singletons.test.cpp | 2 +- tests/TypeInfer.tables.test.cpp | 74 ++++---- tests/TypeInfer.test.cpp | 17 +- tests/TypeInfer.tryUnify.test.cpp | 2 +- tests/TypeInfer.typePacks.cpp | 6 +- tests/TypeInfer.unionTypes.test.cpp | 6 +- 48 files changed, 511 insertions(+), 416 deletions(-) diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 84576758..9cacbc6d 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -329,7 +329,7 @@ struct TableTypeVar // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; + std::map methodDefinitionLocations; // TODO: Remove with FFlag::LuauNoMethodLocations std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index d5bd9dab..1aa556eb 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauLosslessClone, false) +LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -277,7 +278,8 @@ void TypeCloner::operator()(const TableTypeVar& t) } ttv->definitionModuleName = t.definitionModuleName; - ttv->methodDefinitionLocations = t.methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + ttv->methodDefinitionLocations = t.methodDefinitionLocations; ttv->tags = t.tags; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index d8c11388..ef5377a1 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineIntersectionFix, false); namespace Luau { @@ -863,7 +862,7 @@ struct Normalize final : TypeVarVisitor TypeId theTable = result->parts.back(); - if (!get(FFlag::LuauNormalizeCombineIntersectionFix ? follow(theTable) : theTable)) + if (!get(follow(theTable))) { result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); theTable = result->parts.back(); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 30d8574a..c5c7977a 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -12,6 +12,7 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) +LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -408,7 +409,8 @@ TypeId Substitution::clone(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index b5d6a550..51665f7f 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -248,6 +248,13 @@ struct StringifierState result.name += s; } + void emit(TypeLevel level) + { + emit(std::to_string(level.level)); + emit("-"); + emit(std::to_string(level.subLevel)); + } + void emit(const char* s) { if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) @@ -379,7 +386,7 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(std::to_string(ftv.level.level)); + state.emit(ftv.level); } } @@ -403,7 +410,10 @@ struct TypeVarStringifier { state.result.invalid = true; - state.emit("[["); + state.emit("["); + if (FFlag::DebugLuauVerboseTypeNames) + state.emit(ctv.level); + state.emit("["); bool first = true; for (TypeId ty : ctv.parts) @@ -947,7 +957,7 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(std::to_string(pack.level.level)); + state.emit(pack.level); } state.emit("..."); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4466ede2..a13abd53 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,13 +36,11 @@ LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -53,12 +51,13 @@ LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAG(LuauLosslessClone) LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); +LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) +LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); namespace Luau { @@ -587,7 +586,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A { std::optional expectedType; - if (FFlag::LuauInferStatFunction && !fun->func->self) + if (!fun->func->self) { if (auto name = fun->name->as()) { @@ -1307,7 +1306,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify4) + else if (auto name = function.name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); @@ -1341,7 +1340,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify4) + else { LUAU_ASSERT(function.name->is()); @@ -1349,71 +1348,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); } - else if (function.func->self) - { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); - - AstExprIndexName* indexName = function.name->as(); - if (!indexName) - ice("member function declaration has malformed name expression"); - - TypeId selfTy = checkExpr(scope, *indexName->expr).type; - TableTypeVar* tableSelf = getMutableTableType(selfTy); - if (!tableSelf) - { - if (isTableIntersection(selfTy)) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - else if (!get(selfTy) && !get(selfTy)) - reportError(TypeError{function.location, OnlyTablesCanHaveMethods{selfTy}}); - } - else if (tableSelf->state == TableState::Sealed) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - - const bool tableIsExtendable = tableSelf && tableSelf->state != TableState::Sealed; - - ty = follow(ty); - - if (tableIsExtendable) - tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; - - const FunctionTypeVar* funTy = get(ty); - if (!funTy) - ice("Methods should be functions"); - - std::optional arg0 = first(funTy->argTypes); - if (!arg0) - ice("Methods should always have at least 1 argument (self)"); - - checkFunctionBody(funScope, ty, *function.func); - - if (tableIsExtendable) - tableSelf->props[indexName->index.value] = { - follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; - } - else - { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); - - TypeId leftType = checkLValueBinding(scope, *function.name); - - checkFunctionBody(funScope, ty, *function.func); - - unify(ty, leftType, function.location); - - LUAU_ASSERT(function.name->is() || function.name->is()); - - if (auto exprIndexName = function.name->as()) - { - if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) - { - if (auto ttv = getMutableTableType(*typeIt)) - { - if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) - it->second.type = follow(quantify(funScope, leftType, function.name->location)); - } - } - } - } } void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) @@ -1523,7 +1457,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -2652,13 +2587,58 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); - // TODO: this check seems odd, the second part is redundant - // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + if (FFlag::LuauSuccessTypingForEqualityOperations) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorRecoveryType(booleanType); + if (leftMetatable != rightMetatable) + { + bool matches = false; + if (isEquality) + { + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + { + for (TypeId leftOption : utv) + { + if (getMetatable(follow(leftOption)) == rightMetatable) + { + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption)) == leftMetatable) + { + matches = true; + break; + } + } + } + } + } + + + if (!matches) + { + reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorRecoveryType(booleanType); + } + } + } + else + { + if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + { + reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorRecoveryType(booleanType); + } } if (leftMetatable) @@ -2754,22 +2734,11 @@ TypeId TypeChecker::checkBinaryOperation( lhsType = follow(lhsType); rhsType = follow(rhsType); - if (FFlag::LuauDecoupleOperatorInferenceFromUnifiedTypeInference) + if (!isNonstrictMode() && get(lhsType)) { - if (!isNonstrictMode() && get(lhsType)) - { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - // We will fall-through to the `return anyType` check below. - } - } - else - { - if (!isNonstrictMode() && get(lhsType)) - { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - } + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + // We will fall-through to the `return anyType` check below. } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -3231,43 +3200,27 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - - if (!FFlag::LuauStatFunctionSimplify4 && (get(lhsType) || get(lhsType))) - return lhsType; - TableTypeVar* ttv = getMutableTableType(lhsType); - if (FFlag::LuauStatFunctionSimplify4) + if (!ttv || ttv->state == TableState::Sealed) { - if (!ttv || ttv->state == TableState::Sealed) - { - if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) - return *ty; + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + return *ty; - return errorRecoveryType(scope); - } - } - else - { - if (!ttv || lhsType->persistent || ttv->state == TableState::Sealed) - return errorRecoveryType(scope); + return errorRecoveryType(scope); } Name name = indexName->index.value; if (ttv->props.count(name)) - { - if (FFlag::LuauStatFunctionSimplify4) - return ttv->props[name].type; - else - return errorRecoveryType(scope); - } + return ttv->props[name].type; Property& property = ttv->props[name]; property.type = freshTy(); property.location = indexName->indexLocation; - ttv->methodDefinitionLocations[name] = funName.location; + if (!FFlag::LuauNoMethodLocations) + ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) @@ -4669,7 +4622,8 @@ TypeId ReplaceGenerics::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } @@ -4715,7 +4669,8 @@ TypeId Anyification::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 91f5cd25..c053e6bd 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau { @@ -2821,7 +2820,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type; + type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; } } else diff --git a/CMakeLists.txt b/CMakeLists.txt index af03b33a..ea352309 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,12 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +if(MSVC AND LUAU_BUILD_CLI) + # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger + set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) + set_target_properties(Luau.Repl.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) +endif() + # embed .natvis inside the library debug information if(MSVC) target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 4fe26222..e177e928 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -628,7 +628,12 @@ struct Compiler return; if (fi && !fi->canInline) - bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + { + if (func->vararg) + bytecode.addDebugRemark("inlining failed: function is variadic"); + else + bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + } } RegScope rs(this); @@ -2342,17 +2347,28 @@ struct Compiler RegScope rs(this); uint8_t temp = 0; + bool consecutive = false; bool multRet = false; - // Optimization: return local value directly instead of copying it into a temporary - if (stat->list.size == 1 && isExprLocalReg(stat->list.data[0])) + // Optimization: return locals directly instead of copying them into a temporary + // this is very important for a single return value and occasionally effective for multiple values + if (stat->list.size > 0 && isExprLocalReg(stat->list.data[0])) { - AstExprLocal* le = stat->list.data[0]->as(); - LUAU_ASSERT(le); + temp = getLocal(stat->list.data[0]->as()->local); + consecutive = true; - temp = getLocal(le->local); + for (size_t i = 1; i < stat->list.size; ++i) + { + AstExpr* v = stat->list.data[i]; + if (!isExprLocalReg(v) || getLocal(v->as()->local) != temp + i) + { + consecutive = false; + break; + } + } } - else if (stat->list.size > 0) + + if (!consecutive && stat->list.size > 0) { temp = allocReg(stat, unsigned(stat->list.size)); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 52ece73e..e4d59ea1 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -195,12 +195,16 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables; DenseHashMap& locals; + bool wasEmpty = false; + ConstantVisitor( DenseHashMap& constants, DenseHashMap& variables, DenseHashMap& locals) : constants(constants) , variables(variables) , locals(locals) { + // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries + wasEmpty = constants.empty() && locals.empty(); } Constant analyze(AstExpr* node) @@ -326,7 +330,7 @@ struct ConstantVisitor : AstVisitor { if (value.type != Constant::Type_Unknown) map[key] = value; - else if (!FFlag::LuauCompileSupportInlining) + else if (!FFlag::LuauCompileSupportInlining || wasEmpty) ; else if (Constant* old = map.find(key)) old->type = Constant::Type_Unknown; diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 9afd09f6..f804e9de 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -187,7 +187,7 @@ struct CostVisitor : AstVisitor if (node->is()) result += 2; else if (node->is() || node->is() || node->is() || node->is()) - result += 2; + result += 5; else if (node->is() || node->is()) result += 1; diff --git a/extern/isocline/src/bbcode.c b/extern/isocline/src/bbcode.c index 4d11ac38..8722cbd6 100644 --- a/extern/isocline/src/bbcode.c +++ b/extern/isocline/src/bbcode.c @@ -575,6 +575,7 @@ ic_private const char* parse_tag_value( tag_t* tag, char* idbuf, const char* s, } // limit name and attr to 128 bytes char valbuf[128]; + valbuf[0] = 0; // fixes gcc uninitialized warning ic_strncpy( idbuf, 128, id, idend - id); ic_strncpy( valbuf, 128, val, valend - val); ic_str_tolower(idbuf); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 292625b0..12c68450 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -7,7 +7,7 @@ using namespace Luau; -struct DocumentationSymbolFixture : Fixture +struct DocumentationSymbolFixture : BuiltinsFixture { std::optional getDocSymbol(const std::string& source, Position position) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1c284f1f..b4e9340c 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -27,7 +27,7 @@ template struct ACFixtureImpl : BaseType { ACFixtureImpl() - : Fixture(true, true) + : BaseType(true, true) { } @@ -111,6 +111,18 @@ struct ACFixtureImpl : BaseType }; struct ACFixture : ACFixtureImpl +{ + ACFixture() + : ACFixtureImpl() + { + addGlobalBinding(frontend.typeChecker, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeChecker, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + } +}; + +struct ACBuiltinsFixture : ACFixtureImpl { }; @@ -277,7 +289,7 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") CHECK(ac.entryMap.count("test")); } -TEST_CASE_FIXTURE(ACFixture, "get_member_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_member_completions") { check(R"( local a = table.@1 @@ -376,7 +388,7 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") CHECK(ac.entryMap.count("c3")); } -TEST_CASE_FIXTURE(ACFixture, "get_string_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_string_completions") { check(R"( local a = ("foo"):@1 @@ -427,7 +439,7 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") CHECK(!ac.entryMap.count("math")); } -TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "method_call_inside_if_conditional") { check(R"( if table: @1 @@ -1884,7 +1896,7 @@ ex.b(function(x: CHECK(!ac.entryMap.count("(done) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "suggest_external_module_type") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "suggest_external_module_type") { fileResolver.source["Module/A"] = R"( export type done = { x: number, y: number } @@ -2235,7 +2247,7 @@ local a: aaa.do CHECK(ac.entryMap.count("other")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteSource") { std::string_view source = R"( local a = table. -- Line 1 @@ -2269,7 +2281,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments") CHECK_EQ(0, ac.entryMap.size()); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteProp_index_function_metamethod_is_variadic") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic") { std::string_view source = R"( type Foo = {x: number} @@ -2720,7 +2732,7 @@ type A = () -> T CHECK(ac.entryMap.count("string")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_oop_implicit_self") { check(R"( --!strict @@ -2728,15 +2740,15 @@ local Class = {} Class.__index = Class type Class = typeof(setmetatable({} :: { x: number }, Class)) function Class.new(x: number): Class - return setmetatable({x = x}, Class) + return setmetatable({x = x}, Class) end function Class.getx(self: Class) - return self.x + return self.x end function test() - local c = Class.new(42) - local n = c:@1 - print(n) + local c = Class.new(42) + local n = c:@1 + print(n) end )"); @@ -2745,7 +2757,7 @@ end CHECK(ac.entryMap.count("getx")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") { check(R"( --!strict @@ -2989,7 +3001,7 @@ s.@1 CHECK(ac.entryMap["sub"].wrongIndexType == true); } -TEST_CASE_FIXTURE(ACFixture, "string_library_non_self_calls_are_fine") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_non_self_calls_are_fine") { ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; @@ -3007,7 +3019,7 @@ string.@1 CHECK(ac.entryMap["sub"].wrongIndexType == false); } -TEST_CASE_FIXTURE(ACFixture, "string_library_self_calls_are_invalid") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_self_calls_are_invalid") { ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index dbe80f2c..496df4b4 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -10,8 +10,10 @@ using namespace Luau; TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); -TEST_CASE_FIXTURE(Fixture, "lib_documentation_symbols") +TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { + CHECK(!typeChecker.globalScope->bindings.empty()); + for (const auto& [name, binding] : typeChecker.globalScope->bindings) { std::string nameString(name.c_str()); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f206438f..b032060e 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4713,7 +4713,6 @@ local function foo() end local a, b = foo() - return a, b )", 1, 2), @@ -4721,9 +4720,7 @@ return a, b DUPCLOSURE R0 K0 LOADNIL R1 LOADNIL R2 -MOVE R3 R1 -MOVE R4 R2 -RETURN R3 2 +RETURN R1 2 )"); // this happens even if the function returns conditionally @@ -4733,7 +4730,6 @@ local function foo(a) end local a, b = foo(false) - return a, b )", 1, 2), @@ -4741,9 +4737,7 @@ return a, b DUPCLOSURE R0 K0 LOADNIL R1 LOADNIL R2 -MOVE R3 R1 -MOVE R4 R2 -RETURN R3 2 +RETURN R1 2 )"); // note though that we can't inline a function like this in multret context @@ -4880,11 +4874,7 @@ LOADN R5 1 ADD R4 R5 R1 LOADN R5 3 ADD R6 R1 R2 -MOVE R7 R3 -MOVE R8 R4 -MOVE R9 R5 -MOVE R10 R6 -RETURN R7 4 +RETURN R3 4 )"); } @@ -5151,4 +5141,59 @@ RETURN R0 0 )"); } +TEST_CASE("ReturnConsecutive") +{ + // we can return a single local directly + CHECK_EQ("\n" + compileFunction0(R"( +local x = ... +return x +)"), + R"( +GETVARARGS R0 1 +RETURN R0 1 +)"); + + // or multiple, when they are allocated in consecutive registers + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 +)"); + + // but not if it's an expression + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y + 1 +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R0 +ADDK R3 R1 K0 +RETURN R2 2 +)"); + + // or a local with wrong register number + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return y, x +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R1 +MOVE R3 R0 +RETURN R2 2 +)"); + + // also double check the optimization doesn't trip on no-argument return (these are rare) + CHECK_EQ("\n" + compileFunction0(R"( +return +)"), + R"( +RETURN R0 0 +)"); +} + TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index aa5b7284..2fa0659b 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -76,9 +76,9 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - // loop baseline cost is 2 - CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); + // loop baseline cost is 5 + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("MutableVariable") @@ -154,8 +154,8 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - CHECK_EQ(38, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(37, Luau::Compile::computeCost(model, args2, 1)); + CHECK_EQ(50, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(49, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("Conditional") @@ -219,8 +219,8 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); + CHECK_EQ(7, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index d8b37a65..03f3e15c 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -92,10 +92,6 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend.typeChecker); - if (prepareAutocomplete) - registerBuiltinTypes(frontend.typeCheckerForAutocomplete); - registerTestTypes(); Luau::freeze(frontend.typeChecker.globalTypes); Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); @@ -410,6 +406,21 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) return result; } +BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) + : Fixture(freeze, prepareAutocomplete) +{ + Luau::unfreeze(frontend.typeChecker.globalTypes); + Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + + registerBuiltinTypes(frontend.typeChecker); + if (prepareAutocomplete) + registerBuiltinTypes(frontend.typeCheckerForAutocomplete); + registerTestTypes(); + + Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); +} + ModuleName fromString(std::string_view name) { return ModuleName(name); diff --git a/tests/Fixture.h b/tests/Fixture.h index 0d1233bf..901f7d42 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -151,6 +151,11 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; +struct BuiltinsFixture : Fixture +{ + BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); +}; + ModuleName fromString(std::string_view name); template diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index a10e8f7f..33b81be8 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -77,7 +77,7 @@ struct NaiveFileResolver : NullFileResolver } // namespace -struct FrontendFixture : Fixture +struct FrontendFixture : BuiltinsFixture { FrontendFixture() { diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 6649cb7f..202aeceb 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -75,7 +75,7 @@ _ = 6 CHECK_EQ(result.warnings.size(), 0); } -TEST_CASE_FIXTURE(Fixture, "BuiltinGlobalWrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "BuiltinGlobalWrite") { LintResult result = lint(R"( math = {} @@ -309,7 +309,7 @@ print(arg) CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); } -TEST_CASE_FIXTURE(Fixture, "LocalShadowGlobal") +TEST_CASE_FIXTURE(BuiltinsFixture, "LocalShadowGlobal") { LintResult result = lint(R"( local math = math @@ -1470,7 +1470,7 @@ end CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); } -TEST_CASE_FIXTURE(Fixture, "TableOperations") +TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { LintResult result = lintTyped(R"( local t = {} diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 44cc20a7..4a999861 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -113,7 +113,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(2, dest.typeVars.size()); // One table and one function } -TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { CheckResult result = check(R"( return {sign=math.sign} @@ -250,7 +250,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); } -TEST_CASE_FIXTURE(Fixture, "clone_self_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index feeaf2c2..69430b1c 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,7 +13,7 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); -TEST_CASE_FIXTURE(Fixture, "function_returns_number_or_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") { ScopedFastFlag sff[]{ {"LuauReturnTypeInferenceInNonstrict", true}, @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); } -TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") { CheckResult result = check(R"( --!nonstrict @@ -243,7 +243,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_insert_and_recursive_calls") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_insert_and_recursive_calls") { CheckResult result = check(R"( --!nonstrict diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index d3778f67..41830682 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -739,7 +739,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); } -TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_distinct_free_types") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, @@ -760,7 +760,7 @@ TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") CHECK("(a, b) -> a | b" == toString(requireType("fussy"))); } -TEST_CASE_FIXTURE(Fixture, "constrained_intersection_of_intersections") +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_intersection_of_intersections") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, @@ -951,7 +951,7 @@ TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "visiting_a_type_twice_is_not_considered_normal") +TEST_CASE_FIXTURE(BuiltinsFixture, "visiting_a_type_twice_is_not_considered_normal") { ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; @@ -976,7 +976,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauNormalizeCombineIntersectionFix", true}, }; CheckResult result = check(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 69ff73ad..c9d8d0b8 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1618,8 +1618,6 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") { - ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; - // Same should hold when comments are captured ParseOptions opts; opts.captureComments = true; diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 538f3576..c16f60d5 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -17,7 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); -struct LimitFixture : Fixture +struct LimitFixture : BuiltinsFixture { #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 332a4b22..e9fa5b26 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -224,7 +224,7 @@ n1 -> n4 [label="typePackParam"]; (void)toDot(requireType("a")); } -TEST_CASE_FIXTURE(Fixture, "metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable") { CheckResult result = check(R"( local a: typeof(setmetatable({}, {})) diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index ccf5c583..50d0838e 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -60,7 +60,7 @@ TEST_CASE_FIXTURE(Fixture, "named_table") CHECK_EQ("TheTable", toString(&table)); } -TEST_CASE_FIXTURE(Fixture, "exhaustive_toString_of_cyclic_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( --!strict @@ -338,7 +338,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") REQUIRE_EQ("c", toString(params[2], opts)); } -TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") +TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b0eb31ce..7562a4d7 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -279,7 +279,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") CHECK_EQ("Node", toString(e->wantedType)); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") { fileResolver.source["workspace/A"] = R"( export type myvec2 = {x: number, y: number} @@ -317,7 +317,7 @@ TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") REQUIRE(bType->props.size() == 3); } -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") { CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); @@ -385,7 +385,7 @@ type Cool = typeof(c) CHECK_EQ(ttv->name, "Cool"); } -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_type") { fileResolver.source["game/A"] = R"( export type X = { a: number, b: X? } @@ -410,7 +410,7 @@ type X = Import.X CHECK_EQ(follow(*ty1), follow(*ty2)); } -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_generic_type") { fileResolver.source["game/A"] = R"( export type X = { a: T, b: U, C: X? } @@ -564,7 +564,7 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") * * We solved this by ascribing a unique subLevel to each prototyped alias. */ -TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_quantify_unresolved_aliases") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 7f1c757a..b9e1ae96 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -528,7 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti CHECK_EQ(recordType, bType); } -TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") +TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -554,7 +554,7 @@ TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -580,7 +580,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "builtin_types_are_not_exported") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -676,7 +676,7 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") )"); } -TEST_CASE_FIXTURE(Fixture, "luau_print_is_magic_if_the_flag_is_set") +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") { // Luau::resetPrintLine(); ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 5224b5d8..bc55940e 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -237,7 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") CHECK_EQ("*unknown*", toString(requireType("a"))); } -TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") { CheckResult result = check(R"( local a: any @@ -285,7 +285,7 @@ end LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_of_any_can_be_a_table") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1ae65947..b710ea0d 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -12,7 +12,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); TEST_SUITE_BEGIN("BuiltinTests"); -TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_things_are_defined") { CheckResult result = check(R"( local a00 = math.frexp @@ -50,7 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "next_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( local a: string, b: number = next({ 1 }) @@ -63,7 +63,7 @@ TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } @@ -75,7 +75,7 @@ TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } @@ -87,7 +87,7 @@ TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_remove_optionally_returns_generic") { CheckResult result = check(R"( local t = { 1 } @@ -98,7 +98,7 @@ TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") CHECK_EQ(toString(requireType("n")), "number?"); } -TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") { CheckResult result = check(R"( local r = table.concat({1,2,3,4}, ",", 2); @@ -108,7 +108,7 @@ TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") CHECK_EQ(*typeChecker.stringType, *requireType("r")); } -TEST_CASE_FIXTURE(Fixture, "sort") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort") { CheckResult result = check(R"( local t = {1, 2, 3}; @@ -118,7 +118,7 @@ TEST_CASE_FIXTURE(Fixture, "sort") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") { CheckResult result = check(R"( --!strict @@ -130,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { CheckResult result = check(R"( --!strict @@ -140,6 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((a, a) -> boolean)?' +caused by: + None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(a, a) -> boolean' +caused by: + Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "strings_have_methods") @@ -152,7 +158,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "math_max_variatic") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") { CheckResult result = check(R"( local n = math.max(1,2,3,4,5,6,7,8,9,0) @@ -162,16 +168,17 @@ TEST_CASE_FIXTURE(Fixture, "math_max_variatic") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "math_max_checks_for_numbers") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") { CheckResult result = check(R"( local n = math.max(1,2,"3") )"); CHECK(!result.errors.empty()); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_tables_sealed") { CheckResult result = check(R"LUA( local b = bit32 @@ -183,7 +190,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") CHECK_EQ(bit32t->state, TableState::Sealed); } -TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") +TEST_CASE_FIXTURE(BuiltinsFixture, "lua_51_exported_globals_all_exist") { // Extracted from lua5.1 CheckResult result = check(R"( @@ -340,7 +347,7 @@ TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_unpacks_arg_types_correctly") { CheckResult result = check(R"( setmetatable({}, setmetatable({}, {})) @@ -348,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_2_args_overload") { CheckResult result = check(R"( local t = {} @@ -360,7 +367,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} @@ -372,7 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_o CHECK_EQ("string", toString(requireType("s"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack") { CheckResult result = check(R"( local t = table.pack(1, "foo", true) @@ -382,7 +389,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack") CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_variadic") { CheckResult result = check(R"( --!strict @@ -397,7 +404,7 @@ local t = table.pack(f()) CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_reduce") { CheckResult result = check(R"( local t = table.pack(1, 2, true) @@ -414,7 +421,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "gcinfo") +TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") { CheckResult result = check(R"( local n = gcinfo() @@ -424,12 +431,12 @@ TEST_CASE_FIXTURE(Fixture, "gcinfo") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "getfenv") +TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") { LUAU_REQUIRE_NO_ERRORS(check("getfenv(1)")); } -TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") { CheckResult result = check(R"( local n1 = os.time() @@ -443,7 +450,7 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") { CheckResult result = check(R"( local co = coroutine.create(function() end) @@ -453,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") CHECK_EQ(*typeChecker.threadType, *requireType("co")); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") { CheckResult result = check(R"( local function nifty(x, y) @@ -471,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_wrap_anything_goes") { CheckResult result = check(R"( --!nonstrict @@ -490,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_should_not_mutate_persisted_types") { CheckResult result = check(R"( local string = string @@ -505,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") REQUIRE(ttv); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_types_inference") { CheckResult result = check(R"( --!strict @@ -518,7 +525,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") CHECK_EQ("(number, number, string) -> string", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_count_mismatch") { CheckResult result = check(R"( --!strict @@ -534,7 +541,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") CHECK_EQ(result.errors[2].location.begin.line, 4); } -TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") { CheckResult result = check(R"( --!strict @@ -548,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") CHECK_EQ(tm->givenType, typeChecker.numberType); } -TEST_CASE_FIXTURE(Fixture, "xpcall") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall") { CheckResult result = check(R"( --!strict @@ -564,7 +571,7 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") CHECK_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select") { CheckResult result = check(R"( local a:number, b:boolean = select(2,"hi", 10, true) @@ -573,7 +580,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select_count") { CheckResult result = check(R"( local a = select("#","hi", 10, true) @@ -583,7 +590,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down") { CheckResult result = check(R"( local a: number, b: boolean = select(2.9, "foo", 1, true) @@ -608,7 +615,7 @@ TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") CHECK_LE(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") { CheckResult result = check(R"( select(5432598430953240958) @@ -619,7 +626,7 @@ TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_slightly_out_of_range") { CheckResult result = check(R"( select(3, "a", 1) @@ -630,7 +637,7 @@ TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") { CheckResult result = check(R"( --!nonstrict @@ -649,7 +656,7 @@ TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") CHECK_EQ("any", toString(requireType("quux"))); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail_and_string_head") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") { CheckResult result = check(R"( --!nonstrict @@ -703,7 +710,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "debug_traceback_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_traceback_is_crazy") { CheckResult result = check(R"( local co: thread = ... @@ -720,7 +727,7 @@ debug.traceback(co, "msg", 1) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "debug_info_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_info_is_crazy") { CheckResult result = check(R"( local co: thread, f: ()->() = ... @@ -734,7 +741,7 @@ debug.info(f, "n") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "aliased_string_format") +TEST_CASE_FIXTURE(BuiltinsFixture, "aliased_string_format") { CheckResult result = check(R"( local fmt = string.format @@ -745,7 +752,7 @@ TEST_CASE_FIXTURE(Fixture, "aliased_string_format") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_lib_self_noself") { CheckResult result = check(R"( --!nonstrict @@ -764,7 +771,7 @@ TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "gmatch_definition") +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_definition") { CheckResult result = check(R"_( local a, b, c = ("hey"):gmatch("(.)(.)(.)")() @@ -777,7 +784,7 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_on_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_on_variadic") { CheckResult result = check(R"( local function f(): (number, ...(boolean | number)) @@ -793,7 +800,7 @@ TEST_CASE_FIXTURE(Fixture, "select_on_variadic") CHECK_EQ("any", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_positions") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_correct_positions") { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) @@ -825,7 +832,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } -TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type") { CheckResult result = check(R"( --!strict @@ -836,7 +843,7 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type2") { CheckResult result = check(R"( --!strict @@ -846,7 +853,7 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( local f = math.sin @@ -868,7 +875,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") REQUIRE(gtv->definition); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -889,7 +896,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -907,7 +914,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -924,7 +931,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_ CHECK_EQ("(...number?) -> (number, ...number?)", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -941,7 +948,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir CHECK_EQ("(nil) -> nil", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") { CheckResult result = check(R"( local t1: {a: number} = {a = 42} @@ -968,7 +975,7 @@ TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") CHECK_EQ("*unknown*", toString(requireType("d"))); } -TEST_CASE_FIXTURE(Fixture, "set_metatable_needs_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( @@ -991,7 +998,7 @@ local function f(a: typeof(f)) end CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { TypeId mathTy = requireType(typeChecker.globalScope, "math"); REQUIRE(mathTy); @@ -1008,7 +1015,7 @@ TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") CHECK(ftv->level.subLevel == original.subLevel); } -TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "global_singleton_types_are_sealed") { CheckResult result = check(R"( local function f(x: string) diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 5a6e4032..d90129d7 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; using std::nullopt; -struct ClassFixture : Fixture +struct ClassFixture : BuiltinsFixture { ClassFixture() { diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 0e071217..14f1f703 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -85,7 +85,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") { CheckResult result = check(R"( local T = {} @@ -555,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") CHECK(bool(argType->indexer)); } -TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") { CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) @@ -620,7 +620,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "mutual_recursion") +TEST_CASE_FIXTURE(BuiltinsFixture, "mutual_recursion") { CheckResult result = check(R"( --!strict @@ -639,7 +639,7 @@ TEST_CASE_FIXTURE(Fixture, "mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") +TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") { CheckResult result = check(R"( --!strict @@ -676,7 +676,7 @@ TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") { CheckResult result = check(R"( function onerror() end @@ -794,7 +794,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields }})); } -TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_leak_free_types") { ScopedFastFlag sff[]{ {"LuauReturnTypeInferenceInNonstrict", true}, @@ -966,7 +966,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") CHECK_EQ("string", toString(requireType("z"))); } -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1068,7 +1068,7 @@ f(function(x) return x * 2 end) } } -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1287,10 +1287,8 @@ caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } -TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict return function(value) @@ -1311,10 +1309,8 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( function string.len(): number return 1 @@ -1333,11 +1329,8 @@ print(string.len('hello')) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; - CheckResult result = check(R"( local t: { f: ((x: number) -> number)? } = {} @@ -1477,11 +1470,8 @@ TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_th LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_unsealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; - CheckResult result = check(R"( local t = { f = nil :: ((x: number) -> number)? } @@ -1518,8 +1508,6 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( local t: {[string]: () -> number} = {} @@ -1580,7 +1568,7 @@ wrapper(test) CHECK(acm->isVariadic); } -TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") +TEST_CASE_FIXTURE(BuiltinsFixture, "too_few_arguments_variadic_generic2") { CheckResult result = check(R"( function test(a: number, b: string, ...) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 91be2c1c..de0c9391 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -67,7 +67,7 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "inferred_local_vars_can_be_polytypes") { CheckResult result = check(R"( local function id(x) return x end @@ -79,7 +79,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "local_vars_can_be_instantiated_polytypes") { CheckResult result = check(R"( local function id(x) return x end @@ -609,7 +609,7 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } -TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +TEST_CASE_FIXTURE(BuiltinsFixture, "bound_tables_do_not_clone_original_fields") { CheckResult result = check(R"( local exports = {} @@ -675,7 +675,7 @@ local d: D = c R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } -TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_functions_dont_cache_type_parameters") { CheckResult result = check(R"( -- See https://github.com/Roblox/luau/issues/332 @@ -1013,7 +1013,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") CHECK(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; @@ -1078,7 +1078,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded" LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( local a = {{x=4}, {x=7}, {x=1}} diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 3675919f..41bc0c21 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -316,8 +316,6 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } @@ -351,8 +349,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one CheckResult result = check(R"( type XY = { x: (number) -> number, y: (string) -> string } @@ -375,7 +371,7 @@ caused by: CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); } -TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_intersection_setmetatable") { CheckResult result = check(R"( local t: {} & {} diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index f9b510c1..765419c6 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -29,7 +29,7 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") CHECK_EQ(*typeChecker.numberType, *requireType("q")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") { CheckResult result = check(R"( local n @@ -46,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") { CheckResult result = check(R"( local n @@ -90,7 +90,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") { CheckResult result = check(R"( local function keys(dictionary) @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_a_custom_iterator_should_type_check") { CheckResult result = check(R"( local function range(l, h): () -> number @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") REQUIRE(get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") { CheckResult result = check(R"( local function hasDivisors(value: number, table) @@ -210,7 +210,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right CHECK_EQ(typeChecker.stringType, tm->givenType); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") { CheckResult result = check(R"( function prime_iter(state, index) @@ -288,7 +288,7 @@ TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +TEST_CASE_FIXTURE(BuiltinsFixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") { CheckResult result = check(R"( repeat @@ -301,7 +301,7 @@ TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") +TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( local T = {} @@ -316,7 +316,7 @@ TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") { // In this case, we cannot know the element type of the table {}. It could be anything. // We therefore must initially ascribe a free typevar to iter. @@ -329,7 +329,7 @@ TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") { CheckResult result = check(R"( while true do @@ -346,7 +346,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( local key @@ -378,7 +378,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") +TEST_CASE_FIXTURE(BuiltinsFixture, "unreachable_code_after_infinite_loop") { { CheckResult result = check(R"( @@ -460,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") } } -TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") { CheckResult result = check(R"( local t = {} @@ -541,7 +541,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") CHECK_EQ("Cannot iterate over a table without indexer", ge->message); } -TEST_CASE_FIXTURE(Fixture, "loop_iter_iter_metamethod") +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") { ScopedFastFlag sff{"LuauTypecheckIter", true}; diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index b6f49f9f..efa2a98d 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -16,7 +16,7 @@ LUAU_FASTFLAG(LuauTableSubtypingVariance2) TEST_SUITE_BEGIN("TypeInferModules"); -TEST_CASE_FIXTURE(Fixture, "require") +TEST_CASE_FIXTURE(BuiltinsFixture, "require") { fileResolver.source["game/A"] = R"( local function hooty(x: number): string @@ -54,7 +54,7 @@ TEST_CASE_FIXTURE(Fixture, "require") REQUIRE_EQ("number", toString(*hType)); } -TEST_CASE_FIXTURE(Fixture, "require_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_types") { fileResolver.source["workspace/A"] = R"( export type Point = {x: number, y: number} @@ -69,7 +69,7 @@ TEST_CASE_FIXTURE(Fixture, "require_types") )"; CheckResult bResult = frontend.check("workspace/B"); - dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; REQUIRE(b != nullptr); @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "require_types") REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); } -TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") { fileResolver.source["game/A"] = R"( local T = {} @@ -121,7 +121,7 @@ TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); } -TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") { const std::string sourceA = R"( )"; @@ -148,7 +148,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") CHECK_EQ("*unknown*", toString(hootyType)); } -TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") { fileResolver.source["Modules/A"] = ""; fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; @@ -164,7 +164,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") CHECK(get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_call_expression") { fileResolver.source["game/A"] = R"( --!strict @@ -183,7 +183,7 @@ a = tbl.abc.def CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_type_mismatch") { fileResolver.source["game/A"] = R"( return { def = 4 } @@ -219,7 +219,7 @@ return m LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "custom_require_global") +TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") { CheckResult result = check(R"( --!nonstrict @@ -231,7 +231,7 @@ local crash = require(game.A) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "require_failed_module") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_failed_module") { fileResolver.source["game/A"] = R"( return unfortunately() @@ -267,7 +267,7 @@ function x:Destroy(): () end LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_2") { fileResolver.source["game/A"] = R"( export type Type = { x: { a: number } } @@ -285,7 +285,7 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_3") { fileResolver.source["game/A"] = R"( local y = setmetatable({}, {}) @@ -304,7 +304,7 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "module_type_conflict") +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") { fileResolver.source["game/A"] = R"( export type T = { x: number } @@ -338,7 +338,7 @@ caused by: } } -TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") { fileResolver.source["game/A"] = R"( export type Wrap = { x: T } diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 5cd3f3ba..41690704 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -142,7 +142,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocat CHECK_GE(50, module->internalTypes.typeVars.size()); } -TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") +TEST_CASE_FIXTURE(BuiltinsFixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 CheckResult result = check(R"( @@ -243,7 +243,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ )"); } -TEST_CASE_FIXTURE(Fixture, "table_oop") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_oop") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index a2787cad..51f6fdfb 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") CHECK_EQ(toString(*requireType("s")), "number | string"); } -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) @@ -140,7 +140,7 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") CHECK_EQ("number", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection") { CheckResult result = check(R"( --!strict @@ -174,7 +174,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("e"))); } -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { CheckResult result = check(R"( --!strict @@ -245,7 +245,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_m REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); } -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") { CheckResult result = check(R"( local M = {} @@ -266,7 +266,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_ov REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); } -TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") { CheckResult result = check(R"( --!strict @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_meta REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); } -TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +TEST_CASE_FIXTURE(BuiltinsFixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") { CheckResult result = check(R"( --!strict @@ -361,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") { CheckResult result = check(R"( --!strict @@ -381,7 +381,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") CHECK_EQ(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") { CheckResult result = check(R"( --!strict @@ -428,7 +428,7 @@ local x = false LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") { CheckResult result = check(R"( --!strict @@ -461,7 +461,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); } -TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") { CheckResult result = check(R"( local b = not "string" @@ -473,7 +473,7 @@ TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") REQUIRE_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") { CheckResult result = check(R"( --!strict @@ -573,7 +573,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") +TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") { CheckResult result = check(R"( --!strict @@ -628,7 +628,7 @@ TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") +TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed { @@ -755,8 +755,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") { - ScopedFastFlag sff{"LuauDecoupleOperatorInferenceFromUnifiedTypeInference", true}; - CheckResult result = check(Mode::Strict, R"( local function f(x, y) return x + y @@ -779,4 +777,47 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") // the case right now, though. } +TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") +{ + ScopedFastFlag sff("LuauSuccessTypingForEqualityOperations", true); + + CheckResult result = check(R"( + local mm = {} + type Foo = typeof(setmetatable({}, mm)) + local x: Foo + local y: Foo? + + local v1 = x == y + local v2 = y == x + local v3 = x ~= y + local v4 = y ~= x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CheckResult result2 = check(R"( + local mm1 = { + x = "foo", + } + + local mm2 = { + y = "bar", + } + + type Foo = typeof(setmetatable({}, mm1)) + type Bar = typeof(setmetatable({}, mm2)) + + local x1: Foo + local x2: Foo? + local y1: Bar + local y2: Bar? + + local v1 = x1 == y1 + local v2 = x2 == y2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result2); + CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2ef77419..ee3ae972 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -53,7 +53,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") CHECK_EQ(expected, decorateWithTypes(code)); } -TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") { const std::string code = R"( local a, b, c = xpcall(function() return 1, "foo" end, function() return "foo", 1 end) @@ -105,7 +105,7 @@ TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") // Ideally setmetatable's second argument would be an optional free table. // For now, infer it as just a free table. -TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_table") { CheckResult result = check(R"( local a = {} @@ -146,7 +146,7 @@ TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") // Originally from TypeInfer.test.cpp. // I dont think type checking the metamethod at every site of == is the correct thing to do. // We should be type checking the metamethod at the call site of setmetatable. -TEST_CASE_FIXTURE(Fixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") { CheckResult result = check(R"( local tab = {a = 1} @@ -428,7 +428,7 @@ TEST_CASE_FIXTURE(Fixture, "pcall_returns_at_least_two_value_but_function_return } // Belongs in TypeInfer.builtins.test.cpp. -TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") +TEST_CASE_FIXTURE(BuiltinsFixture, "choose_the_right_overload_for_pcall") { CheckResult result = check(R"( local function f(): number @@ -449,7 +449,7 @@ TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") } // Belongs in TypeInfer.builtins.test.cpp. -TEST_CASE_FIXTURE(Fixture, "function_returns_many_things_but_first_of_it_is_forgotten") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it_is_forgotten") { CheckResult result = check(R"( local function f(): (number, string, boolean) diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 136ca00a..8c130490 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -240,7 +240,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_assert_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") { CheckResult result = check(R"( local a @@ -300,7 +300,7 @@ TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") { // This unit test serves as a reminder to not implement this warning until Luau is intelligent enough. // For instance, getting a value out of the indexer and checking whether the value exists is not an error. @@ -333,7 +333,7 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") CHECK_EQ("number?", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_on_a_refined_property") { CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -346,7 +346,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_constraints") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resolve_constraints") { CheckResult result = check(R"( local foo: string? = "hello" @@ -730,7 +730,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { CheckResult result = check(R"( local function f(t: {x: number}) @@ -846,7 +846,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); } -TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { CheckResult result = check(R"( local a: (number | string)? @@ -862,7 +862,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); } -TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") +TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -899,7 +899,7 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") { CheckResult result = check(R"( function f(v:string?) @@ -913,7 +913,7 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); } -TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expression") { CheckResult result = check(R"( function f(v:string?) @@ -945,7 +945,7 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } -TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") { CheckResult result = check(R"( local foo: string? = "hi" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 8d6682b8..79eeb824 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -415,7 +415,7 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { ScopedFastFlag sff[]{ {"LuauWidenIfSupertypeIsFree2", true}, diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 8e535995..5078b0bf 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -201,7 +201,7 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") REQUIRE(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "used_colon_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "used_colon_correctly") { CheckResult result = check(R"( --!nonstrict @@ -883,7 +883,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s CHECK_EQ(*typeChecker.stringType, *propertyA); } -TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") { CheckResult result = check(R"( local clazz = {} @@ -906,7 +906,7 @@ TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") CHECK_EQ(*typeChecker.stringType, *requireType("words")); } -TEST_CASE_FIXTURE(Fixture, "indexer_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") { CheckResult result = check(R"( local clazz = {a="hello"} @@ -919,7 +919,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_table") CHECK_EQ(*typeChecker.stringType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "indexer_fn") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") { CheckResult result = check(R"( local instanace = setmetatable({}, {__index=function() return 10 end}) @@ -930,7 +930,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_fn") CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "meta_add") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") { // Note: meta_add_inferred and this unit test are currently the same exact thing. // We'll want to change this one in particular when we add real syntax for metatables. @@ -947,7 +947,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add") CHECK_EQ(follow(requireType("a")), follow(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_inferred") { CheckResult result = check(R"( local a = {} @@ -960,7 +960,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") CHECK_EQ(*requireType("a"), *requireType("c")); } -TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") { CheckResult result = check(R"( type VectorMt = { __add: (Vector, number) -> Vector } @@ -980,7 +980,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") // This test exposed a bug where we let go of the "seen" stack while unifying table types // As a result, type inference crashed with a stack overflow. -TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_type") { CheckResult result = check(R"( type A = {} @@ -1009,7 +1009,7 @@ TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") CHECK_EQ(bmtv->metatable, requireType("bmt")); } -TEST_CASE_FIXTURE(Fixture, "oop_polymorphic") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") { CheckResult result = check(R"( local animal = {} @@ -1060,7 +1060,7 @@ TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") CHECK_EQ("Vector3", toString(requireType("v"))); } -TEST_CASE_FIXTURE(Fixture, "result_is_always_any_if_lhs_is_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "result_is_always_any_if_lhs_is_any") { CheckResult result = check(R"( type Vector3MT = { @@ -1133,7 +1133,7 @@ TEST_CASE_FIXTURE(Fixture, "nice_error_when_trying_to_fetch_property_of_boolean" CHECK_EQ("Type 'boolean' does not have key 'some_prop'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string.m() end @@ -1142,7 +1142,7 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fa LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string:m() end @@ -1261,7 +1261,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_function_call") CHECK_EQ(toString(te), "Key 'fOo' not found in table 't'. Did you mean 'Foo'?"); } -TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_like_key_in_table_property_access") { CheckResult result = check(R"( local t = {X = 1} @@ -1286,7 +1286,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") CHECK_EQ(toString(te), "Key 'x' not found in table 't'. Did you mean 'X'?"); } -TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_multiple_like_keys") { CheckResult result = check(R"( local t = {Foo = 1, foO = 2} @@ -1312,7 +1312,7 @@ TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") CHECK_EQ(toString(te), "Key 'foo' not found in table 't'. Did you mean one of 'Foo', 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_suggest_exact_match_keys") { CheckResult result = check(R"( local t = {} @@ -1339,7 +1339,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") CHECK_EQ(toString(te), "Key 'Foo' not found in table 't'. Did you mean 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_pointer_to_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1352,7 +1352,7 @@ TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") CHECK_EQ(*requireType("mt"), *requireType("returnedMT")); } -TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_mismatch_should_fail") { CheckResult result = check(R"( local t1 = {x = 1} @@ -1374,7 +1374,7 @@ TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") CHECK_EQ(*tm->givenType, *requireType("t2")); } -TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "property_lookup_through_tabletypevar_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1393,7 +1393,7 @@ TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") CHECK_EQ(up->key, "z"); } -TEST_CASE_FIXTURE(Fixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") { CheckResult result = check(R"( local t = {x = 1} @@ -1742,7 +1742,7 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "builtin_table_names") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { CheckResult result = check(R"( os.h = 2 @@ -1755,7 +1755,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_table_names") CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "persistent_sealed_table_is_immutable") +TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { CheckResult result = check(R"( --!nonstrict @@ -1858,7 +1858,7 @@ local foos: {Foo} = { LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") { CheckResult result = check(R"( local clazz = {} @@ -1983,7 +1983,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") { CheckResult result = check(R"( --!nonstrict @@ -1996,7 +1996,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_strict") { ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; @@ -2052,7 +2052,7 @@ caused by: Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); } -TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") { ScopedFastFlag sff[]{ {"LuauTableSubtypingVariance2", true}, @@ -2183,7 +2183,7 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { ScopedFastFlag sff[]{ {"LuauUnsealedTableLiteral", true}, @@ -2277,7 +2277,7 @@ local y = #x LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") { ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; @@ -2296,7 +2296,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable CHECK_EQ("Type 't' does not have key 'p'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") +TEST_CASE_FIXTURE(BuiltinsFixture, "give_up_after_one_metatable_index_look_up") { CheckResult result = check(R"( local data = { x = 5 } @@ -2478,7 +2478,7 @@ TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") )"); } -TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unifies_into_map") { CheckResult result = check(R"( local Instance: any @@ -2564,7 +2564,7 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") * the generalization process), then it loses the knowledge that its metatable will have an :incr() * method. */ -TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_scope") { CheckResult result = check(R"( local Counter = {} @@ -2606,7 +2606,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") } // TODO: CLI-39624 -TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_tables_at_scope_level") { CheckResult result = check(R"( --!strict @@ -2690,7 +2690,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_meta LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") { CheckResult result = check(R"( --!nonstrict @@ -2711,7 +2711,7 @@ type t0 = any CHECK(ttv->instantiatedTypeParams.empty()); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning_2") { ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; @@ -2767,7 +2767,7 @@ local baz = foo[bar] CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); } -TEST_CASE_FIXTURE(Fixture, "table_simple_call") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { @@ -2783,7 +2783,7 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadic") { CheckResult result = check(R"( type Foo = {x: string} @@ -2878,7 +2878,7 @@ TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") )"); } -TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_function_check_use_after_free") { CheckResult result = check(R"( local t = {} @@ -2916,7 +2916,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the } // The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. -TEST_CASE_FIXTURE(Fixture, "dont_leak_free_table_props") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") { CheckResult result = check(R"( local function a(state) diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e81ef1a9..48cd1c3d 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_on_lowercase_parent_property") { CheckResult result = check(R"( local M = require(script.parent.DoesNotMatter) @@ -175,7 +175,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") REQUIRE_EQ("parent", ed->symbol); } -TEST_CASE_FIXTURE(Fixture, "weird_case") +TEST_CASE_FIXTURE(BuiltinsFixture, "weird_case") { CheckResult result = check(R"( local function f() return 4 end @@ -419,7 +419,7 @@ TEST_CASE_FIXTURE(Fixture, "globals_everywhere") CHECK_EQ("any", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") { CheckResult result = check(R"( do @@ -534,7 +534,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_in_error") { { CheckResult result = check(R"( @@ -587,7 +587,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error } } -TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_expr_should_be_checked") { CheckResult result = check(R"( local foo: any @@ -768,7 +768,7 @@ b, c = {2, "s"}, {"b", 4} LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_assignment_value_types_mutable_lval") { CheckResult result = check(R"( local a = {} @@ -836,7 +836,7 @@ local a: number? = if true then 1 else nil LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_if_else_expressions_expected_type_3") { CheckResult result = check(R"( local function times(n: any, f: () -> T) @@ -907,7 +907,7 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } -TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_crash") { CheckResult result = check(R"( local function getIt() @@ -1041,7 +1041,6 @@ TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table") { ScopedFastFlag flag[] = { - {"LuauStatFunctionSimplify4", true}, {"LuauLowerBoundsCalculation", true}, {"LuauDifferentOrderOfUnificationDoesntMatter2", true}, }; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 87562644..49deae71 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -196,7 +196,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") CHECK_EQ(toString(tm->wantedType), "string"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unification") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index f141622f..fd66b080 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -339,7 +339,7 @@ local c: Packed CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); } -TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_type_packs_import") { fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } @@ -369,7 +369,7 @@ local d: { a: typeof(c) } CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); } -TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_pack_type_parameters") { fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } @@ -784,7 +784,7 @@ local a: Y<...number> LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_default_export") { fileResolver.source["Module/Types"] = R"( export type A = { a: T, b: U } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 96bdd534..277f3887 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -104,7 +104,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_arguments_table2") REQUIRE(!result.errors.empty()); } -TEST_CASE_FIXTURE(Fixture, "error_takes_optional_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_takes_optional_arguments") { CheckResult result = check(R"( error("message") @@ -517,10 +517,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } From ab4bb355a3f261d18ae8c0d09ce44a7af67ecaf9 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Mon, 16 May 2022 17:50:15 +0100 Subject: [PATCH 065/102] Add `ToStringOptions.hideFunctionSelfArgument` (#486) Adds an option to hide the `self: type` argument as the first argument in the string representation of a named function type var if the ftv hasSelf. Also added in a test for the original output (i.e., if the option was disabled) I didn't apply this option in the normal `Luau::toString()` function, just the `Luau::toStringNamedFunction()` one (for my usecase, that is enough + I felt like a named function would include the method colon `:` to signify self). If this is unintuitive, I can also add it to the general `Luau::toString()` function. --- Analysis/include/Luau/ToString.h | 1 + Analysis/src/ToString.cpp | 8 +++++++ tests/ToString.test.cpp | 36 ++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index f4db5e35..3b380a60 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -28,6 +28,7 @@ struct ToStringOptions bool functionTypeArguments = false; // If true, output function type argument names when they are available bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 51665f7f..51f3f69f 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1230,6 +1230,14 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp size_t idx = 0; while (argPackIter != end(ftv.argTypes)) { + // ftv takes a self parameter as the first argument, skip it if specified in option + if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument) + { + ++argPackIter; + ++idx; + continue; + } + if (!first) state.emit(", "); first = false; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 50d0838e..4bdd45f7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -617,4 +617,40 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") +{ + ScopedFastFlag flag{"LuauDocFuncParameters", true}; + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); +} + + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") +{ + ScopedFastFlag flag{"LuauDocFuncParameters", true}; + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + ToStringOptions opts; + opts.hideFunctionSelfArgument = true; + CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); +} + + TEST_SUITE_END(); From f2191b9e4da6a4bb2d9d344ebd7941ec2f00844b Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Tue, 17 May 2022 19:22:54 +0100 Subject: [PATCH 066/102] Respect useLineBreaks for union/intersect toString (#487) * Respect useLineBreaks for union/intersect toString * Apply suggestions from code review Co-authored-by: Andy Friesen Co-authored-by: Andy Friesen --- Analysis/src/ToString.cpp | 10 ++++++++-- tests/ToString.test.cpp | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 51f3f69f..380ac456 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -745,7 +745,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" | "); + { + state.newline(); + state.emit("| "); + } state.emit(ss); first = false; } @@ -798,7 +801,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" & "); + { + state.newline(); + state.emit("& "); + } state.emit(ss); first = false; } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4bdd45f7..f38dd10a 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -126,6 +126,39 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_inte CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); } +TEST_CASE_FIXTURE(Fixture, "intersections_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: ((string) -> string) & ((number) -> number) + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("((number) -> number)\n" + "& ((string) -> string)", + toString(requireType("a"), opts)); + //clang-format on +} + +TEST_CASE_FIXTURE(Fixture, "unions_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: string | number | boolean + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("boolean\n" + "| number\n" + "| string", + toString(requireType("a"), opts)); + //clang-format on +} + TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") { TableTypeVar ttv{}; From 8b4c6aabc271c8bee2910c5b752759b21a6bcca9 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Thu, 19 May 2022 00:26:05 +0100 Subject: [PATCH 067/102] Fix findAstAncestry when position is at eof (#490) --- Analysis/src/AstQuery.cpp | 20 ++++++++++++++++++-- tests/AstQuery.test.cpp | 13 +++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 0aed34c0..0522b1fa 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -71,9 +71,11 @@ struct FindFullAncestry final : public AstVisitor { std::vector nodes; Position pos; + Position documentEnd; - explicit FindFullAncestry(Position pos) + explicit FindFullAncestry(Position pos, Position documentEnd) : pos(pos) + , documentEnd(documentEnd) { } @@ -84,6 +86,16 @@ struct FindFullAncestry final : public AstVisitor nodes.push_back(node); return true; } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; + } + return false; } }; @@ -92,7 +104,11 @@ struct FindFullAncestry final : public AstVisitor std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { - FindFullAncestry finder(pos); + const Position end = source.root->location.end; + if (pos > end) + pos = end; + + FindFullAncestry finder(pos, end); source.root->visit(&finder); return std::move(finder.nodes); } diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 12c68450..f0017509 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -92,4 +92,17 @@ bar(foo()) CHECK_EQ("number", toString(*expectedOty)); } +TEST_CASE_FIXTURE(Fixture, "ast_ancestry_at_eof") +{ + check(R"( +if true then + )"); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(2, 4)); + REQUIRE_GE(ancestry.size(), 2); + AstStat* parentStat = ancestry[ancestry.size() - 2]->asStat(); + REQUIRE(bool(parentStat)); + REQUIRE(parentStat->is()); +} + TEST_SUITE_END(); From f5923aefeb66f8ea3e5193d328d8eb74400e0b3b Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 19 May 2022 17:02:24 -0700 Subject: [PATCH 068/102] Sync to upstream/release/527 (#491) --- Analysis/include/Luau/Clone.h | 2 +- Analysis/include/Luau/Error.h | 4 +- Analysis/include/Luau/LValue.h | 4 - Analysis/include/Luau/Module.h | 36 +- Analysis/include/Luau/Substitution.h | 3 +- Analysis/include/Luau/TxnLog.h | 11 - Analysis/include/Luau/TypeArena.h | 42 ++ Analysis/include/Luau/TypeInfer.h | 19 +- Analysis/include/Luau/Unifier.h | 4 +- Analysis/include/Luau/VisitTypeVar.h | 20 +- Analysis/src/BuiltinDefinitions.cpp | 53 +- Analysis/src/Clone.cpp | 44 +- Analysis/src/Error.cpp | 13 +- Analysis/src/IostreamHelpers.cpp | 2 +- Analysis/src/LValue.cpp | 21 - Analysis/src/Module.cpp | 123 ++-- Analysis/src/Normalize.cpp | 32 +- Analysis/src/Quantify.cpp | 39 +- Analysis/src/Substitution.cpp | 113 ++-- Analysis/src/ToString.cpp | 7 +- Analysis/src/TxnLog.cpp | 34 +- Analysis/src/TypeArena.cpp | 88 +++ Analysis/src/TypeInfer.cpp | 650 +++++---------------- Analysis/src/TypeUtils.cpp | 11 +- Analysis/src/TypeVar.cpp | 9 +- Analysis/src/Unifier.cpp | 84 +-- Compiler/include/Luau/BytecodeBuilder.h | 1 + Compiler/src/BytecodeBuilder.cpp | 10 + Compiler/src/Compiler.cpp | 289 ++++++--- Compiler/src/ConstantFolding.cpp | 4 +- Sources.cmake | 2 + VM/src/ltablib.cpp | 27 - VM/src/lvmexecute.cpp | 14 - tests/Autocomplete.test.cpp | 4 + tests/Compiler.test.cpp | 235 ++++++-- tests/Module.test.cpp | 8 - tests/NonstrictMode.test.cpp | 4 - tests/Normalize.test.cpp | 45 ++ tests/RuntimeLimits.test.cpp | 2 - tests/ToString.test.cpp | 14 +- tests/TypeInfer.builtins.test.cpp | 14 - tests/TypeInfer.functions.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 74 +++ tests/TypeInfer.intersectionTypes.test.cpp | 2 - tests/TypeInfer.loops.test.cpp | 2 - tests/TypeInfer.operators.test.cpp | 2 - tests/TypeInfer.provisional.test.cpp | 36 +- tests/TypeInfer.refinements.test.cpp | 143 +---- tests/TypeInfer.singletons.test.cpp | 25 +- tests/TypeInfer.tables.test.cpp | 34 +- tests/TypeInfer.test.cpp | 4 - tests/TypeInfer.unionTypes.test.cpp | 1 - 52 files changed, 1097 insertions(+), 1369 deletions(-) create mode 100644 Analysis/include/Luau/TypeArena.h create mode 100644 Analysis/src/TypeArena.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 9b6ffa62..9fcbce04 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/TypeArena.h" #include "Luau/TypeVar.h" #include @@ -18,7 +19,6 @@ struct CloneState SeenTypePacks seenTypePacks; int recursionCount = 0; - bool encounteredFreeType = false; // TODO: Remove with LuauLosslessClone. }; TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 70683141..b4530674 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -5,6 +5,7 @@ #include "Luau/Location.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" +#include "Luau/TypeArena.h" namespace Luau { @@ -108,9 +109,6 @@ struct FunctionDoesNotTakeSelf struct FunctionRequiresSelf { - // TODO: Delete with LuauAnyInIsOptionalIsOptional - int requiredExtraNils = 0; - bool operator==(const FunctionRequiresSelf& rhs) const; }; diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index afb71415..1a92d52d 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -34,10 +34,6 @@ const LValue* baseof(const LValue& lvalue); std::optional tryGetLValue(const class AstExpr& expr); -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -// TODO: remove with FFlagLuauTypecheckOptPass -std::pair> getFullName(const LValue& lvalue); - // Utility function: breaks down an LValue to get at the Symbol Symbol getBaseSymbol(const LValue& lvalue); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 0dd44188..00e1e635 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -2,11 +2,10 @@ #pragma once #include "Luau/FileResolver.h" -#include "Luau/TypePack.h" -#include "Luau/TypedAllocator.h" #include "Luau/ParseOptions.h" #include "Luau/Error.h" #include "Luau/ParseResult.h" +#include "Luau/TypeArena.h" #include #include @@ -54,35 +53,6 @@ struct RequireCycle std::vector path; // one of the paths for a require() to go all the way back to the originating module }; -struct TypeArena -{ - TypedAllocator typeVars; - TypedAllocator typePacks; - - void clear(); - - template - TypeId addType(T tv) - { - if constexpr (std::is_same_v) - LUAU_ASSERT(tv.options.size() >= 2); - - return addTV(TypeVar(std::move(tv))); - } - - TypeId addTV(TypeVar&& tv); - - TypeId freshType(TypeLevel level); - - TypePackId addTypePack(std::initializer_list types); - TypePackId addTypePack(std::vector types); - TypePackId addTypePack(TypePack pack); - TypePackId addTypePack(TypePackVar pack); -}; - -void freeze(TypeArena& arena); -void unfreeze(TypeArena& arena); - struct Module { ~Module(); @@ -111,9 +81,7 @@ struct Module // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. - // Returns true if there were any free types encountered in the public interface. This - // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(InternalErrorReporter& ice); + void clonePublicInterface(InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6f5931e1..f3c3ae9a 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -1,8 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Module.h" -#include "Luau/ModuleResolver.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/DenseHash.h" diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 995ed6c6..cd115e3b 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -7,8 +7,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauTypecheckOptPass) - namespace Luau { @@ -93,15 +91,6 @@ struct TxnLog { } - TxnLog(TxnLog* parent, std::vector>* sharedSeen) - : typeVarChanges(nullptr) - , typePackChanges(nullptr) - , parent(parent) - , sharedSeen(sharedSeen) - { - LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); - } - TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h new file mode 100644 index 00000000..7c74158b --- /dev/null +++ b/Analysis/include/Luau/TypeArena.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ + +struct TypeArena +{ + TypedAllocator typeVars; + TypedAllocator typePacks; + + void clear(); + + template + TypeId addType(T tv) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + + return addTV(TypeVar(std::move(tv))); + } + + TypeId addTV(TypeVar&& tv); + + TypeId freshType(TypeLevel level); + + TypePackId addTypePack(std::initializer_list types); + TypePackId addTypePack(std::vector types); + TypePackId addTypePack(TypePack pack); + TypePackId addTypePack(TypePackVar pack); +}; + +void freeze(TypeArena& arena); +void unfreeze(TypeArena& arena); + +} diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ac880135..fcaf5baa 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -187,7 +187,6 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); @@ -395,7 +394,7 @@ private: const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); public: - ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); @@ -403,14 +402,14 @@ private: std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); - void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); - void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); + void resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; bool useConstrainedIntersections() const; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 418d4ca4..0e24c8b0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -5,7 +5,7 @@ #include "Luau/Location.h" #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" -#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" #include @@ -55,8 +55,6 @@ struct Unifier UnifierSharedState& sharedState; Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, - UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 67fce5ed..2e98f526 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -10,6 +10,7 @@ LUAU_FASTFLAG(LuauUseVisitRecursionLimit) LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) namespace Luau { @@ -471,18 +472,21 @@ struct GenericTypeVarVisitor else if (auto pack = get(tp)) { - visit(tp, *pack); + bool res = visit(tp, *pack); + if (!FFlag::LuauNormalizeFlagIsConservative || res) + { + for (TypeId ty : pack->head) + traverse(ty); - for (TypeId ty : pack->head) - traverse(ty); - - if (pack->tail) - traverse(*pack->tail); + if (pack->tail) + traverse(*pack->tail); + } } else if (auto pack = get(tp)) { - visit(tp, *pack); - traverse(pack->ty); + bool res = visit(tp, *pack); + if (!FFlag::LuauNormalizeFlagIsConservative || res) + traverse(pack->ty); } else LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 3895b01b..5ed6de67 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,7 +8,6 @@ #include -LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. @@ -408,41 +407,29 @@ static std::optional> magicFunctionAssert( { auto [paramPack, predicates] = exprResult; - if (FFlag::LuauAssertStripsFalsyTypes) + TypeArena& arena = typechecker.currentModule->internalTypes; + + auto [head, tail] = flatten(paramPack); + if (head.empty() && tail) { - TypeArena& arena = typechecker.currentModule->internalTypes; - - auto [head, tail] = flatten(paramPack); - if (head.empty() && tail) - { - std::optional fst = first(*tail); - if (!fst) - return ExprResult{paramPack}; - head.push_back(*fst); - } - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - if (head.size() > 0) - { - std::optional newhead = typechecker.pickTypesFromSense(head[0], true); - if (!newhead) - head = {typechecker.nilType}; - else - head[0] = *newhead; - } - - return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; - } - else - { - if (expr.args.size < 1) + std::optional fst = first(*tail); + if (!fst) return ExprResult{paramPack}; - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - return ExprResult{paramPack}; + head.push_back(*fst); } + + typechecker.resolve(predicates, scope, true); + + if (head.size() > 0) + { + std::optional newhead = typechecker.pickTypesFromSense(head[0], true); + if (!newhead) + head = {typechecker.nilType}; + else + head[0] = *newhead; + } + + return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; } static std::optional> magicFunctionPack( diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 1aa556eb..a3611f53 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" -#include "Luau/Module.h" #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" @@ -9,8 +8,6 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauTypecheckOptPass) -LUAU_FASTFLAGVARIABLE(LuauLosslessClone, false) LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau @@ -89,20 +86,8 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - if (FFlag::LuauLosslessClone) - { - defaultClone(t); - } - else - { - cloneState.encounteredFreeType = true; - - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; - } + defaultClone(t); } - void operator()(const Unifiable::Generic& t) { defaultClone(t); @@ -152,18 +137,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - if (FFlag::LuauLosslessClone) - { - defaultClone(t); - } - else - { - cloneState.encounteredFreeType = true; - - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; - } + defaultClone(t); } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -191,9 +165,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const ConstrainedTypeVar& t) { - if (!FFlag::LuauLosslessClone) - cloneState.encounteredFreeType = true; - TypeId res = dest.addType(ConstrainedTypeVar{t.level}); ConstrainedTypeVar* ctv = getMutable(res); LUAU_ASSERT(ctv); @@ -230,9 +201,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; ftv->retType = clone(t.retType, dest, cloneState); - - if (FFlag::LuauTypecheckOptPass) - ftv->hasNoGenerics = t.hasNoGenerics; + ftv->hasNoGenerics = t.hasNoGenerics; } void TypeCloner::operator()(const TableTypeVar& t) @@ -270,13 +239,6 @@ void TypeCloner::operator()(const TableTypeVar& t) for (TypePackId& arg : ttv->instantiatedTypePackParams) arg = clone(arg, dest, cloneState); - if (!FFlag::LuauLosslessClone && ttv->state == TableState::Free) - { - cloneState.encounteredFreeType = true; - - ttv->state = TableState::Sealed; - } - ttv->definitionModuleName = t.definitionModuleName; if (!FFlag::LuauNoMethodLocations) ttv->methodDefinitionLocations = t.methodDefinitionLocations; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 24ed4ac1..f443a3cc 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -2,7 +2,6 @@ #include "Luau/Error.h" #include "Luau/Clone.h" -#include "Luau/Module.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" @@ -178,15 +177,7 @@ struct ErrorConverter std::string operator()(const Luau::FunctionRequiresSelf& e) const { - if (e.requiredExtraNils) - { - const char* plural = e.requiredExtraNils == 1 ? "" : "s"; - return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or " - "pass %i extra nil%s to suppress this warning", - e.requiredExtraNils, plural); - } - else - return "This function must be called with self. Did you mean to use a colon instead of a dot?"; + return "This function must be called with self. Did you mean to use a colon instead of a dot?"; } std::string operator()(const Luau::OccursCheckFailed&) const @@ -539,7 +530,7 @@ bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const { - return requiredExtraNils == e.requiredExtraNils; + return true; } bool OccursCheckFailed::operator==(const OccursCheckFailed&) const diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 0eaa485e..048167ae 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -48,7 +48,7 @@ static void errorToString(std::ostream& stream, const T& err) else if constexpr (std::is_same_v) stream << "FunctionDoesNotTakeSelf { }"; else if constexpr (std::is_same_v) - stream << "FunctionRequiresSelf { extraNils " << err.requiredExtraNils << " }"; + stream << "FunctionRequiresSelf { }"; else if constexpr (std::is_same_v) stream << "OccursCheckFailed { }"; else if constexpr (std::is_same_v) diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index 72555ab4..38dfe1ae 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauTypecheckOptPass) - namespace Luau { @@ -79,27 +77,8 @@ std::optional tryGetLValue(const AstExpr& node) return std::nullopt; } -std::pair> getFullName(const LValue& lvalue) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); - - const LValue* current = &lvalue; - std::vector keys; - while (auto field = get(*current)) - { - keys.push_back(field->key); - current = baseof(*current); - } - - const Symbol* symbol = get(*current); - LUAU_ASSERT(symbol); - return {*symbol, std::vector(keys.rbegin(), keys.rend())}; -} - Symbol getBaseSymbol(const LValue& lvalue) { - LUAU_ASSERT(FFlag::LuauTypecheckOptPass); - const LValue* current = &lvalue; while (auto field = get(*current)) current = baseof(*current); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index bafd4371..074a41e6 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,9 +13,8 @@ #include -LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauLosslessClone) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative); namespace Luau { @@ -55,89 +54,25 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -void TypeArena::clear() +struct ForceNormal : TypeVarOnceVisitor { - typeVars.clear(); - typePacks.clear(); -} + bool visit(TypeId ty) override + { + asMutable(ty)->normal = true; + return true; + } -TypeId TypeArena::addTV(TypeVar&& tv) -{ - TypeId allocated = typeVars.allocate(std::move(tv)); + bool visit(TypeId ty, const FreeTypeVar& ftv) override + { + visit(ty); + return true; + } - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypeId TypeArena::freshType(TypeLevel level) -{ - TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::initializer_list types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::vector types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePack tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePackVar tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -ScopePtr Module::getModuleScope() const -{ - LUAU_ASSERT(!scopes.empty()); - return scopes.front().second; -} - -void freeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.freeze(); - arena.typePacks.freeze(); -} - -void unfreeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.unfreeze(); - arena.typePacks.unfreeze(); -} + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + return true; + } +}; Module::~Module() { @@ -145,7 +80,7 @@ Module::~Module() unfreeze(internalTypes); } -bool Module::clonePublicInterface(InternalErrorReporter& ice) +void Module::clonePublicInterface(InternalErrorReporter& ice) { LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); @@ -165,11 +100,22 @@ bool Module::clonePublicInterface(InternalErrorReporter& ice) normalize(*moduleScope->varargPack, interfaceTypes, ice); } + ForceNormal forceNormal; + for (auto& [name, tf] : moduleScope->exportedTypeBindings) { tf = clone(tf, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) + { normalize(tf.type, interfaceTypes, ice); + + if (FFlag::LuauNormalizeFlagIsConservative) + { + // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables + // won't be marked normal. If the types aren't normal by now, they never will be. + forceNormal.traverse(tf.type); + } + } } for (TypeId ty : moduleScope->returnType) @@ -191,11 +137,12 @@ bool Module::clonePublicInterface(InternalErrorReporter& ice) freeze(internalTypes); freeze(interfaceTypes); +} - if (FFlag::LuauLosslessClone) - return false; // TODO: make function return void. - else - return cloneState.encounteredFreeType; +ScopePtr Module::getModuleScope() const +{ + LUAU_ASSERT(!scopes.empty()); + return scopes.front().second; } } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index ef5377a1..30fd4af2 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); namespace Luau { @@ -260,8 +261,13 @@ static bool areNormal_(const T& t, const std::unordered_set& seen, Intern if (count >= FInt::LuauNormalizeIterationLimit) ice.ice("Luau::areNormal hit iteration limit"); - // The follow is here because a bound type may not be normal, but the bound type is normal. - return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); + if (FFlag::LuauNormalizeFlagIsConservative) + return ty->normal; + else + { + // The follow is here because a bound type may not be normal, but the bound type is normal. + return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); + } }; return std::all_of(begin(t), end(t), isNormal); @@ -1003,8 +1009,15 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo (void)clone(ty, arena, state); Normalize n{arena, ice}; - std::unordered_set seen; - DEPRECATED_visitTypeVar(ty, n, seen); + if (FFlag::LuauNormalizeFlagIsConservative) + { + DEPRECATED_visitTypeVar(ty, n); + } + else + { + std::unordered_set seen; + DEPRECATED_visitTypeVar(ty, n, seen); + } return {ty, !n.limitExceeded}; } @@ -1028,8 +1041,15 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE (void)clone(tp, arena, state); Normalize n{arena, ice}; - std::unordered_set seen; - DEPRECATED_visitTypeVar(tp, n, seen); + if (FFlag::LuauNormalizeFlagIsConservative) + { + DEPRECATED_visitTypeVar(tp, n); + } + else + { + std::unordered_set seen; + DEPRECATED_visitTypeVar(tp, n, seen); + } return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 4f3e4469..018d5632 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,7 +4,7 @@ #include "Luau/VisitTypeVar.h" -LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAG(LuauAlwaysQuantify) namespace Luau { @@ -59,8 +59,7 @@ struct Quantifier final : TypeVarOnceVisitor bool visit(TypeId ty, const FreeTypeVar& ftv) override { - if (FFlag::LuauTypecheckOptPass) - seenMutableType = true; + seenMutableType = true; if (!level.subsumes(ftv.level)) return false; @@ -76,20 +75,17 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(getMutable(ty)); TableTypeVar& ttv = *getMutable(ty); - if (FFlag::LuauTypecheckOptPass) - { - if (ttv.state == TableState::Generic) - seenGenericType = true; + if (ttv.state == TableState::Generic) + seenGenericType = true; - if (ttv.state == TableState::Free) - seenMutableType = true; - } + if (ttv.state == TableState::Free) + seenMutableType = true; if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) return false; if (!level.subsumes(ttv.level)) { - if (FFlag::LuauTypecheckOptPass && ttv.state == TableState::Unsealed) + if (ttv.state == TableState::Unsealed) seenMutableType = true; return false; } @@ -97,9 +93,7 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) { ttv.state = TableState::Generic; - - if (FFlag::LuauTypecheckOptPass) - seenGenericType = true; + seenGenericType = true; } else if (ttv.state == TableState::Unsealed) ttv.state = TableState::Sealed; @@ -111,8 +105,7 @@ struct Quantifier final : TypeVarOnceVisitor bool visit(TypePackId tp, const FreeTypePack& ftp) override { - if (FFlag::LuauTypecheckOptPass) - seenMutableType = true; + seenMutableType = true; if (!level.subsumes(ftp.level)) return false; @@ -131,10 +124,18 @@ void quantify(TypeId ty, TypeLevel level) FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } - if (FFlag::LuauTypecheckOptPass && ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index c5c7977a..e40bedb0 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -9,9 +9,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAG(LuauTypecheckOptPass) -LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) -LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau @@ -19,26 +16,20 @@ namespace Luau void Tarjan::visitChildren(TypeId ty, int index) { - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) return; - if (FFlag::LuauTypecheckOptPass) - { - if (auto pty = log->pending(ty)) - ty = &pty->pending; - } + if (auto pty = log->pending(ty)) + ty = &pty->pending; - if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + if (const FunctionTypeVar* ftv = get(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const TableTypeVar* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -55,17 +46,17 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = get(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const UnionTypeVar* utv = get(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = get(ty)) { for (TypeId part : itv->parts) visitChild(part); @@ -79,28 +70,22 @@ void Tarjan::visitChildren(TypeId ty, int index) void Tarjan::visitChildren(TypePackId tp, int index) { - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); if (ignoreChildren(tp)) return; - if (FFlag::LuauTypecheckOptPass) - { - if (auto ptp = log->pending(tp)) - tp = &ptp->pending; - } + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; - if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + if (const TypePack* tpp = get(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + else if (const VariadicTypePack* vtp = get(tp)) { visitChild(vtp->ty); } @@ -108,10 +93,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -129,10 +111,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -150,8 +129,7 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - if (!FFlag::LuauSubstituteFollowPossibleMutations) - ty = log->follow(ty); + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -159,8 +137,7 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - if (!FFlag::LuauSubstituteFollowPossibleMutations) - tp = log->follow(tp); + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -389,13 +366,10 @@ TypeId Substitution::clone(TypeId ty) TypeId result = ty; - if (FFlag::LuauTypecheckOptPass) - { - if (auto pty = log->pending(ty)) - ty = &pty->pending; - } + if (auto pty = log->pending(ty)) + ty = &pty->pending; - if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + if (const FunctionTypeVar* ftv = get(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -405,7 +379,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const TableTypeVar* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -419,19 +393,19 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = get(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const UnionTypeVar* utv = get(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = get(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; @@ -451,20 +425,17 @@ TypePackId Substitution::clone(TypePackId tp) { tp = log->follow(tp); - if (FFlag::LuauTypecheckOptPass) - { - if (auto ptp = log->pending(tp)) - tp = &ptp->pending; - } + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; - if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + if (const TypePack* tpp = get(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + else if (const VariadicTypePack* vtp = get(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -476,28 +447,22 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + ty = log->follow(ty); if (isDirty(ty)) - newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(ty)) : clean(ty); + newTypes[ty] = follow(clean(ty)); else - newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(ty)) : clone(ty); + newTypes[ty] = follow(clone(ty)); } void Substitution::foundDirty(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + tp = log->follow(tp); if (isDirty(tp)) - newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(tp)) : clean(tp); + newPacks[tp] = follow(clean(tp)); else - newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(tp)) : clone(tp); + newPacks[tp] = follow(clone(tp)); } TypeId Substitution::replace(TypeId ty) @@ -525,10 +490,7 @@ void Substitution::replaceChildren(TypeId ty) if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) btv->boundTo = replace(btv->boundTo); - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) return; @@ -579,10 +541,7 @@ void Substitution::replaceChildren(TypeId ty) void Substitution::replaceChildren(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 380ac456..f90f7019 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -219,6 +219,8 @@ struct StringifierState return generateName(s); } + int previousNameIndex = 0; + std::string getName(TypePackId ty) { const size_t s = result.nameMap.typePacks.size(); @@ -228,9 +230,10 @@ struct StringifierState for (int count = 0; count < 256; ++count) { - std::string candidate = generateName(usedNames.size() + count); + std::string candidate = generateName(previousNameIndex + count); if (!usedNames.count(candidate)) { + previousNameIndex += count; usedNames.insert(candidate); n = candidate; return candidate; @@ -399,6 +402,7 @@ struct TypeVarStringifier { if (gtv.explicitName) { + state.usedNames.insert(gtv.name); state.result.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } @@ -943,6 +947,7 @@ struct TypePackStringifier state.emit("gen-"); if (pack.explicitName) { + state.usedNames.insert(pack.name); state.result.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 1fb5a61a..e45c0cbd 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,8 +7,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) - namespace Luau { @@ -150,37 +148,13 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - if (FFlag::LuauJustOneCallFrameForHaveSeen && !FFlag::LuauTypecheckOptPass) + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) { - // This function will technically work if `this` is nullptr, but this - // indicates a bug, so we explicitly assert. - LUAU_ASSERT(static_cast(this) != nullptr); - - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - - for (const TxnLog* current = this; current; current = current->parent) - { - if (current->sharedSeen->end() != std::find(current->sharedSeen->begin(), current->sharedSeen->end(), sortedPair)) - return true; - } - - return false; + return true; } - else - { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) - { - return true; - } - if (!FFlag::LuauTypecheckOptPass && parent) - { - return parent->haveSeen(lhs, rhs); - } - - return false; - } + return false; } void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp new file mode 100644 index 00000000..673b002d --- /dev/null +++ b/Analysis/src/TypeArena.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeArena.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false); + +namespace Luau +{ + +void TypeArena::clear() +{ + typeVars.clear(); + typePacks.clear(); +} + +TypeId TypeArena::addTV(TypeVar&& tv) +{ + TypeId allocated = typeVars.allocate(std::move(tv)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{level}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::initializer_list types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::vector types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePack tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePackVar tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +void freeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.freeze(); + arena.typePacks.freeze(); +} + +void unfreeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.unfreeze(); + arena.typePacks.unfreeze(); +} + +} diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a13abd53..208b3f2f 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,32 +32,25 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) -LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) -LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) -LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) -LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) -LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) -LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) -LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); -LUAU_FASTFLAG(LuauLosslessClone) +LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); +LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); namespace Luau { @@ -371,12 +364,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); - if (!FFlag::LuauLosslessClone && encounteredFreeType) - { - reportError(TypeError{module.root->location, - GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); - } + currentModule->clonePublicInterface(*iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. @@ -701,7 +689,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) ExprResult result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); - reportErrors(resolve(result.predicates, ifScope, true)); + resolve(result.predicates, ifScope, true); check(ifScope, *statement.thenbody); if (statement.elsebody) @@ -734,7 +722,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) ExprResult result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); - reportErrors(resolve(result.predicates, whileScope, true)); + resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); } @@ -1154,10 +1142,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } else { - if (FFlag::LuauInstantiateFollows) - iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); - else - iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } if (FFlag::LuauTypecheckIter) @@ -1849,23 +1834,11 @@ std::optional TypeChecker::getIndexTypeFromType( tablify(type); - if (FFlag::LuauDiscriminableUnions2) + if (isString(type)) { - if (isString(type)) - { - std::optional mtIndex = findMetatableEntry(stringType, "__index", location); - LUAU_ASSERT(mtIndex); - type = *mtIndex; - } - } - else - { - const PrimitiveTypeVar* primitiveType = get(type); - if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) - { - if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) - type = *mtIndex; - } + std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + LUAU_ASSERT(mtIndex); + type = *mtIndex; } if (TableTypeVar* tableType = getMutableTableType(type)) @@ -1966,23 +1939,10 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - if (FFlag::LuauDoNotTryToReduce) - { - if (parts.size() == 1) - return parts[0]; + if (parts.size() == 1) + return parts[0]; - return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. - } - else - { - // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 - std::vector result = reduceUnion(parts); - - if (result.size() == 1) - return result[0]; - - return addType(IntersectionTypeVar{result}); - } + return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. } if (addErrors) @@ -1993,103 +1953,55 @@ std::optional TypeChecker::getIndexTypeFromType( std::vector TypeChecker::reduceUnion(const std::vector& types) { - if (FFlag::LuauDoNotAccidentallyDependOnPointerOrdering) + std::vector result; + for (TypeId t : types) { - std::vector result; - for (TypeId t : types) + t = follow(t); + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) { - t = follow(t); - if (get(t) || get(t)) - return {t}; - - if (const UnionTypeVar* utv = get(t)) + if (FFlag::LuauReduceUnionRecursion) { - if (FFlag::LuauReduceUnionRecursion) + for (TypeId ty : utv) { - for (TypeId ty : utv) - { - if (get(ty) || get(ty)) - return {ty}; - - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else - { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - { + if (FFlag::LuauNormalizeFlagIsConservative) ty = follow(ty); - if (get(ty) || get(ty)) - return {ty}; + if (get(ty) || get(ty)) + return {ty}; - if (std::find(result.begin(), result.end(), ty) == result.end()) - result.push_back(ty); - } + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); } } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); - } - - return result; - } - else - { - std::set s; - - for (TypeId t : types) - { - if (const UnionTypeVar* utv = get(follow(t))) + else { std::vector r = reduceUnion(utv->options); for (TypeId ty : r) - s.insert(ty); + { + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (std::find(result.begin(), result.end(), ty) == result.end()) + result.push_back(ty); + } } - else - s.insert(t); } - - // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. - for (TypeId t : s) - { - t = follow(t); - if (get(t) || get(t)) - return {t}; - } - - std::vector r(s.begin(), s.end()); - std::sort(r.begin(), r.end()); - return r; + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); } + + return result; } std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) { - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (!std::any_of(begin(utv), end(utv), isNil)) - return ty; - } - else - { - bool hasNil = false; - - for (TypeId option : utv) - { - if (isNil(option)) - { - hasNil = true; - break; - } - } - - if (!hasNil) - return ty; - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; std::vector result; @@ -2110,32 +2022,18 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { - if (FFlag::LuauAnyInIsOptionalIsOptional) + ty = follow(ty); + + if (auto utv = get(ty)) { - ty = follow(ty); - - if (auto utv = get(ty)) - { - if (!std::any_of(begin(utv), end(utv), isNil)) - return ty; - } - - if (std::optional strippedUnion = tryStripUnionFromNil(ty)) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; } - else + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) { - if (isOptional(ty)) - { - if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } - } + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); } return ty; @@ -2194,8 +2092,7 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - if (FFlag::LuauCheckImplicitNumbericKeys) - unify(numberType, indexer->indexType, value->location); + unify(numberType, indexer->indexType, value->location); unify(valueType, indexer->indexResultType, value->location); } else @@ -2219,7 +2116,8 @@ TypeId TypeChecker::checkExprTable( if (errors.empty()) exprType = expectedProp.type; } - else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) + else if (expectedTable->indexer && (FFlag::LuauExpectedPropTypeFromIndexer ? maybeString(expectedTable->indexer->indexType) + : isString(expectedTable->indexer->indexType))) { ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) @@ -2259,26 +2157,13 @@ TypeId TypeChecker::checkExprTable( ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - if (FFlag::LuauTableUseCounterInstead) + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { - RecursionCounter _rc(&checkRecursionCount); - if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) - { - reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; - } - - return checkExpr_(scope, expr, expectedType); + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; } - else - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "checkExpr for tables"); - return checkExpr_(scope, expr, expectedType); - } -} -ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) -{ std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; @@ -2324,6 +2209,8 @@ ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprT { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; + else if (FFlag::LuauExpectedPropTypeFromIndexer && expectedIndexType && maybeString(*expectedIndexType)) + expectedResultType = expectedIndexResultType; } else if (expectedUnion) { @@ -2529,7 +2416,7 @@ TypeId TypeChecker::checkRelationalOperation( if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); + resolve(predicates, subScope, true); return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); } } @@ -2851,8 +2738,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy), - {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; + return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { @@ -2864,7 +2750,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2872,8 +2758,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2931,12 +2817,12 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { ExprResult result = checkExpr(scope, *expr.condition); + ScopePtr trueScope = childScope(scope, expr.trueExpr->location); - reportErrors(resolve(result.predicates, trueScope, true)); + resolve(result.predicates, trueScope, true); ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); - // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); @@ -3668,9 +3554,6 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.get(t)) - { - } // ok else { size_t minParams = getMinParameterCount(&state.log, paramPack); @@ -3823,9 +3706,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = instantiate(scope, functionType, expr.func->location); } - if (!FFlag::LuauInstantiateFollows) - actualFunctionType = follow(actualFunctionType); - TypePackId retPack; if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) { @@ -4096,32 +3976,6 @@ std::optional> TypeChecker::checkCallOverload(const Scope { state.log.commit(); - if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) - { - // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND - // the function is declared with colon notation AND we use dot notation, warn. - auto [providedArgs, providedTail] = flatten(argPack); - - // If we have a variadic tail, we can't say how many arguments were actually provided - if (!providedTail) - { - std::vector actualArgs = flatten(ftv->argTypes).first; - - size_t providedCount = providedArgs.size(); - size_t requiredCount = actualArgs.size(); - - // Ignore optional arguments - while (providedCount < requiredCount && requiredCount != 0 && isOptional(actualArgs[requiredCount - 1])) - requiredCount--; - - if (providedCount < requiredCount) - { - int requiredExtraNils = int(requiredCount - providedCount); - reportError(TypeError{expr.func->location, FunctionRequiresSelf{requiredExtraNils}}); - } - } - } - currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload @@ -4525,7 +4379,7 @@ bool Instantiation::isDirty(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) { - if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + if (ftv->hasNoGenerics) return false; return true; @@ -4582,7 +4436,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) { - if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + if (ftv->hasNoGenerics) return true; // We aren't recursing in the case of a generic function which @@ -4701,8 +4555,17 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location ty = follow(ty); const FunctionTypeVar* ftv = get(ty); - if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) - Luau::quantify(ty, scope->level); + + if (FFlag::LuauAlwaysQuantify) + { + if (ftv) + Luau::quantify(ty, scope->level); + } + else + { + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + } if (FFlag::LuauLowerBoundsCalculation && ftv) { @@ -4717,15 +4580,11 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - if (FFlag::LuauInstantiateFollows) - ty = follow(ty); + ty = follow(ty); - if (FFlag::LuauTypecheckOptPass) - { - const FunctionTypeVar* ftv = get(FFlag::LuauInstantiateFollows ? ty : follow(ty)); - if (ftv && ftv->hasNoGenerics) - return ty; - } + const FunctionTypeVar* ftv = get(ty); + if (ftv && ftv->hasNoGenerics) + return ty; Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; @@ -5392,10 +5251,9 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack bool ApplyTypeFunction::isDirty(TypeId ty) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(ty)) + if (FFlag::LuauApplyTypeFunctionFix && typeArguments.count(ty)) + return true; + else if (!FFlag::LuauApplyTypeFunctionFix && get(ty)) return true; else if (const FreeTypeVar* ftv = get(ty)) { @@ -5409,10 +5267,9 @@ bool ApplyTypeFunction::isDirty(TypeId ty) bool ApplyTypeFunction::isDirty(TypePackId tp) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(tp)) + if (FFlag::LuauApplyTypeFunctionFix && typePackArguments.count(tp)) + return true; + else if (!FFlag::LuauApplyTypeFunctionFix && get(tp)) return true; else return false; @@ -5436,11 +5293,13 @@ bool ApplyTypeFunction::ignoreChildren(TypePackId tp) TypeId ApplyTypeFunction::clean(TypeId ty) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. TypeId& arg = typeArguments[ty]; - if (arg) + if (FFlag::LuauApplyTypeFunctionFix) + { + LUAU_ASSERT(arg); + return arg; + } + else if (arg) return arg; else return addType(FreeTypeVar{level}); @@ -5448,11 +5307,13 @@ TypeId ApplyTypeFunction::clean(TypeId ty) TypePackId ApplyTypeFunction::clean(TypePackId tp) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. TypePackId& arg = typePackArguments[tp]; - if (arg) + if (FFlag::LuauApplyTypeFunctionFix) + { + LUAU_ASSERT(arg); + return arg; + } + else if (arg) return arg; else return addTypePack(FreeTypePack{level}); @@ -5596,8 +5457,6 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions2 || FFlag::LuauAssertStripsFalsyTypes); - const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. @@ -5683,66 +5542,6 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV // We need to search in the provided Scope. Find t.x.y first. // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. - if (!FFlag::LuauTypecheckOptPass) - { - const auto& [symbol, keys] = getFullName(lvalue); - - ScopePtr currentScope = scope; - while (currentScope) - { - std::optional found; - - std::vector childKeys; - const LValue* currentLValue = &lvalue; - while (currentLValue) - { - if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) - { - found = it->second; - break; - } - - childKeys.push_back(*currentLValue); - currentLValue = baseof(*currentLValue); - } - - if (!found) - { - // Should not be using scope->lookup. This is already recursive. - if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) - found = it->second.typeId; - else - { - // Nothing exists in this Scope. Just skip and try the parent one. - currentScope = currentScope->parent; - continue; - } - } - - for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) - { - const LValue& key = *it; - - // Symbol can happen. Skip. - if (get(key)) - continue; - else if (auto field = get(key)) - { - found = getIndexTypeFromType(scope, *found, field->key, Location(), false); - if (!found) - return std::nullopt; // Turns out this type doesn't have the property at all. We're done. - } - else - LUAU_ASSERT(!"New LValue alternative not handled here."); - } - - return found; - } - - // No entry for it at all. Can happen when LValue root is a global. - return std::nullopt; - } - const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; @@ -5820,85 +5619,47 @@ static bool isUndecidable(TypeId ty) return get(ty) || get(ty) || get(ty); } -ErrorVec TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) { - ErrorVec errVec; - resolve(predicates, errVec, scope->refinements, scope, sense); - return errVec; + resolve(predicates, scope->refinements, scope, sense); } -void TypeChecker::resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { for (const Predicate& c : predicates) - resolve(c, errVec, refis, scope, sense, fromOr); + resolve(c, refis, scope, sense, fromOr); } -void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { if (auto truthyP = get(predicate)) - resolve(*truthyP, errVec, refis, scope, sense, fromOr); + resolve(*truthyP, refis, scope, sense, fromOr); else if (auto andP = get(predicate)) - resolve(*andP, errVec, refis, scope, sense); + resolve(*andP, refis, scope, sense); else if (auto orP = get(predicate)) - resolve(*orP, errVec, refis, scope, sense); + resolve(*orP, refis, scope, sense); else if (auto notP = get(predicate)) - resolve(notP->predicates, errVec, refis, scope, !sense, fromOr); + resolve(notP->predicates, refis, scope, !sense, fromOr); else if (auto isaP = get(predicate)) - resolve(*isaP, errVec, refis, scope, sense); + resolve(*isaP, refis, scope, sense); else if (auto typeguardP = get(predicate)) - resolve(*typeguardP, errVec, refis, scope, sense); + resolve(*typeguardP, refis, scope, sense); else if (auto eqP = get(predicate)) - resolve(*eqP, errVec, refis, scope, sense); + resolve(*eqP, refis, scope, sense); else ice("Unhandled predicate kind"); } -void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { - if (FFlag::LuauAssertStripsFalsyTypes) - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (ty && fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); - } - else - { - auto predicate = [sense](TypeId option) -> std::optional { - if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) - return option; - - return std::nullopt; - }; - - if (FFlag::LuauDiscriminableUnions2) - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (ty && fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); - - refineLValue(truthyP.lvalue, refis, scope, predicate); - } - else - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; - - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); - } - } + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); } -void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5907,14 +5668,14 @@ void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, Refinement {NotPredicate{std::move(andP.rhs)}}, }; - return resolve(orP, errVec, refis, scope, !sense); + return resolve(orP, refis, scope, !sense); } - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + resolve(andP.lhs, refis, scope, sense); + resolve(andP.rhs, refis, scope, sense); } -void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5923,28 +5684,24 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa {NotPredicate{std::move(orP.rhs)}}, }; - return resolve(andP, errVec, refis, scope, !sense); + return resolve(andP, refis, scope, !sense); } - ErrorVec discarded; - RefinementMap leftRefis; - resolve(orP.lhs, errVec, leftRefis, scope, sense); + resolve(orP.lhs, leftRefis, scope, sense); RefinementMap rightRefis; - resolve(orP.lhs, discarded, rightRefis, scope, !sense); - resolve(orP.rhs, errVec, rightRefis, scope, sense, true); // :( + resolve(orP.lhs, rightRefis, scope, !sense); + resolve(orP.rhs, rightRefis, scope, sense, true); // :( merge(refis, leftRefis); merge(refis, rightRefis); } -void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); @@ -5985,32 +5742,15 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - if (FFlag::LuauDiscriminableUnions2) - { - refineLValue(isaP.lvalue, refis, scope, predicate); - } - else - { - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); - else - { - addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); - } - } + refineLValue(isaP.lvalue, refis, scope, predicate); } -void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // Rewrite the predicate 'type(foo) == "vector"' to be 'typeof(foo) == "Vector3"'. They're exactly identical. // This allows us to avoid writing in edge cases. if (!typeguardP.isTypeof && typeguardP.kind == "vector") - return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, errVec, refis, scope, sense); + return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, refis, scope, sense); std::optional ty = resolveLValue(refis, scope, typeguardP.lvalue); if (!ty) @@ -6060,52 +5800,29 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (FFlag::LuauDiscriminableUnions2) - { - refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); - return; - } - else - { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); - else - { - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); - } - - return; - } + refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); + return; } - auto fail = [&](const TypeErrorData& err) { - if (!FFlag::LuauDiscriminableUnions2) - errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - }; - if (!typeguardP.isTypeof) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); auto typeFun = globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); TypeId type = follow(typeFun->type); // We're only interested in the root class of any classes. if (auto ctv = get(type); !ctv || ctv->parent) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. // Until then, we rewrite this to be the same as using IsA. - return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); + return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, refis, scope, sense); } -void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. auto options = [](TypeId ty) -> std::vector { @@ -6114,82 +5831,33 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return {ty}; }; - if (FFlag::LuauDiscriminableUnions2) - { - std::vector rhs = options(eqP.type); + std::vector rhs = options(eqP.type); - if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - auto predicate = [&](TypeId option) -> std::optional { - if (sense && isUndecidable(option)) - return FFlag::LuauWeakEqConstraint ? option : eqP.type; + auto predicate = [&](TypeId option) -> std::optional { + if (sense && isUndecidable(option)) + return FFlag::LuauWeakEqConstraint ? option : eqP.type; - if (!sense && isNil(eqP.type)) - return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; - if (maybeSingleton(eqP.type)) - { - // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, eqP.location).empty()) - return sense ? eqP.type : option; - - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; - } - - return option; - }; - - refineLValue(eqP.lvalue, refis, scope, predicate); - } - else - { - if (FFlag::LuauWeakEqConstraint) + if (maybeSingleton(eqP.type)) { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. + if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + return sense ? eqP.type : option; - return; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } - if (FFlag::LuauEqConstraint) - { - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; + return option; + }; - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); - - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) - { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); - } - } - - if (set.empty()) - return; - - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); - } - } + refineLValue(eqP.lvalue, refis, scope, predicate); } bool TypeChecker::isNonstrictMode() const diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index c2435890..ba09df5f 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,8 +5,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) - namespace Luau { @@ -55,13 +53,10 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t { TypeId index = follow(*mtIndex); - if (FFlag::LuauTerminateCyclicMetatableIndexLookup) - { - if (count >= 100) - return std::nullopt; + if (count >= 100) + return std::nullopt; - ++count; - } + ++count; if (const auto& itt = getTableType(index)) { diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 463b4651..2355dab2 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,8 +24,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) -LUAU_FASTFLAG(LuauDiscriminableUnions2) -LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) LUAU_FASTFLAGVARIABLE(LuauClassDefinitionModuleInError, false) namespace Luau @@ -204,14 +202,14 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) + if (get(ty)) return true; auto utv = get(ty); if (!utv) return false; - return std::any_of(begin(utv), end(utv), FFlag::LuauAnyInIsOptionalIsOptional ? isOptional : isNil); + return std::any_of(begin(utv), end(utv), isOptional); } bool isTableIntersection(TypeId ty) @@ -378,8 +376,7 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) if (seen.contains(ty)) return true; - bool isStr = FFlag::LuauDiscriminableUnions2 ? isString(ty) : isPrim(ty, PrimitiveTypeVar::String); - if (isStr || get(ty) || get(ty) || get(ty)) + if (isString(ty) || get(ty) || get(ty) || get(ty)) return true; if (auto uty = get(ty)) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index f5c1dde9..9308e9ff 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -24,8 +24,6 @@ LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) -LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -382,19 +380,6 @@ Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) - : types(types) - , mode(mode) - , log(parentLog, sharedSeen) - , location(location) - , variance(variance) - , sharedState(sharedState) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); - LUAU_ASSERT(sharedState.iceHandler); -} - void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { sharedState.counters.iterationCount = 0; @@ -1219,14 +1204,6 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal continue; } - // In nonstrict mode, any also marks an optional argument. - else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && - log.getMutable(log.follow(*superIter))) - { - superIter.advance(); - continue; - } - if (log.getMutable(superIter.packId)) { tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); @@ -1454,21 +1431,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (subIter == subTable->props.end() && - (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) - missingProperties.push_back(propName); - } - else - { - bool isAny = log.getMutable(log.follow(superProp.type)); - - if (subIter == subTable->props.end() && - (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && - !isAny) - missingProperties.push_back(propName); - } + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && + !isOptional(superProp.type)) + missingProperties.push_back(propName); } if (!missingProperties.empty()) @@ -1485,18 +1450,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) - extraProperties.push_back(propName); - } - else - { - bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && - (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) - extraProperties.push_back(propName); - } + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) + extraProperties.push_back(propName); } if (!extraProperties.empty()) @@ -1540,21 +1495,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if (FFlag::LuauAnyInIsOptionalIsOptional && - (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && - (isOptional(prop.type) || get(follow(prop.type)))) - // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` - // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. - // TODO: should isOptional(anyType) be true? - // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) - { - } else if (subTable->state == TableState::Free) { PendingType* pendingSub = log.queue(subTy); @@ -1618,10 +1564,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (FFlag::LuauAnyInIsOptionalIsOptional && !FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) - { - } - else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) + else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) { } else if (superTable->state == TableState::Free) @@ -1753,9 +1696,7 @@ TypePackId Unifier::widen(TypePackId tp) TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); - if (!FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) - return ty; - else if (isOptional(ty)) + if (isOptional(ty)) return ty; else if (const TableTypeVar* ttv = get(ty)) { @@ -2666,14 +2607,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauTypecheckOptPass) - { - Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; - u.anyIsTop = anyIsTop; - return u; - } - - Unifier u = Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; u.anyIsTop = anyIsTop; return u; } diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index b00440ae..12465377 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -224,6 +224,7 @@ private: DenseHashMap constantMap; DenseHashMap tableShapeMap; + DenseHashMap protoMap; int debugLine = 0; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index fb70392e..beeda295 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauCompileNestedClosureO2) + namespace Luau { @@ -181,6 +183,7 @@ size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) + , protoMap(~0u) , stringTable({nullptr, 0}) , encoder(encoder) { @@ -250,6 +253,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) constantMap.clear(); tableShapeMap.clear(); + protoMap.clear(); debugRemarks.clear(); debugRemarkBuffer.clear(); @@ -372,11 +376,17 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) int16_t BytecodeBuilder::addChildFunction(uint32_t fid) { + if (FFlag::LuauCompileNestedClosureO2) + if (int16_t* cache = protoMap.find(fid)) + return *cache; + uint32_t id = uint32_t(protos.size()); if (id >= kMaxClosureCount) return -1; + if (FFlag::LuauCompileNestedClosureO2) + protoMap[fid] = int16_t(id); protos.push_back(fid); return int16_t(id); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e177e928..4f26ceb9 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -17,8 +17,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileSupportInlining, false) - LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoReserve, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) @@ -30,6 +28,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileNestedClosureO2, false) + namespace Luau { @@ -100,13 +100,11 @@ struct Compiler upvals.reserve(16); } - uint8_t getLocal(AstLocal* local) + int getLocalReg(AstLocal* local) { Local* l = locals.find(local); - LUAU_ASSERT(l); - LUAU_ASSERT(l->allocated); - return l->reg; + return l && l->allocated ? l->reg : -1; } uint8_t getUpval(AstLocal* local) @@ -159,17 +157,19 @@ struct Compiler AstExprFunction* getFunctionExpr(AstExpr* node) { - if (AstExprLocal* le = node->as()) + if (AstExprLocal* expr = node->as()) { - Variable* lv = variables.find(le->local); + Variable* lv = variables.find(expr->local); if (!lv || lv->written || !lv->init) return nullptr; return getFunctionExpr(lv->init); } - else if (AstExprGroup* ge = node->as()) - return getFunctionExpr(ge->expr); + else if (AstExprGroup* expr = node->as()) + return getFunctionExpr(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getFunctionExpr(expr->expr); else return node->as(); } @@ -180,13 +180,13 @@ struct Compiler { bool result = true; - bool visit(AstExpr* node) override + bool visit(AstExprFunction* node) override { - // nested functions may capture function arguments, and our upval handling doesn't handle elided variables (constant) - // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues - // TODO: additionally we would need to change upvalue handling in compileExprFunction to handle upvalue->local migration - result = result && !node->is(); - return result; + if (!FFlag::LuauCompileNestedClosureO2) + result = false; + + // short-circuit to avoid analyzing nested closure bodies + return false; } bool visit(AstStat* node) override @@ -275,8 +275,7 @@ struct Compiler f.upvals = upvals; // record information for inlining - if (FFlag::LuauCompileSupportInlining && options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && - !getfenvUsed && !setfenvUsed) + if (options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && !getfenvUsed && !setfenvUsed) { f.canInline = true; f.stackSize = stackSize; @@ -346,8 +345,8 @@ struct Compiler uint8_t argreg; - if (isExprLocalReg(arg)) - argreg = getLocal(arg->as()->local); + if (int reg = getExprLocalReg(arg); reg >= 0) + argreg = uint8_t(reg); else { argreg = uint8_t(regs + 1); @@ -403,8 +402,8 @@ struct Compiler } } - if (isExprLocalReg(expr->args.data[i])) - args[i] = getLocal(expr->args.data[i]->as()->local); + if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + args[i] = uint8_t(reg); else { args[i] = uint8_t(regs + 1 + i); @@ -489,19 +488,18 @@ struct Compiler return false; } - // TODO: we can compile functions with mismatching arity at call site but it's more annoying - if (func->args.size != expr->args.size) - { - bytecode.addDebugRemark("inlining failed: argument count mismatch (expected %d, got %d)", int(func->args.size), int(expr->args.size)); - return false; - } - - // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + // compute constant bitvector for all arguments to feed the cost model bool varc[8] = {}; - for (size_t i = 0; i < expr->args.size && i < 8; ++i) + for (size_t i = 0; i < func->args.size && i < expr->args.size && i < 8; ++i) varc[i] = isConstant(expr->args.data[i]); - int inlinedCost = computeCost(fi->costModel, varc, std::min(int(expr->args.size), 8)); + // if the last argument only returns a single value, all following arguments are nil + if (expr->args.size != 0 && !(expr->args.data[expr->args.size - 1]->is() || expr->args.data[expr->args.size - 1]->is())) + for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i) + varc[i] = true; + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + int inlinedCost = computeCost(fi->costModel, varc, std::min(int(func->args.size), 8)); int baselineCost = computeCost(fi->costModel, nullptr, 0) + 3; int inlineProfit = (inlinedCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / inlinedCost); @@ -533,15 +531,44 @@ struct Compiler for (size_t i = 0; i < func->args.size; ++i) { AstLocal* var = func->args.data[i]; - AstExpr* arg = expr->args.data[i]; + AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr; - if (Variable* vv = variables.find(var); vv && vv->written) + if (i + 1 == expr->args.size && func->args.size > expr->args.size && (arg->is() || arg->is())) + { + // if the last argument can return multiple values, we need to compute all of them into the remaining arguments + unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; + uint8_t reg = allocReg(arg, tail); + + if (AstExprCall* expr = arg->as()) + compileExprCall(expr, reg, tail, /* targetTop= */ true); + else if (AstExprVarargs* expr = arg->as()) + compileExprVarargs(expr, reg, tail); + else + LUAU_ASSERT(!"Unexpected expression type"); + + for (size_t j = i; j < func->args.size; ++j) + pushLocal(func->args.data[j], uint8_t(reg + (j - i))); + + // all remaining function arguments have been allocated and assigned to + break; + } + else if (Variable* vv = variables.find(var); vv && vv->written) { // if the argument is mutated, we need to allocate a fresh register even if it's a constant uint8_t reg = allocReg(arg, 1); - compileExprTemp(arg, reg); + + if (arg) + compileExprTemp(arg, reg); + else + bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); + pushLocal(var, reg); } + else if (arg == nullptr) + { + // since the argument is not mutated, we can simply fold the value into the expressions that need it + locstants[var] = {Constant::Type_Nil}; + } else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) { // since the argument is not mutated, we can simply fold the value into the expressions that need it @@ -553,20 +580,26 @@ struct Compiler Variable* lv = le ? variables.find(le->local) : nullptr; // if the argument is a local that isn't mutated, we will simply reuse the existing register - if (isExprLocalReg(arg) && (!lv || !lv->written)) + if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - uint8_t reg = getLocal(le->local); - pushLocal(var, reg); + pushLocal(var, uint8_t(reg)); } else { - uint8_t reg = allocReg(arg, 1); - compileExprTemp(arg, reg); - pushLocal(var, reg); + uint8_t temp = allocReg(arg, 1); + compileExprTemp(arg, temp); + pushLocal(var, temp); } } } + // evaluate extra expressions for side effects + for (size_t i = func->args.size; i < expr->args.size; ++i) + { + RegScope rsi(this); + compileExprAuto(expr->args.data[i], rsi); + } + // fold constant values updated above into expressions in the function body foldConstants(constants, variables, locstants, func->body); @@ -627,12 +660,15 @@ struct Compiler FInt::LuauCompileInlineThresholdMaxBoost, FInt::LuauCompileInlineDepth)) return; - if (fi && !fi->canInline) + // add a debug remark for cases when we didn't even call tryCompileInlinedCall + if (func && !(fi && fi->canInline)) { if (func->vararg) bytecode.addDebugRemark("inlining failed: function is variadic"); - else + else if (fi) bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + else + bytecode.addDebugRemark("inlining failed: can't inline recursive calls"); } } @@ -677,9 +713,9 @@ struct Compiler LUAU_ASSERT(fi); // Optimization: use local register directly in NAMECALL if possible - if (isExprLocalReg(fi->expr)) + if (int reg = getExprLocalReg(fi->expr); reg >= 0) { - selfreg = getLocal(fi->expr->as()->local); + selfreg = uint8_t(reg); } else { @@ -785,6 +821,8 @@ struct Compiler void compileExprFunction(AstExprFunction* expr, uint8_t target) { + RegScope rs(this); + const Function* f = functions.find(expr); LUAU_ASSERT(f); @@ -795,6 +833,67 @@ struct Compiler if (pid < 0) CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); + if (FFlag::LuauCompileNestedClosureO2) + { + captures.clear(); + captures.reserve(f->upvals.size()); + + for (AstLocal* uv : f->upvals) + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth); + + if (int reg = getLocalReg(uv); reg >= 0) + { + // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; + + captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); + } + else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) + { + // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register + uint8_t reg = allocReg(expr, 1); + compileExprConstant(expr, uc, reg); + + captures.push_back({LCT_VAL, reg}); + } + else + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); + + // get upvalue from parent frame + // note: this will add uv to the current upvalue list if necessary + uint8_t uid = getUpval(uv); + + captures.push_back({LCT_UPVAL, uid}); + } + } + + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + int16_t shared = -1; + + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + shared = int16_t(cid); + } + + if (shared >= 0) + bytecode.emitAD(LOP_DUPCLOSURE, target, shared); + else + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (const Capture& c : captures) + bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); + + return; + } + bool shared = false; // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure @@ -824,9 +923,10 @@ struct Compiler if (uv->functionDepth == expr->functionDepth - 1) { // get local variable - uint8_t reg = getLocal(uv); + int reg = getLocalReg(uv); + LUAU_ASSERT(reg >= 0); - bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), reg, 0); + bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), uint8_t(reg), 0); } else { @@ -1213,10 +1313,10 @@ struct Compiler if (!isConditionFast(expr->left)) { // Optimization: when right hand side is a local variable, we can use AND/OR - if (isExprLocalReg(expr->right)) + if (int reg = getExprLocalReg(expr->right); reg >= 0) { uint8_t lr = compileExprAuto(expr->left, rs); - uint8_t rr = getLocal(expr->right->as()->local); + uint8_t rr = uint8_t(reg); bytecode.emitABC(and_ ? LOP_AND : LOP_OR, target, lr, rr); return; @@ -1803,19 +1903,18 @@ struct Compiler } else if (AstExprLocal* expr = node->as()) { - if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + if (int reg = getExprLocalReg(expr); reg >= 0) + { + bytecode.emitABC(LOP_MOVE, target, uint8_t(reg), 0); + } + else { LUAU_ASSERT(expr->upvalue); uint8_t uid = getUpval(expr->local); bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); } - else - { - uint8_t reg = getLocal(expr->local); - - bytecode.emitABC(LOP_MOVE, target, reg, 0); - } } else if (AstExprGlobal* expr = node->as()) { @@ -1879,8 +1978,8 @@ struct Compiler uint8_t compileExprAuto(AstExpr* node, RegScope&) { // Optimization: directly return locals instead of copying them to a temporary - if (isExprLocalReg(node)) - return getLocal(node->as()->local); + if (int reg = getExprLocalReg(node); reg >= 0) + return uint8_t(reg); // note: the register is owned by the parent scope uint8_t reg = allocReg(node, 1); @@ -1910,7 +2009,7 @@ struct Compiler for (size_t i = 0; i < targetCount; ++i) compileExprTemp(list.data[i], uint8_t(target + i)); - // compute expressions with values that go nowhere; this is required to run side-effecting code if any + // evaluate extra expressions for side effects for (size_t i = targetCount; i < list.size; ++i) { RegScope rsi(this); @@ -2008,20 +2107,21 @@ struct Compiler if (AstExprLocal* expr = node->as()) { - if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + if (int reg = getExprLocalReg(expr); reg >= 0) { - LUAU_ASSERT(expr->upvalue); - - LValue result = {LValue::Kind_Upvalue}; - result.upval = getUpval(expr->local); + LValue result = {LValue::Kind_Local}; + result.reg = uint8_t(reg); result.location = node->location; return result; } else { - LValue result = {LValue::Kind_Local}; - result.reg = getLocal(expr->local); + LUAU_ASSERT(expr->upvalue); + + LValue result = {LValue::Kind_Upvalue}; + result.upval = getUpval(expr->local); result.location = node->location; return result; @@ -2115,15 +2215,21 @@ struct Compiler compileLValueUse(lv, source, /* set= */ true); } - bool isExprLocalReg(AstExpr* expr) + int getExprLocalReg(AstExpr* node) { - AstExprLocal* le = expr->as(); - if (!le || (!FFlag::LuauCompileSupportInlining && le->upvalue)) - return false; + if (AstExprLocal* expr = node->as()) + { + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + Local* l = locals.find(expr->local); - Local* l = locals.find(le->local); - - return l && l->allocated; + return l && l->allocated ? l->reg : -1; + } + else if (AstExprGroup* expr = node->as()) + return getExprLocalReg(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getExprLocalReg(expr->expr); + else + return -1; } bool isStatBreak(AstStat* node) @@ -2352,20 +2458,17 @@ struct Compiler // Optimization: return locals directly instead of copying them into a temporary // this is very important for a single return value and occasionally effective for multiple values - if (stat->list.size > 0 && isExprLocalReg(stat->list.data[0])) + if (int reg = stat->list.size > 0 ? getExprLocalReg(stat->list.data[0]) : -1; reg >= 0) { - temp = getLocal(stat->list.data[0]->as()->local); + temp = uint8_t(reg); consecutive = true; for (size_t i = 1; i < stat->list.size; ++i) - { - AstExpr* v = stat->list.data[i]; - if (!isExprLocalReg(v) || getLocal(v->as()->local) != temp + i) + if (getExprLocalReg(stat->list.data[i]) != int(temp + i)) { consecutive = false; break; } - } } if (!consecutive && stat->list.size > 0) @@ -2438,12 +2541,13 @@ struct Compiler { bool result = true; - bool visit(AstExpr* node) override + bool visit(AstExprFunction* node) override { - // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) - // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues - result = result && !node->is(); - return result; + if (!FFlag::LuauCompileNestedClosureO2) + result = false; + + // short-circuit to avoid analyzing nested closure bodies + return false; } bool visit(AstStat* node) override @@ -2874,12 +2978,9 @@ struct Compiler void compileStatFunction(AstStatFunction* stat) { // Optimization: compile value expresion directly into target local register - if (isExprLocalReg(stat->name)) + if (int reg = getExprLocalReg(stat->name); reg >= 0) { - AstExprLocal* le = stat->name->as(); - LUAU_ASSERT(le); - - compileExpr(stat->func, getLocal(le->local)); + compileExpr(stat->func, uint8_t(reg)); return; } @@ -3399,6 +3500,12 @@ struct Compiler std::vector returnJumps; }; + struct Capture + { + LuauCaptureType type; + uint8_t data; + }; + BytecodeBuilder& bytecode; CompileOptions options; @@ -3422,6 +3529,7 @@ struct Compiler std::vector loopJumps; std::vector loops; std::vector inlineFrames; + std::vector captures; }; void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) @@ -3465,6 +3573,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); + const Compiler::Function* mainf = compiler.functions.find(&main); + LUAU_ASSERT(mainf && mainf->upvals.empty()); + bytecode.setMainFunction(mainid); bytecode.finalize(); } diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index e4d59ea1..a62beeb1 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -3,8 +3,6 @@ #include -LUAU_FASTFLAG(LuauCompileSupportInlining) - namespace Luau { namespace Compile @@ -330,7 +328,7 @@ struct ConstantVisitor : AstVisitor { if (value.type != Constant::Type_Unknown) map[key] = value; - else if (!FFlag::LuauCompileSupportInlining || wasEmpty) + else if (wasEmpty) ; else if (Constant* old = map.find(key)) old->type = Constant::Type_Unknown; diff --git a/Sources.cmake b/Sources.cmake index d2430cc9..297f561a 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -73,6 +73,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/ToString.h Analysis/include/Luau/Transpiler.h Analysis/include/Luau/TxnLog.h + Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeInfer.h @@ -108,6 +109,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/ToString.cpp Analysis/src/Transpiler.cpp Analysis/src/TxnLog.cpp + Analysis/src/TypeArena.cpp Analysis/src/TypeAttach.cpp Analysis/src/TypedAllocator.cpp Analysis/src/TypeInfer.cpp diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 9c1f387e..27187c61 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,10 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry2, false) - -void (*lua_table_move_telemetry)(lua_State* L, int f, int e, int t, int nf, int nt); - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -199,29 +195,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; - - if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - bool report = false; - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - report = true; - if (!(e == nf || (e >= 1 && e <= nf))) - report = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt))) - report = true; - - if (report) - telemetrycb(L, f, e, t, nf, nt); - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 3c7c276a..9e2eb268 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -17,9 +17,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauIter, false) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauIterCallTelemetry, false) - -void (*lua_iter_call_telemetry)(lua_State* L); // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -157,17 +154,6 @@ LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) StkId ra = &L->base[a]; LUAU_ASSERT(ra + 3 <= L->top); - if (DFFlag::LuauIterCallTelemetry) - { - /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ - void (*telemetrycb)(lua_State* L) = lua_iter_call_telemetry; - - if (telemetrycb && ttistable(ra) && fasttm(L, hvalue(ra)->metatable, TM_CALL)) - telemetrycb(L); - if (telemetrycb && ttisuserdata(ra) && fasttm(L, uvalue(ra)->metatable, TM_CALL)) - telemetrycb(L); - } - setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 1, ra + 1); setobjs2s(L, ra + 3, ra); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index b4e9340c..caaccf4e 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2772,6 +2772,8 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { + ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; + check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2844,6 +2846,8 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { + ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; + check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index b032060e..cf27d191 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4269,22 +4269,26 @@ FORNLOOP R3 -6 FORNLOOP R0 -11 RETURN R0 0 )"); +} - // can't unroll loops if the body has functions that refer to loop variables +TEST_CASE("LoopUnrollNestedClosure") +{ + ScopedFastFlag sff("LuauCompileNestedClosureO2", true); + + // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues CHECK_EQ("\n" + compileFunction(R"( -for i=1,1 do +for i=1,2 do local x = function() return i end end )", 1, 2), R"( -LOADN R2 1 -LOADN R0 1 LOADN R1 1 -FORNPREP R0 +3 -NEWCLOSURE R3 P0 -CAPTURE VAL R2 -FORNLOOP R0 -3 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 +LOADN R1 2 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 RETURN R0 0 )"); } @@ -4469,8 +4473,6 @@ RETURN R0 0 TEST_CASE("InlineBasic") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // inline function that returns a constant CHECK_EQ("\n" + compileFunction(R"( local function foo() @@ -4550,10 +4552,72 @@ RETURN R1 1 )"); } +TEST_CASE("InlineBasicProhibited") +{ + ScopedFastFlag sff("LuauCompileNestedClosureO2", true); + + // we can't inline variadic functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(...) + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // we also can't inline functions that have internal loops + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + for i=1,4 do end +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineNestedClosures") +{ + ScopedFastFlag sff("LuauCompileNestedClosureO2", true); + + // we can inline functions that contain/return functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(x) + return function(y) return x + y end +end + +local x = foo(1)(2) +return x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 1 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +LOADN R2 2 +CALL R1 1 1 +RETURN R1 1 +)"); +} + TEST_CASE("InlineMutate") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // if the argument is mutated, it gets a register even if the value is constant CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -4636,8 +4700,6 @@ RETURN R1 1 TEST_CASE("InlineUpval") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // if the argument is an upvalue, we naturally need to copy it to a local CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -4705,8 +4767,6 @@ RETURN R1 1 TEST_CASE("InlineFallthrough") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // if the function doesn't return, we still fill the results with nil CHECK_EQ("\n" + compileFunction(R"( local function foo() @@ -4759,8 +4819,6 @@ RETURN R1 -1 TEST_CASE("InlineCapture") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // can't inline function with nested functions that capture locals because they might be constants CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -4782,12 +4840,9 @@ RETURN R2 -1 TEST_CASE("InlineArgMismatch") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // when inlining a function, we must respect all the usual rules // caller might not have enough arguments - // TODO: we don't inline this atm CHECK_EQ("\n" + compileFunction(R"( local function foo(a) return a @@ -4799,13 +4854,11 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 -CALL R1 0 1 +LOADNIL R1 RETURN R1 1 )"); // caller might be using multret for arguments - // TODO: we don't inline this atm CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) return a + b @@ -4817,17 +4870,32 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 LOADK R3 K1 FASTCALL1 20 R3 +2 GETIMPORT R2 4 -CALL R2 1 -1 -CALL R1 -1 1 +CALL R2 1 2 +ADD R1 R2 R3 +RETURN R1 1 +)"); + + // caller might be using varargs for arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(...) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R2 2 +ADD R1 R2 R3 RETURN R1 1 )"); // caller might have too many arguments, but we still need to compute them for side effects - // TODO: we don't inline this atm CHECK_EQ("\n" + compileFunction(R"( local function foo(a) return a @@ -4839,19 +4907,34 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 +GETIMPORT R2 2 +CALL R2 0 1 +LOADN R1 42 +RETURN R1 1 +)"); + + // caller might not have enough arguments, and the arg might be mutated so it needs a register + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 LOADN R2 42 -GETIMPORT R3 2 -CALL R3 0 -1 -CALL R1 -1 1 +MOVE R1 R2 RETURN R1 1 )"); } TEST_CASE("InlineMultiple") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // we call this with a different set of variable/constant args CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -4880,8 +4963,6 @@ RETURN R3 4 TEST_CASE("InlineChain") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // inline a chain of functions CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -4912,8 +4993,6 @@ RETURN R3 1 TEST_CASE("InlineThresholds") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - ScopedFastInt sfis[] = { {"LuauCompileInlineThreshold", 25}, {"LuauCompileInlineThresholdMaxBoost", 300}, @@ -4988,8 +5067,6 @@ RETURN R3 1 TEST_CASE("InlineIIFE") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // IIFE with arguments CHECK_EQ("\n" + compileFunction(R"( function choose(a, b, c) @@ -5025,8 +5102,6 @@ RETURN R3 1 TEST_CASE("InlineRecurseArguments") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // we can't inline a function if it's used to compute its own arguments CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -5036,22 +5111,20 @@ foo(foo(foo,foo(foo,foo))[foo]) 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 +MOVE R2 R0 +MOVE R3 R0 MOVE R4 R0 MOVE R5 R0 MOVE R6 R0 -CALL R4 2 1 -LOADNIL R3 -GETTABLE R2 R3 R0 -CALL R1 1 0 +CALL R4 2 -1 +CALL R2 -1 1 +GETTABLE R1 R2 R0 RETURN R0 0 )"); } TEST_CASE("InlineFastCallK") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - CHECK_EQ("\n" + compileFunction(R"( local function set(l0) rawset({}, l0) @@ -5080,8 +5153,6 @@ RETURN R0 0 TEST_CASE("InlineExprIndexK") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - CHECK_EQ("\n" + compileFunction(R"( local _ = function(l0) local _ = nil @@ -5141,6 +5212,58 @@ RETURN R0 0 )"); } +TEST_CASE("InlineHiddenMutation") +{ + // when the argument is assigned inside the function, we can't reuse the local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = ... +local y = foo(x :: number) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +LOADN R3 42 +MOVE R2 R3 +RETURN R2 1 +)"); + + // and neither can we do that when it's assigned outside the function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + mutator() + return a +end + +local x = ... +mutator = function() x = 42 end + +local y = foo(x :: number) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE REF R1 +SETGLOBAL R2 K1 +MOVE R3 R1 +GETGLOBAL R4 K1 +CALL R4 0 0 +MOVE R2 R3 +CLOSEUPVALS R1 +RETURN R2 1 +)"); +} + TEST_CASE("ReturnConsecutive") { // we can return a single local directly @@ -5193,6 +5316,16 @@ return )"), R"( RETURN R0 0 +)"); + + // this optimization also works in presence of group / type casts + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return (x), y :: number +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 )"); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 4a999861..c7e18efd 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -198,10 +198,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - ScopedFastFlag sff[]{ - {"LuauLosslessClone", true}, - }; - TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -218,8 +214,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_types") TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { - ScopedFastFlag sff{"LuauLosslessClone", true}; - TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->state = TableState::Free; @@ -252,8 +246,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 69430b1c..83c526ef 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -150,8 +150,6 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( --!nonstrict local T = {} @@ -169,8 +167,6 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( --!nonstrict local T = {} diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 41830682..dd49eb01 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -683,6 +683,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", false} }; check(R"( @@ -697,6 +698,26 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") CHECK(t->normal); } +// Unfortunately, getting this right in the general case is difficult. +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_not_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true} + }; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(!t->normal); +} + TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal") { ScopedFastFlag flags[] = { @@ -997,4 +1018,28 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_failure_bound_type_is_normal_but_not_its_bounde LUAU_REQUIRE_ERRORS(result); } +// We had an issue where a normal BoundTypeVar might point at a non-normal BoundTypeVar if it in turn pointed to a +// normal TypeVar because we were calling follow() in an improper place. +TEST_CASE_FIXTURE(Fixture, "bound_typevars_should_only_be_marked_normal_if_their_pointee_is_normal") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:M() + local function f(a) + print(self.prop) + self:g(a) + self.prop = a + end + end + + return T + )"); +} + TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index c16f60d5..14c17614 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -22,8 +22,6 @@ struct LimitFixture : BuiltinsFixture #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; #endif - - ScopedFastFlag LuauJustOneCallFrameForHaveSeen{"LuauJustOneCallFrameForHaveSeen", true}; }; template diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index f38dd10a..b854bc51 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -650,6 +650,19 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_generics") +{ + ScopedFastFlag sff[] = { + {"LuauAlwaysQuantify", true}, + }; + + CheckResult result = check(R"( + function foo(x: a, y) end + )"); + + CHECK("(a, b) -> ()" == toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { ScopedFastFlag flag{"LuauDocFuncParameters", true}; @@ -685,5 +698,4 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); } - TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index b710ea0d..aa4ca415 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -878,8 +878,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") { ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -899,8 +897,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") { ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -916,11 +912,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(...: number?) return assert(...) @@ -933,11 +924,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pa TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(x: nil) return assert(x, "hmm") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 14f1f703..a28ba49e 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1496,8 +1496,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( local function f(x: any) end f() diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de0c9391..78a5fee7 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1121,4 +1121,78 @@ TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") +{ + ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; + + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +local object: MyObject = { + getReturnValue = function(cb: () -> U): U + return cb() + end, +} + +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex: ComplexObject = { + id = "Foo", + nested = object, +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") +{ + ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; + + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex2: ComplexObject = nil + +local x = complex2.nested.getReturnValue(function(): string + return "" +end) + +local y = complex2.nested.getReturnValue(function() + return 3 +end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_generic") +{ + ScopedFastFlag sff[] = { + {"LuauAlwaysQuantify", true}, + }; + + CheckResult result = check(R"( + function foo(f, x: X) + return f(x) + end + )"); + + CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 41bc0c21..f75b2d11 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -177,8 +177,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type A = {x: {y: {z: {thing: string}}}} type B = {x: {y: {z: {thing: string}}}} diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 765419c6..a3cae3de 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -475,8 +475,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") { - ScopedFastFlag luauInstantiateFollows{"LuauInstantiateFollows", true}; - // Just check that this doesn't assert check(R"( --!nonstrict diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 51f6fdfb..03614938 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -728,8 +728,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: string | number, b: boolean | number) return a == b diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index ee3ae972..9d227895 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -7,7 +7,6 @@ #include -LUAU_FASTFLAG(LuauEqConstraint) LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -183,8 +182,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") // We'll need to not only report an error on `a == b`, but also to refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then @@ -208,8 +205,6 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") // Just needs to fully support equality refinement. Which is annoying without type states. TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") { - ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; - CheckResult result = check(R"( type T = {x: string, y: number} | {x: nil, y: nil} @@ -471,4 +466,35 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it CHECK_EQ("boolean", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o:m1() + t[o] = nil + end + + local function bar(o) + o:m2() + t[o] = true + end + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: We're missing generics a... and b... + CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 8c130490..85a3334c 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauLowerBoundsCalculation) @@ -268,18 +267,10 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauDiscriminableUnions2) - { - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -378,8 +369,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?, b: boolean?) if a == b then @@ -392,28 +381,15 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "nil"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "nil"); // a == b - - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a == 1 then @@ -426,24 +402,12 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 } TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(a: (string | number)?) if "hello" == a then @@ -462,8 +426,6 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a ~= nil then @@ -476,21 +438,12 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "nil"); // a == nil - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( @@ -509,8 +462,6 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: any, b: {x: number}?) if a ~= b then @@ -521,22 +472,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local t: {string} = {"hello"} @@ -554,18 +495,8 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b - } - else - { - // This is technically not wrong, but it's also wrong at the same time. - // The refinement code is none the wiser about the fact we pulled a string out of an array, so it has no choice but to narrow as just string. - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") @@ -594,16 +525,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions2) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); } @@ -1009,10 +931,6 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -1033,10 +951,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type Cat = {tag: "Cat", name: string, catfood: string} type Dog = {tag: "Dog", name: string, dogfood: string} @@ -1070,11 +984,6 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauAssertStripsFalsyTypes", true}, - }; - CheckResult result = check(R"( local function is_true(b: true) end local function is_false(b: false) end @@ -1093,11 +1002,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauAssertStripsFalsyTypes", true}, - }; - CheckResult result = check(R"( type Ok = { ok: true, value: T } type Err = { ok: false, error: E } @@ -1117,8 +1021,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersection_table") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type T = {} & {f: ((string) -> string)?} local function f(t: T, x) @@ -1133,10 +1035,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersect TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} @@ -1171,14 +1069,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions2) - LUAU_REQUIRE_NO_ERRORS(result); - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 79eeb824..d90dfbb5 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -139,6 +139,8 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { + ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; + CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" @@ -325,8 +327,6 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { ScopedFastFlag sff[]{ - {"LuauEqConstraint", true}, - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, }; @@ -350,11 +350,8 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauEqConstraint", true}, {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, - {"LuauDoNotAccidentallyDependOnPointerOrdering", true}, }; CheckResult result = check(R"( @@ -390,7 +387,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") { ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -419,6 +415,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { ScopedFastFlag sff[]{ {"LuauWidenIfSupertypeIsFree2", true}, + {"LuauWeakEqConstraint", true}, }; CheckResult result = check(R"( @@ -456,10 +453,6 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" then @@ -474,10 +467,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" or a == "bye" then @@ -492,10 +481,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" then @@ -510,10 +495,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" or a == "bye" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 5078b0bf..c924484a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2279,8 +2279,6 @@ local y = #x TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") { - ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; - // t :: t1 where t1 = {metatable {__index: t1, __tostring: (t1) -> string}} CheckResult result = check(R"( local mt = {} @@ -2313,8 +2311,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "give_up_after_one_metatable_index_look_up") TEST_CASE_FIXTURE(Fixture, "confusing_indexing") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type T = {} & {p: number | string} local function f(t: T) @@ -2971,8 +2967,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") { - ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; - CheckResult result = check(R"( local t: { [string]: number } = { 5, 6, 7 } )"); @@ -2984,4 +2978,32 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); } +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") +{ + ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; + ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; + + CheckResult result = check(R"( + type X = { { x: boolean?, y: boolean? } } + + local l1: {[string]: X} = { key = { { x = true }, { y = true } } } + local l2: {[any]: X} = { key = { { x = true }, { y = true } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") +{ + ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; + + CheckResult result = check(R"( + type X = {[any]: string | boolean} + + local x: X = { key = "str" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 48cd1c3d..1d144db7 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -15,7 +15,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) -LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; @@ -308,7 +307,6 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") int limit = 600; #endif - ScopedFastFlag sff{"LuauTableUseCounterInstead", true}; ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); @@ -1011,8 +1009,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") { - ScopedFastFlag substituteFollowNewTypes{"LuauSubstituteFollowNewTypes", true}; - CheckResult result = check(R"( local obj = {} diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 277f3887..d19d80cb 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; From 70ff6b434704bd336ea33fbf219d483f5a19ecee Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 20 May 2022 13:00:53 -0700 Subject: [PATCH 069/102] Update performance.md (#494) Add a section on table length optimizations and reword the table iteration section a bit to account for generalized iteration. --- docs/_pages/performance.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/_pages/performance.md b/docs/_pages/performance.md index b4fd3a7b..10e2341c 100644 --- a/docs/_pages/performance.md +++ b/docs/_pages/performance.md @@ -92,12 +92,22 @@ As a result, builtin calls are very fast in Luau - they are still slightly slowe ## Optimized table iteration -Luau implements a fully generic iteration protocol; however, for iteration through tables it recognizes three common idioms (`for .. in ipairs(t)`, `for .. in pairs(t)` and `for .. in next, t`) and emits specialized bytecode that is carefully optimized using custom internal iterators. +Luau implements a fully generic iteration protocol; however, for iteration through tables in addition to generalized iteration (`for .. in t`) it recognizes three common idioms (`for .. in ipairs(t)`, `for .. in pairs(t)` and `for .. in next, t`) and emits specialized bytecode that is carefully optimized using custom internal iterators. -As a result, iteration through tables typically doesn't result in function calls for every iteration; the performance of iteration using `pairs` and `ipairs` is comparable, so it's recommended to pick the iteration style based on readability instead of performance. +As a result, iteration through tables typically doesn't result in function calls for every iteration; the performance of iteration using generalized iteration, `pairs` and `ipairs` is comparable, so generalized iteration (without the use of `pairs`/`ipairs`) is recommended unless the code needs to be compatible with vanilla Lua or the specific semantics of `ipairs` (which stops at the first `nil` element) is required. Additionally, using generalized iteration avoids calling `pairs` when the loop starts which can be noticeable when the table is very short. Iterating through array-like tables using `for i=1,#t` tends to be slightly slower because of extra cost incurred when reading elements from the table. +## Optimized table length + +Luau tables use a hybrid array/hash storage, like in Lua; in some sense "arrays" don't truly exist and are an internal optimization, but some operations, notably `#t` and functions that depend on it, like `table.insert`, are defined by the Luau/Lua language to allow internal optimizations. Luau takes advantage of that fact. + +Unlike Lua, Luau guarantees that the element at index `#t` is stored in the array part of the table. This can accelerate various table operations that use indices limited by `#t`, and this makes `#t` worst-case complexity O(logN), unlike Lua where the worst case complexity is O(N). This also accelerates computation of this value for small tables like `{ [1] = 1 }` since we never need to look at the hash part. + +The "default" implementation of `#t` in both Lua and Luau is a binary search. Luau uses a special branch-free (depending on the compiler...) implementation of the binary search which results in 50+% faster computation of table length when it needs to be computed from scratch. + +Additionally, Luau can cache the length of the table and adjust it following operations like `table.insert`/`table.remove`; this means that in practice, `#t` is almost always a constant time operation. + ## Creating and modifying tables Luau implements several optimizations for table creation. When creating object-like tables, it's recommended to use table literals (`{ ... }`) and to specify all table fields in the literal in one go instead of assigning fields later; this triggers an optimization inspired by LuaJIT's "table templates" and results in higher performance when creating objects. When creating array-like tables, if the maximum size of the table is known up front, it's recommended to use `table.create` function which can create an empty table with preallocated storage, and optionally fill it with a given value. @@ -112,7 +122,7 @@ v.z = 3 return v ``` -When appending elements to tables, it's recommended to use `table.insert` (which is the fastest method to append an element to a table if the table size is not known). In cases when a table is filled sequentially, however, it's much more efficient to use a known index for insertion - together with preallocating tables using `table.create` this can result in much faster code, for example this is the fastest way to build a table of squares: +When appending elements to tables, it's recommended to use `table.insert` (which is the fastest method to append an element to a table if the table size is not known). In cases when a table is filled sequentially, however, it can be more efficient to use a known index for insertion - together with preallocating tables using `table.create` this can result in much faster code, for example this is the fastest way to build a table of squares: ```lua local t = table.create(N) From fb9c4311d8645fbce50186b4455992ad84194846 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petri=20H=C3=A4kkinen?= Date: Tue, 24 May 2022 18:59:12 +0300 Subject: [PATCH 070/102] Add lua_tolightuserdata, optimized lua_topointer (#496) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Petri Häkkinen --- VM/include/lua.h | 1 + VM/src/lapi.cpp | 19 ++++++++++++------- tests/Conformance.test.cpp | 2 ++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/VM/include/lua.h b/VM/include/lua.h index c3ebadb1..7f9647c8 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -148,6 +148,7 @@ LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); +LUA_API void* lua_tolightuserdata(lua_State* L, int idx); LUA_API void* lua_touserdata(lua_State* L, int idx); LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag); LUA_API int lua_userdatatag(lua_State* L, int idx); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f8baefaf..df144a15 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -478,18 +478,22 @@ lua_CFunction lua_tocfunction(lua_State* L, int idx) return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); } +void* lua_tolightuserdata(lua_State* L, int idx) +{ + StkId o = index2addr(L, idx); + return (!ttislightuserdata(o)) ? NULL : pvalue(o); +} + void* lua_touserdata(lua_State* L, int idx) { StkId o = index2addr(L, idx); - switch (ttype(o)) - { - case LUA_TUSERDATA: + // fast-path: check userdata first since it is most likely the expected result + if (ttisuserdata(o)) return uvalue(o)->data; - case LUA_TLIGHTUSERDATA: + else if (ttislightuserdata(o)) return pvalue(o); - default: + else return NULL; - } } void* lua_touserdatatagged(lua_State* L, int idx, int tag) @@ -524,8 +528,9 @@ const void* lua_topointer(lua_State* L, int idx) case LUA_TTHREAD: return thvalue(o); case LUA_TUSERDATA: + return uvalue(o)->data; case LUA_TLIGHTUSERDATA: - return lua_touserdata(L, idx); + return pvalue(o); default: return NULL; } diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index a23ea470..4282bd78 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1066,6 +1066,7 @@ TEST_CASE("UserdataApi") int lud; lua_pushlightuserdata(L, &lud); + CHECK(lua_tolightuserdata(L, -1) == &lud); CHECK(lua_touserdata(L, -1) == &lud); CHECK(lua_topointer(L, -1) == &lud); @@ -1073,6 +1074,7 @@ TEST_CASE("UserdataApi") int* ud1 = (int*)lua_newuserdata(L, 4); *ud1 = 42; + CHECK(lua_tolightuserdata(L, -1) == nullptr); CHECK(lua_touserdata(L, -1) == ud1); CHECK(lua_topointer(L, -1) == ud1); From 69acf5ac07fbb0517fece9cd0df75d75cd1184cb Mon Sep 17 00:00:00 2001 From: T 'Filtered' C Date: Tue, 24 May 2022 19:29:17 +0100 Subject: [PATCH 071/102] Make coroutine.status use a string literal (#500) Implements #495 C++ isn't a language im very familliar with so this might be completely wrong --- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index be3fcd7d..9a2259f1 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -143,7 +143,7 @@ declare coroutine: { create: ((A...) -> R...) -> thread, resume: (thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> string, + status: (thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., From e13f17e2251e1fc3db03593cd139280fc84b2aca Mon Sep 17 00:00:00 2001 From: Austin <6193474+axstin@users.noreply.github.com> Date: Tue, 24 May 2022 13:32:03 -0500 Subject: [PATCH 072/102] Fix VM inconsistency caused by userdata C TM fast paths (#497) This fixes usage of userdata C functions in xpcall handler following call stack overflow --- VM/src/ldo.cpp | 16 +++++--- VM/src/ldo.h | 1 + VM/src/lvmexecute.cpp | 2 +- tests/conformance/errors.lua | 75 ++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 7 deletions(-) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index c133a59e..c7904dde 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -213,6 +213,14 @@ CallInfo* luaD_growCI(lua_State* L) return ++L->ci; } +void luaD_checkCstack(lua_State *L) +{ + if (L->nCcalls == LUAI_MAXCCALLS) + luaG_runerror(L, "C stack overflow"); + else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) + luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ +} + /* ** Call a function (C or Lua). The function to be called is at *func. ** The arguments are on the stack, right after the function. @@ -222,12 +230,8 @@ CallInfo* luaD_growCI(lua_State* L) void luaD_call(lua_State* L, StkId func, int nResults) { if (++L->nCcalls >= LUAI_MAXCCALLS) - { - if (L->nCcalls == LUAI_MAXCCALLS) - luaG_runerror(L, "C stack overflow"); - else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) - luaD_throw(L, LUA_ERRERR); /* error while handing stack error */ - } + luaD_checkCstack(L); + if (luau_precall(L, func, nResults) == PCRLUA) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 6e16e6f1..5e9472bf 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -49,6 +49,7 @@ LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, pt LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); LUAI_FUNC void luaD_growstack(lua_State* L, int n); +LUAI_FUNC void luaD_checkCstack(lua_State* L); LUAI_FUNC l_noret luaD_throw(lua_State* L, int errcode); LUAI_FUNC int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 9e2eb268..3c505fe5 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -181,7 +181,7 @@ LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) ++L->nCcalls; if (L->nCcalls >= LUAI_MAXCCALLS) - luaG_runerror(L, "C stack overflow"); + luaD_checkCstack(L); luaD_checkstack(L, LUA_MINSTACK); diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 297cf011..d8dc9bd2 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -167,6 +167,81 @@ if not limitedstack then end end +-- C stack overflow +if not limitedstack then + local count = 1 + local cso = setmetatable({}, { + __index = function(self, i) + count = count + 1 + return self[i] + end, + __newindex = function(self, i, v) + count = count + 1 + self[i] = v + end, + __tostring = function(self) + count = count + 1 + return tostring(self) + end + }) + + local ehline + local function ehassert(cond) + if not cond then + ehline = debug.info(2, "l") + error() + end + end + + local userdata = newproxy(true) + getmetatable(userdata).__index = print + assert(debug.info(print, "s") == "[C]") + + local s, e = xpcall(tostring, function(e) + ehassert(string.find(e, "C stack overflow")) + print("after __tostring C stack overflow", count) -- 198: 1 resume + 1 xpcall + 198 luaB_tostring calls (which runs our __tostring successfully 197 times, erroring on the last attempt) + ehassert(count > 1) + + local ps, pe + + -- __tostring overflow (lua_call) + count = 1 + ps, pe = pcall(tostring, cso) + print("after __tostring overflow in handler", count) -- 23: xpcall error handler + pcall + 23 luaB_tostring calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- __index overflow (callTMres) + count = 1 + ps, pe = pcall(function() return cso[cso] end) + print("after __index overflow in handler", count) -- 23: xpcall error handler + pcall + 23 __index calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- __newindex overflow (callTM) + count = 1 + ps, pe = pcall(function() cso[cso] = "kohuke" end) + print("after __newindex overflow in handler", count) -- 23: xpcall error handler + pcall + 23 __newindex calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- test various C __index invocations on userdata + ehassert(pcall(function() return userdata[userdata] end)) -- LOP_GETTABLE + ehassert(pcall(function() return userdata[1] end)) -- LOP_GETTABLEN + ehassert(pcall(function() return userdata.StringConstant end)) -- LOP_GETTABLEKS (luau_callTM) + + -- lua_resume test + local coro = coroutine.create(function() end) + ps, pe = coroutine.resume(coro) + ehassert(not ps and string.find(pe, "C stack overflow")) + + return true + end, cso) + + assert(not s) + assert(e == true, "error in xpcall eh, line " .. tostring(ehline)) +end + --[[ local i=1 while stack[i] ~= l1 do From 61766a692c53cae3ea47408fb0c38f9b1af786d8 Mon Sep 17 00:00:00 2001 From: rblanckaert <63755228+rblanckaert@users.noreply.github.com> Date: Thu, 26 May 2022 15:08:16 -0700 Subject: [PATCH 073/102] Sync to upstream/release/529 (#505) * Adds a currently unused x86-64 assembler as a prerequisite for possible future JIT compilation * Fix a bug in table iteration (closes Possible table iteration bug #504) * Improved warning method when function is used as a type * Fix a bug with unsandboxed iteration with pairs() * Type of coroutine.status() is now a union of value types * Bytecode output for tests/debugging now has labels * Improvements to loop unrolling cost estimation * Report errors when the key obviously doesn't exist in the table --- Analysis/include/Luau/Frontend.h | 16 +- Analysis/include/Luau/Instantiation.h | 53 + Analysis/include/Luau/TypeArena.h | 2 +- Analysis/include/Luau/TypeInfer.h | 39 - Analysis/include/Luau/Unifier.h | 3 + Analysis/include/Luau/UnifierSharedState.h | 1 - Analysis/include/Luau/VisitTypeVar.h | 196 --- Analysis/src/Autocomplete.cpp | 31 +- Analysis/src/Frontend.cpp | 108 +- Analysis/src/Instantiation.cpp | 128 ++ Analysis/src/Module.cpp | 12 +- Analysis/src/Normalize.cpp | 265 +--- Analysis/src/Quantify.cpp | 3 +- Analysis/src/ToString.cpp | 42 +- Analysis/src/TypeArena.cpp | 2 +- Analysis/src/TypeInfer.cpp | 233 +--- Analysis/src/Unifier.cpp | 172 +-- Ast/include/Luau/TimeTrace.h | 2 +- Ast/src/Parser.cpp | 13 + CMakeLists.txt | 15 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 169 +++ CodeGen/include/Luau/Condition.h | 46 + CodeGen/include/Luau/Label.h | 18 + CodeGen/include/Luau/OperandX64.h | 136 +++ CodeGen/include/Luau/RegisterX64.h | 116 ++ CodeGen/src/AssemblyBuilderX64.cpp | 1005 ++++++++++++++++ {Compiler => Common}/include/Luau/Bytecode.h | 0 {Ast => Common}/include/Luau/Common.h | 0 Compiler/include/Luau/BytecodeBuilder.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 124 +- Compiler/src/Compiler.cpp | 179 ++- Compiler/src/CostModel.cpp | 124 +- Compiler/src/CostModel.h | 3 + Makefile | 30 +- Sources.cmake | 25 +- VM/src/lapi.cpp | 7 +- VM/src/lbuiltins.cpp | 4 +- VM/src/lbytecode.h | 5 +- VM/src/lcommon.h | 6 +- VM/src/ldo.cpp | 1 + VM/src/lvmexecute.cpp | 20 +- tests/AssemblyBuilderX64.test.cpp | 410 +++++++ tests/Autocomplete.test.cpp | 22 +- tests/Compiler.test.cpp | 1129 +++++++++++------- tests/Conformance.test.cpp | 174 ++- tests/CostModel.test.cpp | 6 +- tests/JsonEncoder.test.cpp | 77 +- tests/NonstrictMode.test.cpp | 5 +- tests/Normalize.test.cpp | 36 +- tests/Parser.test.cpp | 11 + tests/RuntimeLimits.test.cpp | 2 +- tests/ToDot.test.cpp | 10 +- tests/TypeInfer.builtins.test.cpp | 14 +- tests/TypeInfer.generics.test.cpp | 2 +- tests/TypeInfer.intersectionTypes.test.cpp | 2 +- tests/TypeInfer.modules.test.cpp | 4 +- tests/TypeInfer.provisional.test.cpp | 5 +- tests/TypeInfer.refinements.test.cpp | 7 +- tests/TypeInfer.singletons.test.cpp | 27 +- tests/TypeInfer.tables.test.cpp | 37 +- tests/TypeInfer.test.cpp | 11 +- tests/TypeInfer.typePacks.cpp | 8 +- tests/TypeVar.test.cpp | 3 +- tests/VisitTypeVar.test.cpp | 2 - tests/conformance/errors.lua | 5 + tests/conformance/nextvar.lua | 25 + tests/conformance/userdata.lua | 45 + tools/natvis/CodeGen.natvis | 50 + tools/patchtests.py | 76 ++ 69 files changed, 3837 insertions(+), 1724 deletions(-) create mode 100644 Analysis/include/Luau/Instantiation.h create mode 100644 Analysis/src/Instantiation.cpp create mode 100644 CodeGen/include/Luau/AssemblyBuilderX64.h create mode 100644 CodeGen/include/Luau/Condition.h create mode 100644 CodeGen/include/Luau/Label.h create mode 100644 CodeGen/include/Luau/OperandX64.h create mode 100644 CodeGen/include/Luau/RegisterX64.h create mode 100644 CodeGen/src/AssemblyBuilderX64.cpp rename {Compiler => Common}/include/Luau/Bytecode.h (100%) rename {Ast => Common}/include/Luau/Common.h (100%) create mode 100644 tests/AssemblyBuilderX64.test.cpp create mode 100644 tests/conformance/userdata.lua create mode 100644 tools/natvis/CodeGen.natvis create mode 100644 tools/patchtests.py diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 37e3cfdc..d7c9ca40 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -12,9 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSeparateTypechecks) -LUAU_FASTFLAG(LuauDirtySourceModule) - namespace Luau { @@ -60,17 +57,12 @@ struct SourceNode { bool hasDirtySourceModule() const { - LUAU_ASSERT(FFlag::LuauDirtySourceModule); - return dirtySourceModule; } bool hasDirtyModule(bool forAutocomplete) const { - if (FFlag::LuauSeparateTypechecks) - return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; - else - return dirtyModule; + return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; } ModuleName name; @@ -90,10 +82,6 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, once in the regular mode, and once in strict mode - // in order to get more precise type information (e.g. for autocomplete). - bool typecheckTwice_DEPRECATED = false; - // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) bool forAutocomplete = false; }; @@ -171,7 +159,7 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h new file mode 100644 index 00000000..e05ceebe --- /dev/null +++ b/Analysis/include/Luau/Instantiation.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifiable.h" + +namespace Luau +{ + +struct TypeArena; +struct TxnLog; + +// A substitution which replaces generic types in a given set by free types. +struct ReplaceGenerics : Substitution +{ + ReplaceGenerics( + const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) + : Substitution(log, arena) + , level(level) + , generics(generics) + , genericPacks(genericPacks) + { + } + + TypeLevel level; + std::vector generics; + std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation : Substitution +{ + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) + : Substitution(log, arena) + , level(level) + { + } + + TypeLevel level; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 7c74158b..559c55c8 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -39,4 +39,4 @@ struct TypeArena void freeze(TypeArena& arena); void unfreeze(TypeArena& arena); -} +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index fcaf5baa..183cc053 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -34,45 +34,6 @@ const AstStat* getFallthrough(const AstStat* node); struct UnifierOptions; struct Unifier; -// A substitution which replaces generic types in a given set by free types. -struct ReplaceGenerics : Substitution -{ - ReplaceGenerics( - const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) - : Substitution(log, arena) - , level(level) - , generics(generics) - , genericPacks(genericPacks) - { - } - - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - -// A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution -{ - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) - : Substitution(log, arena) - , level(level) - { - } - - TypeLevel level; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - // A substitution which replaces free types by any struct Anyification : Substitution { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0e24c8b0..627b52ca 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -32,6 +32,9 @@ struct Widen : Substitution TypeId clean(TypeId ty) override; TypePackId clean(TypePackId ty) override; bool ignoreChildren(TypeId ty) override; + + TypeId operator()(TypeId ty); + TypePackId operator()(TypePackId ty); }; // TODO: Use this more widely. diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 1a0b8b76..d4315d47 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -42,7 +42,6 @@ struct UnifierSharedState InternalErrorReporter* iceHandler; - DenseHashSet seenAny{nullptr}; DenseHashMap skipCacheForType{nullptr}; DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; DenseHashMap, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}}; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 2e98f526..f3839915 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -8,7 +8,6 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" -LUAU_FASTFLAG(LuauUseVisitRecursionLimit) LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) @@ -62,168 +61,6 @@ inline void unsee(DenseHashSet& seen, const void* tv) // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } -template -void visit(TypePackId tp, F& f, Set& seen); - -template -void visit(TypeId ty, F& f, Set& seen) -{ - if (visit_detail::hasSeen(seen, ty)) - { - f.cycle(ty); - return; - } - - if (auto btv = get(ty)) - { - if (apply(ty, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(ty)) - apply(ty, *ftv, seen, f); - - else if (auto gtv = get(ty)) - apply(ty, *gtv, seen, f); - - else if (auto etv = get(ty)) - apply(ty, *etv, seen, f); - - else if (auto ctv = get(ty)) - { - if (apply(ty, *ctv, seen, f)) - { - for (TypeId part : ctv->parts) - visit(part, f, seen); - } - } - - else if (auto ptv = get(ty)) - apply(ty, *ptv, seen, f); - - else if (auto ftv = get(ty)) - { - if (apply(ty, *ftv, seen, f)) - { - visit(ftv->argTypes, f, seen); - visit(ftv->retType, f, seen); - } - } - - else if (auto ttv = get(ty)) - { - // Some visitors want to see bound tables, that's why we visit the original type - if (apply(ty, *ttv, seen, f)) - { - if (ttv->boundTo) - { - visit(*ttv->boundTo, f, seen); - } - else - { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) - { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); - } - } - } - } - - else if (auto mtv = get(ty)) - { - if (apply(ty, *mtv, seen, f)) - { - visit(mtv->table, f, seen); - visit(mtv->metatable, f, seen); - } - } - - else if (auto ctv = get(ty)) - { - if (apply(ty, *ctv, seen, f)) - { - for (const auto& [name, prop] : ctv->props) - visit(prop.type, f, seen); - - if (ctv->parent) - visit(*ctv->parent, f, seen); - - if (ctv->metatable) - visit(*ctv->metatable, f, seen); - } - } - - else if (auto atv = get(ty)) - apply(ty, *atv, seen, f); - - else if (auto utv = get(ty)) - { - if (apply(ty, *utv, seen, f)) - { - for (TypeId optTy : utv->options) - visit(optTy, f, seen); - } - } - - else if (auto itv = get(ty)) - { - if (apply(ty, *itv, seen, f)) - { - for (TypeId partTy : itv->parts) - visit(partTy, f, seen); - } - } - - visit_detail::unsee(seen, ty); -} - -template -void visit(TypePackId tp, F& f, Set& seen) -{ - if (visit_detail::hasSeen(seen, tp)) - { - f.cycle(tp); - return; - } - - if (auto btv = get(tp)) - { - if (apply(tp, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(tp)) - apply(tp, *ftv, seen, f); - - else if (auto gtv = get(tp)) - apply(tp, *gtv, seen, f); - - else if (auto etv = get(tp)) - apply(tp, *etv, seen, f); - - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - - for (TypeId ty : pack->head) - visit(ty, f, seen); - - if (pack->tail) - visit(*pack->tail, f, seen); - } - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - visit(pack->ty, f, seen); - } - - visit_detail::unsee(seen, tp); -} - } // namespace visit_detail template @@ -513,37 +350,4 @@ struct TypeVarOnceVisitor : GenericTypeVarVisitor> } }; -// Clip with FFlagLuauUseVisitRecursionLimit -template -void DEPRECATED_visitTypeVar(TID ty, F& f, std::unordered_set& seen) -{ - visit_detail::visit(ty, f, seen); -} - -// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit -template -void DEPRECATED_visitTypeVar(TID ty, F& f) -{ - if (FFlag::LuauUseVisitRecursionLimit) - f.traverse(ty); - else - { - std::unordered_set seen; - visit_detail::visit(ty, f, seen); - } -} - -// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit -template -void DEPRECATED_visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) -{ - if (FFlag::LuauUseVisitRecursionLimit) - f.traverse(ty); - else - { - seen.clear(); - visit_detail::visit(ty, f, seen); - } -} - } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 19d06cfc..b988ed35 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1700,31 +1700,18 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - if (FFlag::LuauSeparateTypechecks) - { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); - } - else - { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. - frontend.check(moduleName); - } + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; + ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); if (!module) return {}; @@ -1752,9 +1739,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - + TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); OwningAutocompleteResult autocompleteResult = { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 56c0ac2c..1d33f131 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -20,9 +20,7 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) -LUAU_FASTFLAGVARIABLE(LuauDirtySourceModule, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) namespace Luau @@ -361,32 +359,21 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. - if (FFlag::LuauSeparateTypechecks) + if (frontendOptions.forAutocomplete) { - if (frontendOptions.forAutocomplete) - { - auto it2 = moduleResolverForAutocomplete.modules.find(name); - if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); - } - else - { - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); - } - - return CheckResult{accumulateErrors( - sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; + auto it2 = moduleResolverForAutocomplete.modules.find(name); + if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); } else { auto it2 = moduleResolver.modules.find(name); if (it2 == moduleResolver.modules.end() || it2->second == nullptr) throw std::runtime_error("Frontend::modules does not have data for " + name); - - return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; } + + return CheckResult{ + accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; } std::vector buildQueue; @@ -428,7 +415,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, CheckResult& chec bool cyclic = false; { - auto [sourceNode, _] = getSourceNode(checkResult, root, forAutocomplete); + auto [sourceNode, _] = getSourceNode(checkResult, root); if (sourceNode) stack.push_back(sourceNode); } @@ -627,7 +603,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec } } - auto [sourceNode, _] = getSourceNode(checkResult, dep, forAutocomplete); + auto [sourceNode, _] = getSourceNode(checkResult, dep); if (sourceNode) { stack.push_back(sourceNode); @@ -671,7 +647,7 @@ LintResult Frontend::lint(const ModuleName& name, std::optional* markedDirty) { - if (FFlag::LuauSeparateTypechecks) - { - if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) - return; - } - else - { - if (!moduleResolver.modules.count(name)) - return; - } + if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) + return; std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) @@ -783,32 +751,12 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (FFlag::LuauDirtySourceModule) - { - LUAU_ASSERT(FFlag::LuauSeparateTypechecks); + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; - if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) - continue; - - sourceNode.dirtySourceModule = true; - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - } - else if (FFlag::LuauSeparateTypechecks) - { - if (sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) - continue; - - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - } - else - { - if (sourceNode.dirtyModule) - continue; - - sourceNode.dirtyModule = true; - } + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; if (0 == reverseDeps.count(name)) continue; @@ -835,14 +783,13 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED) +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && - (FFlag::LuauDirtySourceModule ? !it->second.hasDirtySourceModule() : !it->second.hasDirtyModule(forAutocomplete_DEPRECATED))) + if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -885,21 +832,12 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.name = name; sourceNode.requires.clear(); sourceNode.requireLocations.clear(); + sourceNode.dirtySourceModule = false; - if (FFlag::LuauDirtySourceModule) - sourceNode.dirtySourceModule = false; - - if (FFlag::LuauSeparateTypechecks) - { - if (it == sourceNodes.end()) - { - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - } - } - else + if (it == sourceNodes.end()) { sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; } for (const auto& [moduleName, location] : requireTrace.requires) diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp new file mode 100644 index 00000000..4a12027d --- /dev/null +++ b/Analysis/src/Instantiation.cpp @@ -0,0 +1,128 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "Luau/Instantiation.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeArena.h" + +LUAU_FASTFLAG(LuauNoMethodLocations) + +namespace Luau +{ + +bool Instantiation::isDirty(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return false; + + return true; + } + else + { + return false; + } +} + +bool Instantiation::isDirty(TypePackId tp) +{ + return false; +} + +bool Instantiation::ignoreChildren(TypeId ty) +{ + if (log->getMutable(ty)) + return true; + else + return false; +} + +TypeId Instantiation::clean(TypeId ty) +{ + const FunctionTypeVar* ftv = log->getMutable(ty); + LUAU_ASSERT(ftv); + + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +TypePackId Instantiation::clean(TypePackId tp) +{ + LUAU_ASSERT(false); + return tp; +} + +bool ReplaceGenerics::ignoreChildren(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return true; + + // We aren't recursing in the case of a generic function which + // binds the same generics. This can happen if, for example, there's recursive types. + // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. + // It's OK to use vector equality here, since we always generate fresh generics + // whenever we quantify, so the vectors overlap if and only if they are equal. + return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } + else + { + return false; + } +} + +bool ReplaceGenerics::isDirty(TypeId ty) +{ + if (const TableTypeVar* ttv = log->getMutable(ty)) + return ttv->state == TableState::Generic; + else if (log->getMutable(ty)) + return std::find(generics.begin(), generics.end(), ty) != generics.end(); + else + return false; +} + +bool ReplaceGenerics::isDirty(TypePackId tp) +{ + if (log->getMutable(tp)) + return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); + else + return false; +} + +TypeId ReplaceGenerics::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = log->getMutable(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + return addType(FreeTypeVar{level}); +} + +TypePackId ReplaceGenerics::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return addTypePack(TypePackVar(FreeTypePack{level})); +} + +} // namespace Luau diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 074a41e6..6591d60a 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -56,8 +56,18 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) struct ForceNormal : TypeVarOnceVisitor { + const TypeArena* typeArena = nullptr; + + ForceNormal(const TypeArena* typeArena) + : typeArena(typeArena) + { + } + bool visit(TypeId ty) override { + if (ty->owningArena != typeArena) + return false; + asMutable(ty)->normal = true; return true; } @@ -100,7 +110,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) normalize(*moduleScope->varargPack, interfaceTypes, ice); } - ForceNormal forceNormal; + ForceNormal forceNormal{&interfaceTypes}; for (auto& [name, tf] : moduleScope->exportedTypeBindings) { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 30fd4af2..fb31df1e 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); namespace Luau { @@ -325,245 +326,6 @@ struct Normalize final : TypeVarVisitor int iterationLimit = 0; bool limitExceeded = false; - // TODO: Clip with FFlag::LuauUseVisitRecursionLimit - bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set& seen) - { - // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. - // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. - if (seen.find(asMutable(btv.boundTo)) != seen.end()) - return false; - - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. - LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); - - asMutable(ty)->normal = btv.boundTo->normal; - return !ty->normal; - } - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const PrimitiveTypeVar& ptv) - { - return visit(ty, ptv); - } - bool operator()(TypeId ty, const GenericTypeVar& gtv) - { - return visit(ty, gtv); - } - bool operator()(TypeId ty, const ErrorTypeVar& etv) - { - return visit(ty, etv); - } - bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - ConstrainedTypeVar* ctv = const_cast(&ctvRef); - - std::vector parts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId part : parts) - visit_detail::visit(part, *this, seen); - - std::vector newParts = normalizeUnion(parts); - - const bool normal = areNormal(newParts, seen, ice); - - if (newParts.size() == 1) - *asMutable(ty) = BoundTypeVar{newParts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(newParts)}; - - asMutable(ty)->normal = normal; - - return false; - } - - bool operator()(TypeId ty, const FunctionTypeVar& ftv, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - visit_detail::visit(ftv.argTypes, *this, seen); - visit_detail::visit(ftv.retType, *this, seen); - - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); - - return false; - } - - bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - bool normal = true; - - auto checkNormal = [&](TypeId t) { - // if t is on the stack, it is possible that this type is normal. - // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && seen.find(asMutable(t)) == seen.end()) - normal = false; - }; - - if (ttv.boundTo) - { - visit_detail::visit(*ttv.boundTo, *this, seen); - asMutable(ty)->normal = (*ttv.boundTo)->normal; - return false; - } - - for (const auto& [_name, prop] : ttv.props) - { - visit_detail::visit(prop.type, *this, seen); - checkNormal(prop.type); - } - - if (ttv.indexer) - { - visit_detail::visit(ttv.indexer->indexType, *this, seen); - checkNormal(ttv.indexer->indexType); - visit_detail::visit(ttv.indexer->indexResultType, *this, seen); - checkNormal(ttv.indexer->indexResultType); - } - - asMutable(ty)->normal = normal; - - return false; - } - - bool operator()(TypeId ty, const MetatableTypeVar& mtv, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - visit_detail::visit(mtv.table, *this, seen); - visit_detail::visit(mtv.metatable, *this, seen); - - asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; - - return false; - } - - bool operator()(TypeId ty, const ClassTypeVar& ctv) - { - return visit(ty, ctv); - } - bool operator()(TypeId ty, const AnyTypeVar& atv) - { - return visit(ty, atv); - } - bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - UnionTypeVar* utv = &const_cast(utvRef); - std::vector options = std::move(utv->options); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : options) - visit_detail::visit(option, *this, seen); - - std::vector newOptions = normalizeUnion(options); - - const bool normal = areNormal(newOptions, seen, ice); - - LUAU_ASSERT(!newOptions.empty()); - - if (newOptions.size() == 1) - *asMutable(ty) = BoundTypeVar{newOptions[0]}; - else - utv->options = std::move(newOptions); - - asMutable(ty)->normal = normal; - - return false; - } - - bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - IntersectionTypeVar* itv = &const_cast(itvRef); - - std::vector oldParts = std::move(itv->parts); - - for (TypeId part : oldParts) - visit_detail::visit(part, *this, seen); - - std::vector tables; - for (TypeId part : oldParts) - { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, itv, part); - } - } - - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - itv->parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); - - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } - - itv->parts.push_back(newTable); - } - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - - if (itv->parts.size() == 1) - { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; - } - - return false; - } - - // TODO: Clip with FFlag::LuauUseVisitRecursionLimit - template - bool operator()(TypePackId, const T&) - { - return true; - } - - // TODO: Clip with FFlag::LuauUseVisitRecursionLimit - template - void cycle(TID) - { - } - bool visit(TypeId ty, const FreeTypeVar&) override { LUAU_ASSERT(!ty->normal); @@ -968,6 +730,9 @@ struct Normalize final : TypeVarVisitor */ TypeId combine(Replacer& replacer, TypeId a, TypeId b) { + if (FFlag::LuauNormalizeCombineEqFix) + b = follow(b); + if (FFlag::LuauNormalizeCombineTableFix && a == b) return a; @@ -986,7 +751,7 @@ struct Normalize final : TypeVarVisitor } else if (auto ttv = getMutable(a)) { - if (FFlag::LuauNormalizeCombineTableFix && !get(follow(b))) + if (FFlag::LuauNormalizeCombineTableFix && !get(FFlag::LuauNormalizeCombineEqFix ? b : follow(b))) return arena.addType(IntersectionTypeVar{{a, b}}); combineIntoTable(replacer, ttv, b); return a; @@ -1009,15 +774,7 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo (void)clone(ty, arena, state); Normalize n{arena, ice}; - if (FFlag::LuauNormalizeFlagIsConservative) - { - DEPRECATED_visitTypeVar(ty, n); - } - else - { - std::unordered_set seen; - DEPRECATED_visitTypeVar(ty, n, seen); - } + n.traverse(ty); return {ty, !n.limitExceeded}; } @@ -1041,15 +798,7 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE (void)clone(tp, arena, state); Normalize n{arena, ice}; - if (FFlag::LuauNormalizeFlagIsConservative) - { - DEPRECATED_visitTypeVar(tp, n); - } - else - { - std::unordered_set seen; - DEPRECATED_visitTypeVar(tp, n, seen); - } + n.traverse(tp); return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 018d5632..c0f677d7 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -119,8 +119,7 @@ struct Quantifier final : TypeVarOnceVisitor void quantify(TypeId ty, TypeLevel level) { Quantifier q{level}; - DenseHashSet seen{nullptr}; - DEPRECATED_visitTypeVarOnce(ty, q, seen); + q.traverse(ty); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f90f7019..a4a3ec49 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -48,46 +48,6 @@ struct FindCyclicTypes final : TypeVarVisitor cycleTPs.insert(tp); } - // TODO: Clip all the operator()s when we clip FFlagLuauUseVisitRecursionLimit - - template - bool operator()(TypeId ty, const T&) - { - return visit(ty); - } - - bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; - - bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) - { - if (!visited.insert(ty).second) - return false; - - if (ttv.name || ttv.syntheticName) - { - for (TypeId itp : ttv.instantiatedTypeParams) - DEPRECATED_visitTypeVar(itp, *this, seen); - - for (TypePackId itp : ttv.instantiatedTypePackParams) - DEPRECATED_visitTypeVar(itp, *this, seen); - - return exhaustive; - } - - return true; - } - - bool operator()(TypeId, const ClassTypeVar&) - { - return false; - } - - template - bool operator()(TypePackId tp, const T&) - { - return visit(tp); - } - bool visit(TypeId ty) override { return visited.insert(ty).second; @@ -128,7 +88,7 @@ void findCyclicTypes(std::set& cycles, std::set& cycleTPs, T { FindCyclicTypes fct; fct.exhaustive = exhaustive; - DEPRECATED_visitTypeVar(ty, fct); + fct.traverse(ty); cycles = std::move(fct.cycles); cycleTPs = std::move(fct.cycleTPs); diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 673b002d..0c89d130 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -85,4 +85,4 @@ void unfreeze(TypeArena& arena) arena.typePacks.unfreeze(); } -} +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 208b3f2f..11813c76 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,6 +3,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" #include "Luau/Parser.h" @@ -10,13 +11,13 @@ #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" -#include "Luau/TopoSortStatements.h" -#include "Luau/TypePack.h" -#include "Luau/ToString.h" -#include "Luau/TypeUtils.h" -#include "Luau/ToString.h" -#include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" +#include "Luau/TopoSortStatements.h" +#include "Luau/ToString.h" +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" #include #include @@ -26,14 +27,11 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(LuauUseVisitRecursionLimit, false) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) -LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) @@ -43,7 +41,6 @@ LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) -LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); @@ -51,6 +48,7 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); +LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) namespace Luau { @@ -305,12 +303,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule.reset(new Module()); currentModule->type = module.type; - - if (FFlag::LuauSeparateTypechecks) - { - currentModule->allocator = module.allocator; - currentModule->names = module.names; - } + currentModule->allocator = module.allocator; + currentModule->names = module.names; iceHandler->moduleName = module.name; @@ -338,21 +332,14 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo if (prepareModuleScope) prepareModuleScope(module.name, currentModule->getModuleScope()); - if (FFlag::LuauSeparateTypechecks) - { - try - { - checkBlock(moduleScope, *module.root); - } - catch (const TimeLimitError&) - { - currentModule->timeout = true; - } - } - else + try { checkBlock(moduleScope, *module.root); } + catch (const TimeLimitError&) + { + currentModule->timeout = true; + } if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); @@ -443,7 +430,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) else ice("Unknown AstStat"); - if (FFlag::LuauSeparateTypechecks && finishTime && TimeTrace::getClock() > *finishTime) + if (finishTime && TimeTrace::getClock() > *finishTime) throw TimeLimitError(); } @@ -868,9 +855,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId right = nullptr; - Location loc = 0 == assign.values.size - ? assign.location - : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; + Location loc = 0 == assign.values.size ? assign.location + : i < assign.values.size ? assign.values.data[i]->location + : assign.values.data[assign.values.size - 1]->location; if (valueIter != valueEnd) { @@ -1825,7 +1812,7 @@ std::optional TypeChecker::findMetatableEntry(TypeId type, std::string e } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const std::string& name, const Location& location, bool addErrors) { type = follow(type); @@ -1843,13 +1830,25 @@ std::optional TypeChecker::getIndexTypeFromType( if (TableTypeVar* tableType = getMutableTableType(type)) { - const auto& it = tableType->props.find(name); - if (it != tableType->props.end()) + if (auto it = tableType->props.find(name); it != tableType->props.end()) return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(stringType, indexer->indexType, location); - return indexer->indexResultType; + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, location); + + if (FFlag::LuauReportErrorsOnIndexerKeyMismatch) + { + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; + } + else + return indexer->indexResultType; } else if (tableType->state == TableState::Free) { @@ -1858,8 +1857,7 @@ std::optional TypeChecker::getIndexTypeFromType( return result; } - auto found = findTablePropertyRespectingMeta(type, name, location); - if (found) + if (auto found = findTablePropertyRespectingMeta(type, name, location)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -2512,8 +2510,9 @@ TypeId TypeChecker::checkRelationalOperation( if (!matches) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + reportError( + expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } @@ -2522,8 +2521,9 @@ TypeId TypeChecker::checkRelationalOperation( { if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + reportError( + expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } @@ -3636,10 +3636,7 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - if (FFlag::LuauWidenIfSupertypeIsFree2) - state.tryUnify(varPack, tail); - else - state.tryUnify(tail, varPack); + state.tryUnify(varPack, tail); return; } @@ -3707,7 +3704,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } TypePackId retPack; - if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) + if (FFlag::LuauLowerBoundsCalculation) { retPack = freshTypePack(scope->level); } @@ -3868,9 +3865,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope Widen widen{¤tModule->internalTypes}; for (; it != endIt; ++it) { - TypeId t = *it; - TypeId widened = widen.substitute(t).value_or(t); // Surely widening is infallible - adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widened}})); + adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widen(*it)}})); } TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); @@ -3885,14 +3880,11 @@ std::optional> TypeChecker::checkCallOverload(const Scope else { TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree2) - { - UnifierOptions options; - options.isFunctionCall = true; - unify(r, fn, expr.location, options); - } - else - unify(fn, r, expr.location); + + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + return {{retPack}}; } } @@ -4375,122 +4367,6 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s } } -bool Instantiation::isDirty(TypeId ty) -{ - if (const FunctionTypeVar* ftv = log->getMutable(ty)) - { - if (ftv->hasNoGenerics) - return false; - - return true; - } - else - { - return false; - } -} - -bool Instantiation::isDirty(TypePackId tp) -{ - return false; -} - -bool Instantiation::ignoreChildren(TypeId ty) -{ - if (log->getMutable(ty)) - return true; - else - return false; -} - -TypeId Instantiation::clean(TypeId ty) -{ - const FunctionTypeVar* ftv = log->getMutable(ty); - LUAU_ASSERT(ftv); - - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); - - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; - - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - -TypePackId Instantiation::clean(TypePackId tp) -{ - LUAU_ASSERT(false); - return tp; -} - -bool ReplaceGenerics::ignoreChildren(TypeId ty) -{ - if (const FunctionTypeVar* ftv = log->getMutable(ty)) - { - if (ftv->hasNoGenerics) - return true; - - // We aren't recursing in the case of a generic function which - // binds the same generics. This can happen if, for example, there's recursive types. - // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. - // It's OK to use vector equality here, since we always generate fresh generics - // whenever we quantify, so the vectors overlap if and only if they are equal. - return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); - } - else - { - return false; - } -} - -bool ReplaceGenerics::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = log->getMutable(ty)) - return ttv->state == TableState::Generic; - else if (log->getMutable(ty)) - return std::find(generics.begin(), generics.end(), ty) != generics.end(); - else - return false; -} - -bool ReplaceGenerics::isDirty(TypePackId tp) -{ - if (log->getMutable(tp)) - return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); - else - return false; -} - -TypeId ReplaceGenerics::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) - { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - return addType(FreeTypeVar{level}); -} - -TypePackId ReplaceGenerics::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); -} - bool Anyification::isDirty(TypeId ty) { if (ty->persistent) @@ -5295,7 +5171,7 @@ TypeId ApplyTypeFunction::clean(TypeId ty) { TypeId& arg = typeArguments[ty]; if (FFlag::LuauApplyTypeFunctionFix) - { + { LUAU_ASSERT(arg); return arg; } @@ -5309,7 +5185,7 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) { TypePackId& arg = typePackArguments[tp]; if (FFlag::LuauApplyTypeFunctionFix) - { + { LUAU_ASSERT(arg); return arg; } @@ -5837,9 +5713,6 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. auto predicate = [&](TypeId option) -> std::optional { - if (sense && isUndecidable(option)) - return FFlag::LuauWeakEqConstraint ? option : eqP.type; - if (!sense && isNil(eqP.type)) return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9308e9ff..414b05f4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,8 +21,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) -LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) namespace Luau @@ -149,8 +147,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel return; PromoteTypeLevels ptl{log, typeArena, minLevel}; - DenseHashSet seen{nullptr}; - DEPRECATED_visitTypeVarOnce(ty, ptl, seen); + ptl.traverse(ty); } void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) @@ -160,8 +157,7 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev return; PromoteTypeLevels ptl{log, typeArena, minLevel}; - DenseHashSet seen{nullptr}; - DEPRECATED_visitTypeVarOnce(tp, ptl, seen); + ptl.traverse(tp); } struct SkipCacheForType final : TypeVarOnceVisitor @@ -172,49 +168,6 @@ struct SkipCacheForType final : TypeVarOnceVisitor { } - // TODO cycle() and operator() can be clipped with FFlagLuauUseVisitRecursionLimit - void cycle(TypeId) override {} - void cycle(TypePackId) override {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const BoundTypeVar& btv) - { - return visit(ty, btv); - } - bool operator()(TypeId ty, const GenericTypeVar& gtv) - { - return visit(ty, gtv); - } - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - bool operator()(TypePackId tp, const BoundTypePack& ftp) - { - return visit(tp, ftp); - } - bool operator()(TypePackId tp, const GenericTypePack& ftp) - { - return visit(tp, ftp); - } - template - bool operator()(TypeId ty, const T& t) - { - return visit(ty); - } - template - bool operator()(TypePackId tp, const T&) - { - return visit(tp); - } - bool visit(TypeId, const FreeTypeVar&) override { result = true; @@ -341,6 +294,16 @@ bool Widen::ignoreChildren(TypeId ty) return !log->is(ty); } +TypeId Widen::operator()(TypeId ty) +{ + return substitute(ty).value_or(ty); +} + +TypePackId Widen::operator()(TypePackId tp) +{ + return substitute(tp).value_or(tp); +} + static std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { @@ -475,6 +438,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursFailed) { promoteTypeLevels(log, types, superLevel, subTy); + + Widen widen{types}; log.replace(superTy, BoundTypeVar(widen(subTy))); } @@ -612,9 +577,6 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId std::optional unificationTooComplex; std::optional firstFailedOption; - size_t count = uv->options.size(); - size_t i = 0; - for (TypeId type : uv->options) { Unifier innerState = makeChildUnifier(); @@ -630,60 +592,44 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) - { - } - else - { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } - - ++i; - } } // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) - { - auto tryBind = [this, subTy](TypeId superOption) { - superOption = log.follow(superOption); + auto tryBind = [this, subTy](TypeId superOption) { + superOption = log.follow(superOption); - // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) - return; + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + return; - // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype - // test is successful. - if (auto subUnion = get(subTy)) - { - if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) - return; - } - - // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. - // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - }; - - if (auto utv = log.getMutable(superTy)) + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) { - for (TypeId ty : utv) - tryBind(ty); + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; } - else - tryBind(superTy); + + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + }; + + if (auto utv = log.getMutable(superTy)) + { + for (TypeId ty : utv) + tryBind(ty); } + else + tryBind(superTy); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -883,7 +829,7 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; - DEPRECATED_visitTypeVarOnce(ty, visitor, sharedState.seenAny); + visitor.traverse(ty); sharedState.skipCacheForType[ty] = visitor.result; @@ -1088,6 +1034,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (!log.getMutable(superTp)) { + Widen widen{types}; log.replace(superTp, Unifiable::Bound(widen(subTp))); } } @@ -1671,28 +1618,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } -TypeId Unifier::widen(TypeId ty) -{ - if (!FFlag::LuauWidenIfSupertypeIsFree2) - return ty; - - Widen widen{types}; - std::optional result = widen.substitute(ty); - // TODO: what does it mean for substitution to fail to widen? - return result.value_or(ty); -} - -TypePackId Unifier::widen(TypePackId tp) -{ - if (!FFlag::LuauWidenIfSupertypeIsFree2) - return tp; - - Widen widen{types}; - std::optional result = widen.substitute(tp); - // TODO: what does it mean for substitution to fail to widen? - return result.value_or(tp); -} - TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); @@ -1809,10 +1734,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - if (FFlag::LuauWidenIfSupertypeIsFree2) - tryUnify_(*subProp, freeProp.type); - else - tryUnify_(freeProp.type, *subProp); + tryUnify_(*subProp, freeProp.type); /* * TypeVars are commonly cyclic, so it is entirely possible diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 9f7b2bdf..be282827 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Common.h" +#include "Luau/Common.h" #include diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index c053e6bd..eaf19914 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,6 +11,8 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) + namespace Luau { @@ -1589,6 +1591,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return parseFunctionTypeAnnotation(allowPack); } + else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction) + { + Location location = lexer.current().location; + + nextLexeme(); + + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'"), + {}}; + } else { Location location = lexer.current().location; diff --git a/CMakeLists.txt b/CMakeLists.txt index ea352309..c624a132 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,9 +19,11 @@ if(LUAU_STATIC_CRT) endif() project(Luau LANGUAGES CXX C) +add_library(Luau.Common INTERFACE) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) +add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) add_library(isocline STATIC) @@ -48,8 +50,11 @@ endif() include(Sources.cmake) +target_include_directories(Luau.Common INTERFACE Common/include) + target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) +target_link_libraries(Luau.Ast PUBLIC Luau.Common) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) @@ -59,8 +64,13 @@ target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) +target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) +target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) +target_link_libraries(Luau.CodeGen PUBLIC Luau.Common) + target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) +target_link_libraries(Luau.VM PUBLIC Luau.Common) target_include_directories(isocline PUBLIC extern/isocline/include) @@ -101,6 +111,7 @@ endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) @@ -120,6 +131,7 @@ endif() if(MSVC) target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis) + target_link_options(Luau.CodeGen INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/CodeGen.natvis) target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis) endif() @@ -127,6 +139,7 @@ endif() if(MSVC_IDE) target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis) target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis) + target_sources(Luau.CodeGen PRIVATE tools/natvis/CodeGen.natvis) target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) endif() @@ -154,7 +167,7 @@ endif() if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) - target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) + target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h new file mode 100644 index 00000000..c5979d3c --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -0,0 +1,169 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Condition.h" +#include "Luau/Label.h" +#include "Luau/OperandX64.h" +#include "Luau/RegisterX64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderX64 +{ +public: + explicit AssemblyBuilderX64(bool logText); + ~AssemblyBuilderX64(); + + // Base two operand instructions with 9 opcode selection + void add(OperandX64 lhs, OperandX64 rhs); + void sub(OperandX64 lhs, OperandX64 rhs); + void cmp(OperandX64 lhs, OperandX64 rhs); + void and_(OperandX64 lhs, OperandX64 rhs); + void or_(OperandX64 lhs, OperandX64 rhs); + void xor_(OperandX64 lhs, OperandX64 rhs); + + // Binary shift instructions with special rhs handling + void sal(OperandX64 lhs, OperandX64 rhs); + void sar(OperandX64 lhs, OperandX64 rhs); + void shl(OperandX64 lhs, OperandX64 rhs); + void shr(OperandX64 lhs, OperandX64 rhs); + + // Two operand mov instruction has additional specialized encodings + void mov(OperandX64 lhs, OperandX64 rhs); + void mov64(RegisterX64 lhs, int64_t imm); + + // Base one operand instruction with 2 opcode selection + void div(OperandX64 op); + void idiv(OperandX64 op); + void mul(OperandX64 op); + void neg(OperandX64 op); + void not_(OperandX64 op); + + void test(OperandX64 lhs, OperandX64 rhs); + void lea(OperandX64 lhs, OperandX64 rhs); + + void push(OperandX64 op); + void pop(OperandX64 op); + void ret(); + + // Control flow + void jcc(Condition cond, Label& label); + void jmp(Label& label); + void jmp(OperandX64 op); + + // AVX + void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vsqrtpd(OperandX64 dst, OperandX64 src); + void vsqrtps(OperandX64 dst, OperandX64 src); + void vsqrtsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsqrtss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vmovsd(OperandX64 dst, OperandX64 src); + void vmovsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmovss(OperandX64 dst, OperandX64 src); + void vmovss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmovapd(OperandX64 dst, OperandX64 src); + void vmovaps(OperandX64 dst, OperandX64 src); + void vmovupd(OperandX64 dst, OperandX64 src); + void vmovups(OperandX64 dst, OperandX64 src); + + // Run final checks + void finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + // Constant allocation (uses rip-relative addressing) + OperandX64 i64(int64_t value); + OperandX64 f32(float value); + OperandX64 f64(double value); + OperandX64 f32x4(float x, float y, float z, float w); + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + +private: + // Instruction archetypes + void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, + uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg); + void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg); + void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); + void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); + + void placeUnaryModRegMem(const char* name, OperandX64 op, uint8_t code8, uint8_t code, uint8_t opreg); + + void placeShift(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t opreg); + + void placeJcc(const char* name, Label& label, uint8_t cc); + + void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix); + void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + + // Instruction components + void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs); + void placeModRegMem(OperandX64 rhs, uint8_t regop); + void placeRex(RegisterX64 op); + void placeRex(OperandX64 op); + void placeRex(RegisterX64 lhs, OperandX64 rhs); + void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix); + void placeImm8Or32(int32_t imm); + void placeImm8(int32_t imm); + void placeImm32(int32_t imm); + void placeImm64(int64_t imm); + void placeLabel(Label& label); + void place(uint8_t byte); + + void commit(); + LUAU_NOINLINE void extend(); + uint32_t getCodeSize(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form (Intel asm with VS disassembly formatting) + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2, OperandX64 op3); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(const char* opcode, Label label); + void log(OperandX64 op); + void logAppend(const char* fmt, ...); + + const char* getSizeName(SizeX64 size); + const char* getRegisterName(RegisterX64 reg); + + uint32_t nextLabel = 1; + std::vector(a) -> a" == toString(idType)); +} + +#if 1 +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") +{ + AstStatBlock* block = parse(R"( + local function a(c) + local function d(e) + return c + end + + return d + end + + local b = a(5) + )"); + + cgb.visit(block); + + ToStringOptions opts; + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId idType = requireBinding(cgb.rootScope, "b"); + + CHECK("(a) -> number" == toString(idType, opts)); +} +#endif + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 03f3e15c..232ec2de 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -17,6 +17,8 @@ static const char* mainModuleName = "MainModule"; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -249,7 +251,10 @@ std::optional Fixture::getType(const std::string& name) ModulePtr module = getMainModule(); REQUIRE(module); - return lookupName(module->getModuleScope(), name); + if (FFlag::DebugLuauDeferredConstraintResolution) + return linearSearchForBinding(module->getModuleScope2(), name.c_str()); + else + return lookupName(module->getModuleScope(), name); } TypeId Fixture::requireType(const std::string& name) @@ -421,6 +426,12 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); } +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ +} + ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -460,4 +471,27 @@ std::optional lookupName(ScopePtr scope, const std::string& name) return std::nullopt; } +std::optional linearSearchForBinding(Scope2* scope, const char* name) +{ + while (scope) + { + for (const auto& [n, ty] : scope->bindings) + { + if (n.astName() == name) + return ty; + } + + scope = scope->parent; + } + + return std::nullopt; +} + +void dump(const std::vector& constraints) +{ + ToStringOptions opts; + for (const auto& c : constraints) + printf("%s\n", toString(c, opts).c_str()); +} + } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 901f7d42..ffcd4b9e 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Config.h" +#include "Luau/ConstraintGraphBuilder.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" @@ -156,6 +157,16 @@ struct BuiltinsFixture : Fixture BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ConstraintGraphBuilder cgb{&arena}; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); +}; + ModuleName fromString(std::string_view name); template @@ -175,9 +186,12 @@ bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); +void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) +std::optional linearSearchForBinding(Scope2* scope, const char* name); + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 33b81be8..c0554669 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1031,8 +1031,6 @@ return false; TEST_CASE("check_without_builtin_next") { - ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; - TestFileResolver fileResolver; TestConfigResolver configResolver; Frontend frontend(&fileResolver, &configResolver); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp new file mode 100644 index 00000000..1a323c85 --- /dev/null +++ b/tests/NotNull.test.cpp @@ -0,0 +1,116 @@ +#include "Luau/NotNull.h" + +#include "doctest.h" + +#include +#include +#include + +using Luau::NotNull; + +namespace +{ + +struct Test +{ + int x; + float y; + + static int count; + Test() + { + ++count; + } + + ~Test() + { + --count; + } +}; + +int Test::count = 0; + +} + +int foo(NotNull p) +{ + return *p; +} + +void bar(int* q) +{} + +TEST_SUITE_BEGIN("NotNull"); + +TEST_CASE("basic_stuff") +{ + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + + // a = nullptr; // nope + + NotNull d = a; // No runtime test. a is known not to be null. + + int e = *d; + *d = 1; + CHECK(e == 55); + + const NotNull f = d; + *f = 5; // valid: there is a difference between const NotNull and NotNull + // f = a; // nope + + CHECK_EQ(a, d); + CHECK(a != b); + + NotNull g(a); + CHECK(g == a); + + // *g = 123; // nope + + (void)f; + + NotNull t{new Test}; + t->x = 5; + t->y = 3.14f; + + const NotNull u = t; + // u->x = 44; // nope + int v = u->x; + CHECK(v == 5); + + bar(a); + + // a++; // nope + // a[41]; // nope + // a + 41; // nope + // a - 41; // nope + + delete a; + delete b; + delete t; + + CHECK_EQ(0, Test::count); +} + +TEST_CASE("hashable") +{ + std::unordered_map, const char*> map; + NotNull a{new int(8)}; + NotNull b{new int(10)}; + + std::string hello = "hello"; + std::string world = "world"; + + map[a] = hello.c_str(); + map[b] = world.c_str(); + + CHECK_EQ(2, map.size()); + CHECK_EQ(hello.c_str(), map[a]); + CHECK_EQ(world.c_str(), map[b]); + + delete a; + delete b; +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b854bc51..4d9fad14 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -505,7 +505,6 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function id(x) return x end )"); @@ -518,7 +517,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function map(arr, fn) local t = {} @@ -537,7 +535,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... @@ -554,7 +551,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; TypePackVar empty{TypePack{}}; FunctionTypeVar ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); @@ -562,7 +558,6 @@ TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: a, ...): (a, a, b...) return x, x, ... @@ -577,7 +572,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): ...number return 1, 2, 3 @@ -592,7 +586,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): (string, ...number) return 'a', 1, 2, 3 @@ -607,7 +600,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local f: (number, y: number) -> number )"); @@ -620,7 +612,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: T, g: (T) -> U)): () end @@ -636,8 +627,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; - CheckResult result = check(R"( local function test(a, b : string, ... : number) return a end )"); @@ -665,7 +654,6 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () @@ -682,7 +670,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d90129d7..6f4191e3 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -470,8 +470,6 @@ caused by: TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") { - ScopedFastFlag luauClassDefinitionModuleInError{"LuauClassDefinitionModuleInError", true}; - CheckResult result = check(R"( local i = ChildClass.New() type ChildClass = { x: number } diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index a3cae3de..4444cd66 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -78,8 +78,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") { - ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; - CheckResult result = check(R"( local foo = "bar" for i, v in foo do diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index fcf37875..aa0731ca 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -13,6 +13,25 @@ struct Foo int x = 42; }; +struct Bar +{ + explicit Bar(int x) + : prop(x * 2) + { + ++count; + } + + ~Bar() + { + --count; + } + + int prop; + static int count; +}; + +int Bar::count = 0; + TEST_SUITE_BEGIN("Variant"); TEST_CASE("DefaultCtor") @@ -46,6 +65,29 @@ TEST_CASE("Create") CHECK(get_if(&v3)->x == 3); } +TEST_CASE("Emplace") +{ + { + Variant v1; + + CHECK(0 == Bar::count); + int& i = v1.emplace(5); + CHECK(5 == i); + + CHECK(0 == Bar::count); + + CHECK(get_if(&v1) == &i); + + Bar& bar = v1.emplace(11); + CHECK(22 == bar.prop); + CHECK(1 == Bar::count); + + CHECK(get_if(&v1) == &bar); + } + + CHECK(0 == Bar::count); +} + TEST_CASE("NonPOD") { // initialize (copy) diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis index 47ff0db1..5ff6e143 100644 --- a/tools/natvis/CodeGen.natvis +++ b/tools/natvis/CodeGen.natvis @@ -2,7 +2,7 @@ - noreg + noreg rip al @@ -36,14 +36,20 @@ - {reg} - {mem.size,en} ptr[{mem.base} + {mem.index}*{(int)mem.scale,d} + {disp}] + {base} + {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{imm}] {imm} - reg - mem + base imm - disp + memSize + base + index + scale + imm From 9619f036ac39834d3db07a4b89a35e46fd269664 Mon Sep 17 00:00:00 2001 From: Daniel Nachun Date: Mon, 6 Jun 2022 15:52:55 -0700 Subject: [PATCH 077/102] fix build with newer GCC (#522) --- Analysis/include/Luau/Variant.h | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index c9c97c92..f637222e 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace Luau { From b44912cd203c526bb470382cebf4b81ee737443a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petri=20H=C3=A4kkinen?= Date: Thu, 9 Jun 2022 19:41:52 +0300 Subject: [PATCH 078/102] Allow vector fastcall constructor to work with 3-4 arguments with 4-wide vectors (#511) --- VM/src/lbuiltins.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index cc6e560a..deaf1407 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1018,18 +1018,20 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { -#if LUA_VECTOR_SIZE == 4 - if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) -#else if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) -#endif { double x = nvalue(arg0); double y = nvalue(args); double z = nvalue(args + 1); #if LUA_VECTOR_SIZE == 4 - double w = nvalue(args + 2); + double w = 0.0; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = nvalue(args + 2); + } setvvalue(res, float(x), float(y), float(z), float(w)); #else setvvalue(res, float(x), float(y), float(z), 0.0f); From b8ef74372167580cc5a053dea21792ec7590a1d1 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 9 Jun 2022 10:03:37 -0700 Subject: [PATCH 079/102] RFC: Do not implement safe navigation operator (#501) This meta-RFC proposes removing the previously accepted RFC on safe navigation operator. This is probably going to be disappointing but is the best course of action that reflects our ideals in language evolution. See PR thread for rationale and discussion. --- rfcs/syntax-safe-navigation-operator.md | 104 ------------------------ 1 file changed, 104 deletions(-) delete mode 100644 rfcs/syntax-safe-navigation-operator.md diff --git a/rfcs/syntax-safe-navigation-operator.md b/rfcs/syntax-safe-navigation-operator.md deleted file mode 100644 index 11c4b37f..00000000 --- a/rfcs/syntax-safe-navigation-operator.md +++ /dev/null @@ -1,104 +0,0 @@ -# Safe navigation postfix operator (?) - -**Note**: We have unresolved issues with interaction between this feature and Roblox instance hierarchy. This may affect the viability of this proposal. - -## Summary - -Introduce syntax to navigate through `nil` values, or short-circuit with `nil` if it was encountered. - - -## Motivation - -nil values are very common in Lua, and take care to prevent runtime errors. - -Currently, attempting to index `dog.name` while caring for `dog` being nil requires some form of the following: - -```lua -local dogName = nil -if dog ~= nil then - dogName = dog.name -end -``` - -...or the unusual to read... - -```lua -local dogName = dog and dog.name -``` - -...which will return `false` if `dog` is `false`, instead of throwing an error because of the index of `false.name`. - -Luau provides the if...else expression making this turn into: - -```lua -local dogName = if dog == nil then nil else dog.name -``` - -...but this is fairly clunky for such a common expression. - -## Design - -The safe navigation operator will make all of these smooth, by supporting `x?.y` to safely index nil values. `dog?.name` would resolve to `nil` if `dog` was nil, or the name otherwise. - -The previous example turns into `local dogName = dog?.name` (or just using `dog?.name` elsewhere). - -Failing the nil-safety check early would make the entire expression nil, for instance `dog?.body.legs` would resolve to `nil` if `dog` is nil, rather than resolve `dog?.body` into nil, then turning into `nil.legs`. - -```lua -dog?.name --[[ is the same as ]] if dog == nil then nil else dog.name -``` - -The short-circuiting is limited within the expression. - -```lua -dog?.owner.name -- This will return nil if `dog` is nil -(dog?.owner).name -- `(dog?.owner)` resolves to nil, of which `name` is then indexed. This will error at runtime if `dog` is nil. - -dog?.legs + 3 -- `dog?.legs` is resolved on its own, meaning this will error at runtime if it is nil (`nil + 3`) -``` - -The operator must be used in the context of either a call or an index, and so: - -```lua -local value = x? -``` - -...would be invalid syntax. - -This syntax would be based on expressions, and not identifiers, meaning that `(x or y)?.call()` would be valid syntax. - -### Type -If the expression is typed as an optional, then the resulting type would be the final expression, also optional. Otherwise, it'll just be the resulting type if `?` wasn't used. - -```lua -local optionalObject: { name: string }? -local optionalObjectName = optionalObject?.name -- resolves to `string?` - -local nonOptionalObject: { name: string } -local nonOptionalObjectName = nonOptionalObject?.name -- resolves to `string` -``` - -### Calling - -This RFC only specifies `x?.y` as an index method. `x?:y()` is currently unspecified, and `x?.y(args)` as a syntax will be reserved (will error if you try to use it). - -While being able to support `dog?.getName()` is useful, it provides [some logistical issues for the language](https://github.com/Roblox/luau/pull/142#issuecomment-990563536). - -`x?.y(args)` will be reserved both so that this can potentially be resolved later down the line if something comes up, but also because it would be a guaranteed runtime error under this RFC: `dog?.getName()` will first index `dog?.getName`, which will return nil, then will attempt to call it. - -### Assignment -`x?.y = z` is not supported, and will be reported as a syntax error. - -## Drawbacks - -As with all syntax additions, this adds complexity to the parsing of expressions, and the execution of cancelling the rest of the expression could prove challenging. - -Furthermore, with the proposed syntax, it might lock off other uses of `?` within code (and not types) for the future as being ambiguous. - -## Alternatives - -Doing nothing is an option, as current standard if-checks already work, as well as the `and` trick in other use cases, but as shown before this can create some hard to read code, and nil values are common enough that the safe navigation operator is welcome. - -Supporting optional calls/indexes, such as `x?[1]` and `x?()`, while not out of scope, are likely too fringe to support, while adding on a significant amount of parsing difficulty, especially in the case of shorthand function calls, such as `x?{}` and `x?""`. - -It is possible to make `x?.y = z` resolve to only setting `x.y` if `x` is nil, but assignments silently failing can be seen as surprising. From bcab792e0d2eb91fc0618c733e8c66333293b813 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 9 Jun 2022 10:18:03 -0700 Subject: [PATCH 080/102] Update STATUS.md Generalized iteration got implemented, safe indexing got dropped. --- rfcs/STATUS.md | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index ef55b5c4..d2fe86f0 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -15,26 +15,12 @@ This document tracks unimplemented RFCs. **Status**: Needs implementation -## Safe navigation operator - -[RFC: Safe navigation postfix operator (?)](https://github.com/Roblox/luau/blob/master/rfcs/syntax-safe-navigation-operator.md) - -**Status**: Needs implementation. - -**Notes**: We have unresolved issues with interaction between this feature and Roblox instance hierarchy. This may affect the viability of this proposal. - ## String interpolation [RFC: String interpolation](https://github.com/Roblox/luau/blob/master/rfcs/syntax-string-interpolation.md) **Status**: Needs implementation -## Generalized iteration - -[RFC: Generalized iteration](https://github.com/Roblox/luau/blob/master/rfcs/generalized-iteration.md) - -**Status**: Implemented but not fully rolled out yet. - ## Lower Bounds Calculation [RFC: Lower bounds calculation](https://github.com/Roblox/luau/blob/master/rfcs/lower-bounds-calculation.md) From ca5fbbfc2467353e39799927425584c08a45b3d8 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 9 Jun 2022 10:18:29 -0700 Subject: [PATCH 081/102] Mark generalized iteration RFC as implemented --- rfcs/generalized-iteration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rfcs/generalized-iteration.md b/rfcs/generalized-iteration.md index 72bdd69e..99671090 100644 --- a/rfcs/generalized-iteration.md +++ b/rfcs/generalized-iteration.md @@ -1,5 +1,7 @@ # Generalized iteration +**Status**: Implemented + ## Summary Introduce support for iterating over tables without using `pairs`/`ipairs` as well as a generic customization point for iteration via `__iter` metamethod. From b066e4c8f8851d9727d079935cfa2e0f9b108531 Mon Sep 17 00:00:00 2001 From: rblanckaert <63755228+rblanckaert@users.noreply.github.com> Date: Fri, 10 Jun 2022 09:58:21 -0700 Subject: [PATCH 082/102] 0.531 (#532) * Fix free Luau type being fully overwritten by 'any' and causing UAF * Fix lua_clonefunction implementation replacing top instead of pushing * Falsey values other than false can now narrow refinements * Fix lua_getmetatable, lua_getfenv not waking thread up * FIx a case where lua_objlen could push a new string without thread wakeup or GC * Moved Luau math and bit32 definitions to definition file * Improve Luau parse recovery of incorrect return type token --- Analysis/include/Luau/TypePack.h | 13 +- Analysis/include/Luau/TypeVar.h | 11 +- Analysis/src/Autocomplete.cpp | 41 ++-- Analysis/src/BuiltinDefinitions.cpp | 36 +--- Analysis/src/Clone.cpp | 5 - Analysis/src/EmbeddedBuiltinDefinitions.cpp | 8 +- Analysis/src/Instantiation.cpp | 4 - Analysis/src/Quantify.cpp | 35 ---- Analysis/src/Scope.cpp | 5 +- Analysis/src/Substitution.cpp | 1 - Analysis/src/TxnLog.cpp | 56 +++++- Analysis/src/TypeInfer.cpp | 113 ++++++------ Analysis/src/TypePack.cpp | 21 +++ Analysis/src/TypeVar.cpp | 21 +++ Ast/src/Parser.cpp | 22 ++- Compiler/src/BytecodeBuilder.cpp | 10 +- Compiler/src/Compiler.cpp | 195 +++++--------------- VM/src/lapi.cpp | 106 +++++------ VM/src/ldo.cpp | 25 ++- VM/src/lvmexecute.cpp | 19 +- tests/Autocomplete.test.cpp | 37 ++-- tests/Compiler.test.cpp | 13 -- tests/Conformance.test.cpp | 59 +++++- tests/Module.test.cpp | 20 ++ tests/Parser.test.cpp | 17 ++ tests/TypeInfer.aliases.test.cpp | 6 - tests/TypeInfer.loops.test.cpp | 8 - tests/TypeInfer.refinements.test.cpp | 32 +++- tests/TypeInfer.singletons.test.cpp | 2 - tests/TypePack.test.cpp | 16 ++ tests/TypeVar.test.cpp | 20 ++ tests/conformance/apicalls.lua | 11 ++ tests/conformance/pcall.lua | 17 ++ 33 files changed, 536 insertions(+), 469 deletions(-) diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index bbc65f94..c1de242f 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -48,13 +48,24 @@ struct TypePackVar explicit TypePackVar(const TypePackVariant& ty); explicit TypePackVar(TypePackVariant&& ty); TypePackVar(TypePackVariant&& ty, bool persistent); + bool operator==(const TypePackVar& rhs) const; + TypePackVar& operator=(TypePackVariant&& tp); + TypePackVar& operator=(const TypePackVar& rhs); + + // Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent. + void reassign(const TypePackVar& rhs) + { + ty = rhs.ty; + } + TypePackVariant ty; + bool persistent = false; - // Pointer to the type arena that allocated this type. + // Pointer to the type arena that allocated this pack. TypeArena* owningArena = nullptr; }; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b3c455cf..b59e7c64 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -334,7 +334,6 @@ struct TableTypeVar // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; // TODO: Remove with FFlag::LuauNoMethodLocations std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; @@ -465,6 +464,14 @@ struct TypeVar final { } + // Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent. + void reassign(const TypeVar& rhs) + { + ty = rhs.ty; + normal = rhs.normal; + documentationSymbol = rhs.documentationSymbol; + } + TypeVariant ty; // Kludge: A persistent TypeVar is one that belongs to the global scope. @@ -486,6 +493,8 @@ struct TypeVar final TypeVar& operator=(const TypeVariant& rhs); TypeVar& operator=(TypeVariant&& rhs); + + TypeVar& operator=(const TypeVar& rhs); }; using SeenSet = std::set>; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b988ed35..a8319c59 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,8 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -248,7 +247,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); @@ -267,7 +266,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (std::optional firstRetTy = first(ftv->retType)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); @@ -308,7 +307,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; @@ -325,7 +324,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) rootTy = follow(rootTy); ty = follow(ty); @@ -335,7 +334,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId seen.insert(ty); auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -368,7 +367,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -382,10 +381,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return calledWithSelf == ftv->hasSelf; } - if (std::optional firstArgTy = first(ftv->argTypes)) + // If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all + // If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible + if (calledWithSelf || ftv->hasSelf) { - if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) - return calledWithSelf; + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; + } } return !calledWithSelf; @@ -427,7 +431,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -462,8 +466,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, - FFlag::LuauFixAutocompleteClassSecurityLevel ? containingClass : cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } else if (auto tbl = get(ty)) fillProps(tbl->props); @@ -471,7 +474,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (auto mtable = get(mt->metatable)) fillMetatableProps(mtable); @@ -537,7 +540,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix) + if (!FFlag::LuauSelfCallAutocompleteFix2) innerSeen = seen; if (isNil(*iter)) @@ -563,7 +566,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix) + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix2) { if (pt->metatable) { @@ -571,7 +574,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix && get(get(ty))) + else if (FFlag::LuauSelfCallAutocompleteFix2 && get(get(ty))) { autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); } @@ -1501,7 +1504,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty)) + if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), finder.ancestry}; else diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 5ed6de67..98737b43 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -179,44 +179,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker) LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; TypeArena& arena = typeChecker.globalTypes; - TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - - TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); - - TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ - listOfAtLeastOneNumber, - oneNumberPack, - }); - - TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId mathLibType = getGlobalBinding(typeChecker, "math"); - if (TableTypeVar* ttv = getMutable(mathLibType)) - { - ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min"); - ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max"); - } - - TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32"); - if (TableTypeVar* ttv = getMutable(bit32LibType)) - { - ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band"); - ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor"); - ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor"); - ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); - } - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); @@ -231,7 +200,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K | nil) -> (K, V) + // next(t: Table, i: K?) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); @@ -241,8 +210,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - // NOTE we are missing 'i: K | nil' argument in the first return types' argument. - // pairs(t: Table) -> ((Table) -> (K, V), Table, nil) + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 19e3383e..9180f309 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -9,7 +9,6 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -241,8 +240,6 @@ void TypeCloner::operator()(const TableTypeVar& t) arg = clone(arg, dest, cloneState); ttv->definitionModuleName = t.definitionModuleName; - if (!FFlag::LuauNoMethodLocations) - ttv->methodDefinitionLocations = t.methodDefinitionLocations; ttv->tags = t.tags; } @@ -406,8 +403,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index f184b74e..2407e3ef 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -7,7 +7,10 @@ namespace Luau static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( declare bit32: { - -- band, bor, bxor, and btest are declared in C++ + band: (...number) -> number, + bor: (...number) -> number, + bxor: (...number) -> number, + btest: (number, ...number) -> boolean, rrotate: (number, number) -> number, lrotate: (number, number) -> number, lshift: (number, number) -> number, @@ -50,7 +53,8 @@ declare math: { asin: (number) -> number, atan2: (number, number) -> number, - -- min and max are declared in C++. + min: (number, ...number) -> number, + max: (number, ...number) -> number, pi: number, huge: number, diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 4a12027d..f145a511 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -4,8 +4,6 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" -LUAU_FASTFLAG(LuauNoMethodLocations) - namespace Luau { @@ -110,8 +108,6 @@ TypeId ReplaceGenerics::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 8f2cc8e3..21775373 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -32,41 +32,6 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); } - void cycle(TypeId) override {} - void cycle(TypePackId) override {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - - template - bool operator()(TypeId ty, const T& t) - { - return true; - } - - template - bool operator()(TypePackId, const T&) - { - return true; - } - - bool operator()(TypeId ty, const ConstrainedTypeVar&) - { - return true; - } - - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - /// @return true if outer encloses inner bool subsumes(Scope2* outer, Scope2* inner) { diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 0a362a5e..011e28d4 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,8 +2,6 @@ #include "Luau/Scope.h" -LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix); - namespace Luau { @@ -19,8 +17,7 @@ Scope::Scope(const ScopePtr& parent, int subLevel) , returnType(parent->returnType) , level(parent->level.incr()) { - if (FFlag::LuauTwoPassAliasDefinitionFix) - level = level.incr(); + level = level.incr(); level.subLevel = subLevel; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 50c516db..5a22deeb 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -10,7 +10,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index e45c0cbd..4c6d54e0 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -80,18 +82,32 @@ void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(ty)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; + } } for (auto& [tp, rep] : typePackChanges) { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(tp)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; + } } clear(); @@ -178,8 +194,13 @@ PendingType* TxnLog::queue(TypeId ty) // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; if (!pending) + { pending = std::make_unique(*ty); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -191,8 +212,13 @@ PendingTypePack* TxnLog::queue(TypePackId tp) // about this type, we don't want to mutate the parent's state. auto& pending = typePackChanges[tp]; if (!pending) + { pending = std::make_unique(*tp); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -229,14 +255,24 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - newTy->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTy->pending.reassign(replacement); + else + newTy->pending = replacement; + return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - newTp->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTp->pending.reassign(replacement); + else + newTp->pending = replacement; + return newTp; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4931bc59..447cd029 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,21 +33,20 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); -LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) -LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) +LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) namespace Luau { @@ -358,8 +357,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.cachedUnifyError.clear(); unifierState.skipCacheForType.clear(); - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.clear(); + duplicateTypeAliases.clear(); return std::move(currentModule); } @@ -610,7 +608,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == kParseNameError) + if (typealias->name == kParseNameError) continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -619,7 +617,16 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = *errorRecoveryType(anyType); + if (FFlag::LuauNonCopyableTypeVarFields) + { + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); + } + else + { + *asMutable(type) = *errorRecoveryType(anyType); + } + reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -1131,45 +1138,43 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (FFlag::LuauTypecheckIter) + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) { - if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + + return check(loopScope, *forin.body); + } + + if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) { - // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions - // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments - // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + for (TypeId var : varTypes) - unify(anyType, var, forin.location); + unify(varTy, var, forin.location); - return check(loopScope, *forin.body); + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); } - else if (const TableTypeVar* iterTable = get(iterTy)) - { - // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer - // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting - if (iterTable->indexer) - { - if (varTypes.size() > 0) - unify(iterTable->indexer->indexType, varTypes[0], forin.location); - if (varTypes.size() > 1) - unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); - - for (size_t i = 2; i < varTypes.size(); ++i) - unify(nilType, varTypes[i], forin.location); - } - else - { - TypeId varTy = errorRecoveryType(loopScope); - - for (TypeId var : varTypes) - unify(varTy, var, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); - } - - return check(loopScope, *forin.body); - } + return check(loopScope, *forin.body); } const FunctionTypeVar* iterFunc = get(iterTy); @@ -1334,7 +1339,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (FFlag::LuauTwoPassAliasDefinitionFix && name == kParseNameError) + if (name == kParseNameError) return; std::optional binding; @@ -1353,8 +1358,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.insert({typealias.exported, name}); + duplicateTypeAliases.insert({typealias.exported, name}); } else { @@ -1378,7 +1382,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. - if (FFlag::LuauTwoPassAliasDefinitionFix && duplicateTypeAliases.find({typealias.exported, name})) + if (duplicateTypeAliases.find({typealias.exported, name})) return; if (!binding) @@ -1422,9 +1426,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -1462,9 +1463,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } TypeId& bindingType = bindingsMap[name].type; - bool ok = unify(ty, bindingType, typealias.location); - if (FFlag::LuauTwoPassAliasDefinitionFix && ok) + if (unify(ty, bindingType, typealias.location)) bindingType = ty; if (FFlag::LuauLowerBoundsCalculation) @@ -1532,7 +1532,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) ftv->hasSelf = true; } } @@ -3099,8 +3099,6 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T property.type = freshTy(); property.location = indexName->indexLocation; - if (!FFlag::LuauNoMethodLocations) - ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) @@ -4393,8 +4391,6 @@ TypeId Anyification::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; @@ -4705,8 +4701,11 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) if (isNil(ty)) return sense ? std::nullopt : std::optional(ty); - // at this point, anything else is kept if sense is true, or eliminated otherwise - return sense ? std::optional(ty) : std::nullopt; + // at this point, anything else is kept if sense is true, or replaced by nil + if (FFlag::LuauFalsyPredicateReturnsNilInstead) + return sense ? ty : nilType; + else + return sense ? std::optional(ty) : std::nullopt; }; } diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 30503233..82451bd1 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -36,6 +38,25 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) return *this; } +TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + owningArena = rhs.owningArena; + } + + return *this; +} + TypePackIterator::TypePackIterator(TypePackId typePack) : TypePackIterator(typePack, TxnLog::empty()) { diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 12cbed91..33bfe254 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,6 +24,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { @@ -644,6 +645,26 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) return *this; } +TypeVar& TypeVar::operator=(const TypeVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + normal = rhs.normal; + owningArena = rhs.owningArena; + } + + return *this; +} + TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index eaf19914..95bce3ee 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -12,6 +12,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) namespace Luau { @@ -1118,8 +1119,12 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && lexer.current().type == ':') + if (options.allowTypeAnnotations && + (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) { + if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); + nextLexeme(); unsigned int oldRecursionCount = recursionCounter; @@ -1350,8 +1355,12 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + bool returnTypeIntroducer = + FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + if (params.size() == 1 && !varargAnnotation && monomorphic && + (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; @@ -1359,7 +1368,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; AstArray> paramNames = copy(names); @@ -1373,8 +1382,13 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); + lexer.next(); + } // Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error - if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) + else if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) { report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?"); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 3aa12d99..597b2f0a 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauCompileNestedClosureO2) - namespace Luau { @@ -390,17 +388,15 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) int16_t BytecodeBuilder::addChildFunction(uint32_t fid) { - if (FFlag::LuauCompileNestedClosureO2) - if (int16_t* cache = protoMap.find(fid)) - return *cache; + if (int16_t* cache = protoMap.find(fid)) + return *cache; uint32_t id = uint32_t(protos.size()); if (id >= kMaxClosureCount) return -1; - if (FFlag::LuauCompileNestedClosureO2) - protoMap[fid] = int16_t(id); + protoMap[fid] = int16_t(id); protos.push_back(fid); return int16_t(id); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index eea56c60..7431cde4 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,7 +16,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) @@ -26,8 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileNestedClosureO2, false) - namespace Luau { @@ -172,30 +169,6 @@ struct Compiler return node->as(); } - bool canInlineFunctionBody(AstStat* stat) - { - if (FFlag::LuauCompileNestedClosureO2) - return true; // TODO: remove this function - - struct CanInlineVisitor : AstVisitor - { - bool result = true; - - bool visit(AstExprFunction* node) override - { - result = false; - - // short-circuit to avoid analyzing nested closure bodies - return false; - } - }; - - CanInlineVisitor canInline; - stat->visit(&canInline); - - return canInline.result; - } - uint32_t compileFunction(AstExprFunction* func) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -268,7 +241,7 @@ struct Compiler f.upvals = upvals; // record information for inlining - if (options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && !getfenvUsed && !setfenvUsed) + if (options.optimizationLevel >= 2 && !func->vararg && !getfenvUsed && !setfenvUsed) { f.canInline = true; f.stackSize = stackSize; @@ -827,110 +800,62 @@ struct Compiler if (pid < 0) CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); - if (FFlag::LuauCompileNestedClosureO2) - { - captures.clear(); - captures.reserve(f->upvals.size()); - - for (AstLocal* uv : f->upvals) - { - LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - - if (int reg = getLocalReg(uv); reg >= 0) - { - // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals - Variable* ul = variables.find(uv); - bool immutable = !ul || !ul->written; - - captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); - } - else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) - { - // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register - uint8_t reg = allocReg(expr, 1); - compileExprConstant(expr, uc, reg); - - captures.push_back({LCT_VAL, reg}); - } - else - { - LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); - - // get upvalue from parent frame - // note: this will add uv to the current upvalue list if necessary - uint8_t uid = getUpval(uv); - - captures.push_back({LCT_UPVAL, uid}); - } - } - - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - int16_t shared = -1; - - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - shared = int16_t(cid); - } - - if (shared >= 0) - bytecode.emitAD(LOP_DUPCLOSURE, target, shared); - else - bytecode.emitAD(LOP_NEWCLOSURE, target, pid); - - for (const Capture& c : captures) - bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); - - return; - } - - bool shared = false; - - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, int16_t(cid)); - shared = true; - } - } - - if (!shared) - bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + // we use a scratch vector to reduce allocations; this is safe since compileExprFunction is not reentrant + captures.clear(); + captures.reserve(f->upvals.size()); for (AstLocal* uv : f->upvals) { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Variable* ul = variables.find(uv); - bool immutable = !ul || !ul->written; - - if (uv->functionDepth == expr->functionDepth - 1) + if (int reg = getLocalReg(uv); reg >= 0) { - // get local variable - int reg = getLocalReg(uv); - LUAU_ASSERT(reg >= 0); + // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; - bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), uint8_t(reg), 0); + captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); + } + else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) + { + // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register + uint8_t reg = allocReg(expr, 1); + compileExprConstant(expr, uc, reg); + + captures.push_back({LCT_VAL, reg}); } else { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); + // get upvalue from parent frame // note: this will add uv to the current upvalue list if necessary uint8_t uid = getUpval(uv); - bytecode.emitABC(LOP_CAPTURE, LCT_UPVAL, uid, 0); + captures.push_back({LCT_UPVAL, uid}); } } + + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + int16_t shared = -1; + + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + shared = int16_t(cid); + } + + if (shared >= 0) + bytecode.emitAD(LOP_DUPCLOSURE, target, shared); + else + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (const Capture& c : captures) + bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); } LuauOpcode getUnaryOp(AstExprUnary::Op op) @@ -2511,30 +2436,6 @@ struct Compiler pushLocal(stat->vars.data[i], uint8_t(vars + i)); } - bool canUnrollForBody(AstStatFor* stat) - { - if (FFlag::LuauCompileNestedClosureO2) - return true; // TODO: remove this function - - struct CanUnrollVisitor : AstVisitor - { - bool result = true; - - bool visit(AstExprFunction* node) override - { - result = false; - - // short-circuit to avoid analyzing nested closure bodies - return false; - } - }; - - CanUnrollVisitor canUnroll; - stat->body->visit(&canUnroll); - - return canUnroll.result; - } - bool tryCompileUnrolledFor(AstStatFor* stat, int thresholdBase, int thresholdMaxBoost) { Constant one = {Constant::Type_Number}; @@ -2560,12 +2461,6 @@ struct Compiler return false; } - if (!canUnrollForBody(stat)) - { - bytecode.addDebugRemark("loop unroll failed: unsupported loop body"); - return false; - } - if (Variable* lv = variables.find(stat->var); lv && lv->written) { bytecode.addDebugRemark("loop unroll failed: mutable loop variable"); @@ -2730,12 +2625,12 @@ struct Compiler uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); LUAU_ASSERT(vars == regs + 3); - // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration - // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 - // variables, which is why we allocate at least 2 above (see vars assignment) - LuauOpcode skipOp = FFlag::LuauCompileIter ? LOP_FORGPREP : LOP_JUMP; + LuauOpcode skipOp = LOP_FORGPREP; LuauOpcode loopOp = LOP_FORGLOOP; + // Optimization: when we iterate via pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration index + // These instructions dynamically check if generator is equal to next/inext and bail out + // They assume that the generator produces 2 variables, which is why we allocate at least 2 above (see vars assignment) if (options.optimizationLevel >= 1 && stat->vars.size <= 2) { if (stat->values.size == 1 && stat->values.data[0]->is()) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f86371da..3c3b7bd0 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,26 @@ #include +/* + * This file contains most implementations of core Lua APIs from lua.h. + * + * These implementations should use api_check macros to verify that stack and type contracts hold; it's the callers + * responsibility to, for example, pass a valid table index to lua_rawgetfield. Generally errors should only be raised + * for conditions caller can't predict such as an out-of-memory error. + * + * The caller is expected to handle stack reservation (by using less than LUA_MINSTACK slots or by calling lua_checkstack). + * To ensure this is handled correctly, use api_incr_top(L) when pushing values to the stack. + * + * Functions that push any collectable objects to the stack *should* call luaC_checkthreadsleep. Failure to do this can result + * in stack references that point to dead objects since sleeping threads don't get rescanned. + * + * Functions that push newly created objects to the stack *should* call luaC_checkGC in addition to luaC_checkthreadsleep. + * Failure to do this can result in OOM since GC may never run. + * + * Note that luaC_checkGC may scan the thread and put it back to sleep; functions that call both before pushing objects must + * therefore call luaC_checkGC before luaC_checkthreadsleep to guarantee the object is pushed to an awake thread. + */ + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -221,15 +241,13 @@ void lua_insert(lua_State* L, int idx) void lua_replace(lua_State* L, int idx) { - /* explicit test for incompatible code */ - if (idx == LUA_ENVIRONINDEX && L->ci == L->base_ci) - luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { + api_check(L, L->ci != L->base_ci); Closure* func = curr_func(L); api_check(L, ttistable(L->top - 1)); func->env = hvalue(L->top - 1); @@ -443,9 +461,7 @@ const float* lua_tovector(lua_State* L, int idx) { StkId o = index2addr(L, idx); if (!ttisvector(o)) - { return NULL; - } return vvalue(o); } @@ -460,11 +476,6 @@ int lua_objlen(lua_State* L, int idx) return uvalue(o)->len; case LUA_TTABLE: return luaH_getn(hvalue(o)); - case LUA_TNUMBER: - { - int l = (luaV_tostring(L, o) ? tsvalue(o)->len : 0); - return l; - } default: return 0; } @@ -752,10 +763,9 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) int lua_getmetatable(lua_State* L, int objindex) { - const TValue* obj; + luaC_checkthreadsleep(L); Table* mt = NULL; - int res; - obj = index2addr(L, objindex); + const TValue* obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -768,21 +778,18 @@ int lua_getmetatable(lua_State* L, int objindex) mt = L->global->mt[ttype(obj)]; break; } - if (mt == NULL) - res = 0; - else + if (mt) { sethvalue(L, L->top, mt); api_incr_top(L); - res = 1; } - return res; + return mt != NULL; } void lua_getfenv(lua_State* L, int idx) { - StkId o; - o = index2addr(L, idx); + luaC_checkthreadsleep(L); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -806,9 +813,8 @@ void lua_getfenv(lua_State* L, int idx) void lua_settable(lua_State* L, int idx) { - StkId t; api_checknelems(L, 2); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -817,22 +823,20 @@ void lua_settable(lua_State* L, int idx) void lua_setfield(lua_State* L, int idx, const char* k) { - StkId t; - TValue key; api_checknelems(L, 1); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); + TValue key; setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); - L->top--; /* pop value */ + L->top--; return; } void lua_rawset(lua_State* L, int idx) { - StkId t; api_checknelems(L, 2); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -844,9 +848,8 @@ void lua_rawset(lua_State* L, int idx) void lua_rawseti(lua_State* L, int idx, int n) { - StkId o; api_checknelems(L, 1); - o = index2addr(L, idx); + StkId o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -858,14 +861,11 @@ void lua_rawseti(lua_State* L, int idx, int n) int lua_setmetatable(lua_State* L, int objindex) { - TValue* obj; - Table* mt; api_checknelems(L, 1); - obj = index2addr(L, objindex); + TValue* obj = index2addr(L, objindex); api_checkvalidindex(L, obj); - if (ttisnil(L->top - 1)) - mt = NULL; - else + Table* mt = NULL; + if (!ttisnil(L->top - 1)) { api_check(L, ttistable(L->top - 1)); mt = hvalue(L->top - 1); @@ -900,10 +900,9 @@ int lua_setmetatable(lua_State* L, int objindex) int lua_setfenv(lua_State* L, int idx) { - StkId o; int res = 1; api_checknelems(L, 1); - o = index2addr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -970,24 +969,21 @@ static void f_call(lua_State* L, void* ud) int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) { - struct CallS c; - int status; - ptrdiff_t func; api_checknelems(L, nargs + 1); api_check(L, L->status == 0); checkresults(L, nargs, nresults); - if (errfunc == 0) - func = 0; - else + ptrdiff_t func = 0; + if (errfunc != 0) { StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } + struct CallS c; c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + int status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); adjustresults(L, nresults); return status; @@ -1247,12 +1243,10 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) const char* lua_setupvalue(lua_State* L, int funcindex, int n) { - const char* name; - TValue* val; - StkId fi; - fi = index2addr(L, funcindex); api_checknelems(L, 1); - name = aux_upvalue(fi, n, &val); + StkId fi = index2addr(L, funcindex); + TValue* val; + const char* name = aux_upvalue(fi, n, &val); if (name) { L->top--; @@ -1319,14 +1313,16 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) void lua_clonefunction(lua_State* L, int idx) { + luaC_checkGC(L); + luaC_checkthreadsleep(L); StkId p = index2addr(L, idx); api_check(L, isLfunction(p)); - - luaC_checkthreadsleep(L); - Closure* cl = clvalue(p); - Closure* newcl = luaF_newLclosure(L, 0, L->gt, cl->l.p); - setclvalue(L, L->top - 1, newcl); + Closure* newcl = luaF_newLclosure(L, cl->nupvalues, L->gt, cl->l.p); + for (int i = 0; i < cl->nupvalues; ++i) + setobj2n(L, &newcl->l.uprefs[i], &cl->l.uprefs[i]); + setclvalue(L, L->top, newcl); + api_incr_top(L); } lua_Callbacks* lua_callbacks(lua_State* L) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index a71fce52..0642cb6d 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -202,22 +202,29 @@ void luaD_growstack(lua_State* L, int n) CallInfo* luaD_growCI(lua_State* L) { - if (L->size_ci > LUAI_MAXCALLS) /* overflow while handling overflow? */ - luaD_throw(L, LUA_ERRERR); - else - { - luaD_reallocCI(L, 2 * L->size_ci); - if (L->size_ci > LUAI_MAXCALLS) - luaG_runerror(L, "stack overflow"); - } + /* allow extra stack space to handle stack overflow in xpcall */ + const int hardlimit = LUAI_MAXCALLS + (LUAI_MAXCALLS >> 3); + + if (L->size_ci >= hardlimit) + luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ + + int request = L->size_ci * 2; + luaD_reallocCI(L, L->size_ci >= LUAI_MAXCALLS ? hardlimit : request < LUAI_MAXCALLS ? request : LUAI_MAXCALLS); + + if (L->size_ci > LUAI_MAXCALLS) + luaG_runerror(L, "stack overflow"); + return ++L->ci; } void luaD_checkCstack(lua_State* L) { + /* allow extra stack space to handle stack overflow in xpcall */ + const int hardlimit = LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3); + if (L->nCcalls == LUAI_MAXCCALLS) luaG_runerror(L, "C stack overflow"); - else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) + else if (L->nCcalls >= hardlimit) luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index f9fd6574..e0a96474 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauIter, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2214,7 +2212,7 @@ static void luau_execute(lua_State* L) { /* will be called during FORGLOOP */ } - else if (FFlag::LuauIter) + else { Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); @@ -2259,17 +2257,6 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); uint32_t aux = *pc; - if (!FFlag::LuauIter) - { - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); - - // note that we need to increment pc by 1 to exit the loop since we need to skip over aux - pc += stop ? 1 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - // fast-path: builtin table iteration if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { @@ -2362,7 +2349,7 @@ static void luau_execute(lua_State* L) /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } - else if (FFlag::LuauIter && !ttisfunction(ra)) + else if (!ttisfunction(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); } @@ -2434,7 +2421,7 @@ static void luau_execute(lua_State* L) /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } - else if (FFlag::LuauIter && !ttisfunction(ra)) + else if (!ttisfunction(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index cc5b31c9..dea1ab19 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2764,8 +2764,6 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2838,8 +2836,6 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end @@ -2873,7 +2869,7 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; loadDefinition(R"( declare class Foo @@ -2913,7 +2909,7 @@ t.@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local t = {} @@ -2929,7 +2925,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end @@ -2961,7 +2957,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2980,7 +2976,7 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2989,17 +2985,15 @@ s.@1 auto ac = autocomplete('1'); - REQUIRE(ac.entryMap.count("byte")); - CHECK(ac.entryMap["byte"].wrongIndexType == true); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == true); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_non_self_calls_are_fine") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string.@1 @@ -3013,11 +3007,24 @@ string.@1 CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + + check(R"( +table.@1 + )"); + + ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("remove")); + CHECK(ac.entryMap["remove"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("getn")); + CHECK(ac.entryMap["getn"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("insert")); + CHECK(ac.entryMap["insert"].wrongIndexType == false); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_self_calls_are_invalid") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string:@1 diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 20139650..6eee254e 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,7 +261,6 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { - ScopedFastFlag sff("LuauCompileIter", true); ScopedFastFlag sff2("LuauCompileIterNoPairs", false); // basic for loop: variable directly refers to internal iteration index (R2) @@ -350,8 +349,6 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { - ScopedFastFlag sff("LuauCompileIter", true); - // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 @@ -2323,8 +2320,6 @@ return result TEST_CASE("DebugLineInfoFor") { - ScopedFastFlag sff("LuauCompileIter", true); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -4355,8 +4350,6 @@ L1: RETURN R0 0 TEST_CASE("LoopUnrollControlFlow") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 50}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -4475,8 +4468,6 @@ RETURN R0 0 TEST_CASE("LoopUnrollNestedClosure") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues CHECK_EQ("\n" + compileFunction(R"( for i=1,2 do @@ -4756,8 +4747,6 @@ RETURN R1 1 TEST_CASE("InlineBasicProhibited") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // we can't inline variadic functions CHECK_EQ("\n" + compileFunction(R"( local function foo(...) @@ -4833,8 +4822,6 @@ RETURN R1 1 TEST_CASE("InlineNestedClosures") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // we can inline functions that contain/return functions CHECK_EQ("\n" + compileFunction(R"( local function foo(x) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index f7f2b4ac..96a2775f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -741,7 +741,7 @@ TEST_CASE("ApiTables") lua_pop(L, 1); } -TEST_CASE("ApiFunctionCalls") +TEST_CASE("ApiCalls") { StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); @@ -790,6 +790,58 @@ TEST_CASE("ApiFunctionCalls") CHECK(lua_equal(L2, -1, -2) == 1); lua_pop(L2, 2); } + + // lua_clonefunction + fenv + { + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + + // clone & override env + lua_clonefunction(L, -1); + lua_newtable(L); + lua_pushnumber(L, 42); + lua_setfield(L, -2, "pi"); + lua_setfenv(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + } + + // lua_clonefunction + upvalues + { + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 1); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + // two clones + lua_clonefunction(L, -1); + lua_clonefunction(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 2); + lua_pop(L, 1); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 4); + lua_pop(L, 1); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -1113,11 +1165,6 @@ TEST_CASE("UserdataApi") TEST_CASE("Iter") { - ScopedFastFlag sffs[] = { - {"LuauCompileIter", true}, - {"LuauIter", true}, - }; - runConformance("iter.lua"); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index c7e18efd..89b13ab1 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -300,4 +300,24 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } +TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + fileResolver.source["Module/A"] = R"( +export type A = B +type B = A + )"; + + FrontendOptions opts; + opts.retainFullTypeGraphs = false; + CheckResult result = frontend.check("Module/A", opts); + LUAU_REQUIRE_ERRORS(result); + + auto mod = frontend.moduleResolver.getModule("Module/A"); + auto it = mod->getModuleScope()->exportedTypeBindings.find("A"); + REQUIRE(it != mod->getModuleScope()->exportedTypeBindings.end()); + CHECK(toString(it->second.type) == "any"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 87b1263f..878023e3 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2622,6 +2622,23 @@ type Z = { a: string | T..., b: number } REQUIRE_EQ(3, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") +{ + ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; + ParseResult result = tryParse(R"( +type Custom = { x: A, y: B, z: C } +type Packed = { x: (A...) -> () } +type F = (number): Custom +type G = Packed<(number): (string, number, boolean)> +local function f(x: number) -> Custom +end + )"); + REQUIRE_EQ(3, result.errors.size()); + CHECK_EQ(result.errors[0].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[1].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[2].getMessage(), "Function return type annotations are written after ':' instead of '->'"); +} + TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") { ScopedFastFlag sff{"LuauParserFunctionKeywordAsTypeHelp", true}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 7562a4d7..86cc9701 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -615,8 +615,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ */ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") { - ScopedFastFlag sff[] = {{"LuauTwoPassAliasDefinitionFix", true}}; - CheckResult result = check(R"( local function x() local y: FutureType = {}::any @@ -633,10 +631,6 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { - ScopedFastFlag sff[] = { - {"LuauTwoPassAliasDefinitionFix", true}, - }; - CheckResult result = check(R"( local B = {} B.bar = 4 diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 4444cd66..1c6fe1d8 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -486,8 +486,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t: {string} = {} local key @@ -506,8 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t: {string} = {} local extra @@ -522,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t = {} for k, v in t do @@ -539,8 +533,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t = {} setmetatable(t, { __iter = function(o) return next, o.children end }) diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 6785f277..207b3cff 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -932,6 +932,8 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -947,7 +949,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") @@ -1191,7 +1193,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") { - const std::string code = R"( + CheckResult result = check(R"( function f(a) if type(a) == "boolean" then local a1 = a @@ -1201,10 +1203,30 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") local a3 = a end end - )"; - CheckResult result = check(code); + )"); + LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") +{ + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + + CheckResult result = check(R"( + local function f(t: {number}) + local x = t[1] + if not x then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 14a5a6ae..a90f434f 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -139,8 +139,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index c4931578..8a5a65fe 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -197,4 +197,20 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") CHECK_EQ(4, std::distance(b, e)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; + + TypeArena arena; + + TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + asMutable(futureError)->reassign(myError); + + CHECK(get(futureError) != nullptr); + CHECK(!futureError->persistent); + CHECK(futureError->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index bb2d94ba..4f8fc502 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -416,4 +416,24 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") CHECK(!isBoolean(&union_)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; + myAny.normal = true; + myAny.documentationSymbol = "@global/any"; + + TypeArena arena; + + TypeId futureAny = arena.addType(FreeTypeVar{TypeLevel{}}); + asMutable(futureAny)->reassign(myAny); + + CHECK(get(futureAny) != nullptr); + CHECK(!futureAny->persistent); + CHECK(futureAny->normal); + CHECK(futureAny->documentationSymbol == "@global/any"); + CHECK(futureAny->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 7a4058b5..27416623 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -11,4 +11,15 @@ function create_with_tm(x) return setmetatable({ a = x }, m) end +local gen = 0 +function incuv() + gen += 1 + return gen +end + +pi = 3.1415926 +function getpi() + return pi +end + return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 84ac2ba1..969209fc 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,21 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) +-- stack overflow needs to happen at the call limit +local calllimit = 20000 +function recurse(n) return n <= 1 and 1 or recurse(n-1) + 1 end + +-- we use one frame for top-level function and one frame is the service frame for coroutines +assert(recurse(calllimit - 2) == calllimit - 2) + +-- note that when calling through pcall, pcall eats one more frame +checkresults({ true, calllimit - 3 }, pcall(recurse, calllimit - 3)) +checkerror(pcall(recurse, calllimit - 2)) + +-- xpcall handler runs in context of the stack frame, but this works just fine since we allow extra stack consumption past stack overflow +checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse("ko") end, calllimit - 2)) + +-- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" +checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) + return 'OK' From c30ab0647b88d42b872aabf3b7d142e331c7e8ab Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Tue, 14 Jun 2022 16:39:25 +0100 Subject: [PATCH 083/102] Improve table stringifier when line breaks enabled (#488) * Improve table stringifier when line breaks enabled * Add FFlag * Fix FFlags --- Analysis/src/ToString.cpp | 98 ++++++++++++++++++++++++++++++--------- tests/ToString.test.cpp | 36 ++++++++++++++ 2 files changed, 113 insertions(+), 21 deletions(-) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 8490350d..04d15cf7 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false) namespace Luau { @@ -283,7 +284,8 @@ struct TypeVarStringifier } Luau::visit( - [this, tv](auto&& t) { + [this, tv](auto&& t) + { return (*this)(tv, t); }, tv->ty); @@ -557,22 +559,54 @@ struct TypeVarStringifier { case TableState::Sealed: state.result.invalid = true; - openbrace = "{| "; - closedbrace = " |}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{|"; + closedbrace = "|}"; + } + else + { + openbrace = "{| "; + closedbrace = " |}"; + } break; case TableState::Unsealed: - openbrace = "{ "; - closedbrace = " }"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{"; + closedbrace = "}"; + } + else + { + openbrace = "{ "; + closedbrace = " }"; + } break; case TableState::Free: state.result.invalid = true; - openbrace = "{- "; - closedbrace = " -}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{-"; + closedbrace = "-}"; + } + else + { + openbrace = "{- "; + closedbrace = " -}"; + } break; case TableState::Generic: state.result.invalid = true; - openbrace = "{+ "; - closedbrace = " +}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{+"; + closedbrace = "+}"; + } + else + { + openbrace = "{+ "; + closedbrace = " +}"; + } break; } @@ -591,6 +625,8 @@ struct TypeVarStringifier bool comma = false; if (ttv.indexer) { + if (FFlag::LuauToStringTableBracesNewlines) + state.newline(); state.emit("["); stringify(ttv.indexer->indexType); state.emit("]: "); @@ -607,6 +643,10 @@ struct TypeVarStringifier state.emit(","); state.newline(); } + else if (FFlag::LuauToStringTableBracesNewlines) + { + state.newline(); + } size_t length = state.result.name.length() - oldLength; @@ -633,6 +673,13 @@ struct TypeVarStringifier } state.dedent(); + if (FFlag::LuauToStringTableBracesNewlines) + { + if (comma) + state.newline(); + else + state.emit(" "); + } state.emit(closedbrace); state.unsee(&ttv); @@ -833,7 +880,8 @@ struct TypePackStringifier } Luau::visit( - [this, tp](auto&& t) { + [this, tp](auto&& t) + { return (*this)(tp, t); }, tp->ty); @@ -964,9 +1012,11 @@ static void assignCycleNames(const std::set& cycles, const std::set(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), + [&](auto&& el) + { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; continue; @@ -1062,9 +1112,11 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1075,7 +1127,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto&& t) + { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1132,9 +1185,11 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1145,7 +1200,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto t) { + [&tvs, cycleTy = cycleTy](auto t) + { return tvs(cycleTy, t); }, cycleTy->ty); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4d9fad14..4d2e94ee 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -60,6 +60,42 @@ TEST_CASE_FIXTURE(Fixture, "named_table") CHECK_EQ("TheTable", toString(&table)); } +TEST_CASE_FIXTURE(Fixture, "empty_table") +{ + ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); + CheckResult result = check(R"( + local a: {} + )"); + + CHECK_EQ("{| |}", toString(requireType("a"))); + + // Should stay the same with useLineBreaks enabled + ToStringOptions opts; + opts.useLineBreaks = true; + CHECK_EQ("{| |}", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") +{ + ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); + CheckResult result = check(R"( + local a: { prop: string, anotherProp: number, thirdProp: boolean } + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + opts.indent = true; + + //clang-format off + CHECK_EQ("{|\n" + " anotherProp: number,\n" + " prop: string,\n" + " thirdProp: boolean\n" + "|}", + toString(requireType("a"), opts)); + //clang-format on +} + TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( From da01056022ce7e85f402932488bdb7f4ef495a65 Mon Sep 17 00:00:00 2001 From: Allan N Jeremy Date: Tue, 14 Jun 2022 18:48:07 +0300 Subject: [PATCH 084/102] Added Luau Benchmark Workflows (#530) --- .github/workflows/benchmark.yml | 109 ++++++++++++++++++++++++++++++++ scripts/run-with-cachegrind.sh | 102 ++++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 .github/workflows/benchmark.yml create mode 100644 scripts/run-with-cachegrind.sh diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 00000000..68f63006 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,109 @@ +name: Luau Benchmarks + +on: + push: + branches: + - master + + paths-ignore: + - "docs/**" + - "papers/**" + - "rfcs/**" + - "*.md" + - "prototyping/**" + +jobs: + benchmarks-run: + name: Run ${{ matrix.bench.title }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau + uses: actions/checkout@v3 + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Install valgrind + run: | + sudo apt-get install valgrind + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle }}" ${{ matrix.bench.cachegrindIterCount }} | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "benchmarkluau" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + alert-threshold: 150% + fail-threshold: 1000% + fail-on-alert: false + comment-on-alert: true + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} (CacheGrind) + tool: "roblox" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + alert-threshold: 150% + fail-threshold: 1000% + fail-on-alert: false + comment-on-alert: true + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Push benchmark results + + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. diff --git a/scripts/run-with-cachegrind.sh b/scripts/run-with-cachegrind.sh new file mode 100644 index 00000000..eb4a8c3f --- /dev/null +++ b/scripts/run-with-cachegrind.sh @@ -0,0 +1,102 @@ +#!/bin/bash +set -euo pipefail +IFS=$'\n\t' + +declare -A event_map +event_map[Ir]="TotalInstructionsExecuted,executions\n" +event_map[I1mr]="L1_InstrReadCacheMisses,misses/op\n" +event_map[ILmr]="LL_InstrReadCacheMisses,misses/op\n" +event_map[Dr]="TotalMemoryReads,reads\n" +event_map[D1mr]="L1_DataReadCacheMisses,misses/op\n" +event_map[DLmr]="LL_DataReadCacheMisses,misses/op\n" +event_map[Dw]="TotalMemoryWrites,writes\n" +event_map[D1mw]="L1_DataWriteCacheMisses,misses/op\n" +event_map[DLmw]="LL_DataWriteCacheMisses,misses/op\n" +event_map[Bc]="ConditionalBranchesExecuted,executions\n" +event_map[Bcm]="ConditionalBranchMispredictions,mispredictions/op\n" +event_map[Bi]="IndirectBranchesExecuted,executions\n" +event_map[Bim]="IndirectBranchMispredictions,mispredictions/op\n" + +now_ms() { + echo -n $(date +%s%N | cut -b1-13) +} + +# Run cachegrind on a given benchmark and echo the results. +ITERATION_COUNT=$4 +START_TIME=$(now_ms) + +valgrind \ + --quiet \ + --tool=cachegrind \ + "$1" "$2" >/dev/null + +TIME_ELAPSED=$(bc <<< "$(now_ms) - ${START_TIME}") + +# Generate report using cg_annotate and extract the header and totals of the +# recorded events valgrind was configured to record. +CG_RESULTS=$(cg_annotate $(ls -t cachegrind.out.* | head -1)) +CG_HEADERS=$(grep -B2 'PROGRAM TOTALS$' <<< "$CG_RESULTS" | head -1 | sed -E 's/\s+/\n/g' | sed '/^$/d') +CG_TOTALS=$(grep 'PROGRAM TOTALS$' <<< "$CG_RESULTS" | head -1 | grep -Po '[0-9,]+\s' | tr -d ', ') + +TOTALS_ARRAY=($CG_TOTALS) +HEADERS_ARRAY=($CG_HEADERS) + +declare -A header_map +for i in "${!TOTALS_ARRAY[@]}"; do + header_map[${HEADERS_ARRAY[$i]}]=$i +done + +# Map the results to the format that the benchmark script expects. +for i in "${!TOTALS_ARRAY[@]}"; do + TOTAL=${TOTALS_ARRAY[$i]} + + # Labels and unit descriptions are packed together in the map. + EVENT_TUPLE=${event_map[${HEADERS_ARRAY[$i]}]} + IFS=$',' read -d '\n' -ra EVENT_VALUES < <(printf "%s" "$EVENT_TUPLE") + EVENT_NAME="${EVENT_VALUES[0]}" + UNIT="${EVENT_VALUES[1]}" + + case ${HEADERS_ARRAY[$i]} in + I1mr | ILmr) + REF=${TOTALS_ARRAY[header_map["Ir"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + D1mr | DLmr) + REF=${TOTALS_ARRAY[header_map["Dr"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + D1mw | DLmw) + REF=${TOTALS_ARRAY[header_map["Dw"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + Bcm) + REF=${TOTALS_ARRAY[header_map["Bc"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + Bim) + REF=${TOTALS_ARRAY[header_map["Bi"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + *) + OPS_PER_SEC=$(bc -l <<< "$TOTAL") + ;; + esac + + STD_DEV="0%" + RUNS="1" + + if [[ $OPS_PER_SEC =~ ^[+-]?[0-9]*$ ]] + then # $OPS_PER_SEC is integer + printf "%s#%s x %.0f %s ±%s (%d runs sampled)\n" \ + "$3" "$EVENT_NAME" "$OPS_PER_SEC" "$UNIT" "$STD_DEV" "$RUNS" + else # $OPS_PER_SEC is float + printf "%s#%s x %.10f %s ±%s (%d runs sampled)\n" \ + "$3" "$EVENT_NAME" "$OPS_PER_SEC" "$UNIT" "$STD_DEV" "$RUNS" + fi + +done From 948f678f935eeff9484468bcc3c9edfd0417d61d Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Tue, 14 Jun 2022 22:03:43 -0500 Subject: [PATCH 085/102] Prototyping function overload resolution (#508) --- prototyping/Luau/FunctionTypes.agda | 38 -- prototyping/Luau/ResolveOverloads.agda | 98 ++++ prototyping/Luau/StrictMode.agda | 2 +- prototyping/Luau/StrictMode/ToString.agda | 5 +- prototyping/Luau/Subtyping.agda | 13 +- prototyping/Luau/TypeCheck.agda | 8 +- prototyping/Luau/TypeNormalization.agda | 8 +- prototyping/Luau/TypeSaturation.agda | 66 +++ prototyping/Properties.agda | 1 - prototyping/Properties/DecSubtyping.agda | 138 +++++- prototyping/Properties/FunctionTypes.agda | 150 ------ prototyping/Properties/ResolveOverloads.agda | 189 ++++++++ prototyping/Properties/StrictMode.agda | 215 +++++---- prototyping/Properties/Subtyping.agda | 165 ++++--- prototyping/Properties/TypeCheck.agda | 12 +- prototyping/Properties/TypeNormalization.agda | 46 +- prototyping/Properties/TypeSaturation.agda | 433 ++++++++++++++++++ 17 files changed, 1207 insertions(+), 380 deletions(-) delete mode 100644 prototyping/Luau/FunctionTypes.agda create mode 100644 prototyping/Luau/ResolveOverloads.agda create mode 100644 prototyping/Luau/TypeSaturation.agda delete mode 100644 prototyping/Properties/FunctionTypes.agda create mode 100644 prototyping/Properties/ResolveOverloads.agda create mode 100644 prototyping/Properties/TypeSaturation.agda diff --git a/prototyping/Luau/FunctionTypes.agda b/prototyping/Luau/FunctionTypes.agda deleted file mode 100644 index 7607052b..00000000 --- a/prototyping/Luau/FunctionTypes.agda +++ /dev/null @@ -1,38 +0,0 @@ -{-# OPTIONS --rewriting #-} - -open import FFI.Data.Either using (Either; Left; Right) -open import Luau.Type using (Type; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) -open import Luau.TypeNormalization using (normalize) - -module Luau.FunctionTypes where - --- The domain of a normalized type -srcⁿ : Type → Type -srcⁿ (S ⇒ T) = S -srcⁿ (S ∩ T) = srcⁿ S ∪ srcⁿ T -srcⁿ never = unknown -srcⁿ T = never - --- To get the domain of a type, we normalize it first We need to do --- this, since if we try to use it on non-normalized types, we get --- --- src(number ∩ string) = src(number) ∪ src(string) = never ∪ never --- src(never) = unknown --- --- so src doesn't respect type equivalence. -src : Type → Type -src (S ⇒ T) = S -src T = srcⁿ(normalize T) - --- The codomain of a type -tgt : Type → Type -tgt nil = never -tgt (S ⇒ T) = T -tgt never = never -tgt unknown = unknown -tgt number = never -tgt boolean = never -tgt string = never -tgt (S ∪ T) = (tgt S) ∪ (tgt T) -tgt (S ∩ T) = (tgt S) ∩ (tgt T) - diff --git a/prototyping/Luau/ResolveOverloads.agda b/prototyping/Luau/ResolveOverloads.agda new file mode 100644 index 00000000..67175176 --- /dev/null +++ b/prototyping/Luau/ResolveOverloads.agda @@ -0,0 +1,98 @@ +{-# OPTIONS --rewriting #-} + +module Luau.ResolveOverloads where + +open import FFI.Data.Either using (Left; Right) +open import Luau.Subtyping using (_<:_; _≮:_; Language; witness; scalar; unknown; never; function-ok) +open import Luau.Type using (Type ; _⇒_; _∩_; _∪_; unknown; never) +open import Luau.TypeSaturation using (saturate) +open import Luau.TypeNormalization using (normalize) +open import Properties.Contradiction using (CONTRADICTION) +open import Properties.DecSubtyping using (dec-subtyping; dec-subtypingⁿ; <:-impl-<:ᵒ) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; <:-trans; <:-trans-≮:; ≮:-trans-<:; <:-∩-left; <:-∩-right; <:-∩-glb; <:-impl-¬≮:; <:-unknown; <:-function; function-≮:-never; <:-never; unknown-≮:-function; scalar-≮:-function; ≮:-∪-right; scalar-≮:-never; <:-∪-left; <:-∪-right) +open import Properties.TypeNormalization using (Normal; FunType; normal; _⇒_; _∩_; _∪_; never; unknown; <:-normalize; normalize-<:; fun-≮:-never; unknown-≮:-fun; scalar-≮:-fun) +open import Properties.TypeSaturation using (Overloads; Saturated; _⊆ᵒ_; _<:ᵒ_; normal-saturate; saturated; <:-saturate; saturate-<:; defn; here; left; right) + +-- The domain of a normalized type +srcⁿ : Type → Type +srcⁿ (S ⇒ T) = S +srcⁿ (S ∩ T) = srcⁿ S ∪ srcⁿ T +srcⁿ never = unknown +srcⁿ T = never + +-- To get the domain of a type, we normalize it first We need to do +-- this, since if we try to use it on non-normalized types, we get +-- +-- src(number ∩ string) = src(number) ∪ src(string) = never ∪ never +-- src(never) = unknown +-- +-- so src doesn't respect type equivalence. +src : Type → Type +src (S ⇒ T) = S +src T = srcⁿ(normalize T) + +-- Calculate the result of applying a function type `F` to an argument type `V`. +-- We do this by finding an overload of `F` that has the most precise type, +-- that is an overload `(Sʳ ⇒ Tʳ)` where `V <: Sʳ` and moreover +-- for any other such overload `(S ⇒ T)` we have that `Tʳ <: T`. + +-- For example if `F` is `(number -> number) & (nil -> nil) & (number? -> number?)` +-- then to resolve `F` with argument type `number`, we pick the `number -> number` +-- overload, but if the argument is `number?`, we pick `number? -> number?`./ + +-- Not all types have such a most precise overload, but saturated ones do. + +data ResolvedTo F G V : Set where + + yes : ∀ Sʳ Tʳ → + + Overloads F (Sʳ ⇒ Tʳ) → + (V <: Sʳ) → + (∀ {S T} → Overloads G (S ⇒ T) → (V <: S) → (Tʳ <: T)) → + -------------------------------------------- + ResolvedTo F G V + + no : + + (∀ {S T} → Overloads G (S ⇒ T) → (V ≮: S)) → + -------------------------------------------- + ResolvedTo F G V + +Resolved : Type → Type → Set +Resolved F V = ResolvedTo F F V + +target : ∀ {F V} → Resolved F V → Type +target (yes _ T _ _ _) = T +target (no _) = unknown + +-- We can resolve any saturated function type +resolveˢ : ∀ {F G V} → FunType G → Saturated F → Normal V → (G ⊆ᵒ F) → ResolvedTo F G V +resolveˢ (Sⁿ ⇒ Tⁿ) (defn sat-∩ sat-∪) Vⁿ G⊆F with dec-subtypingⁿ Vⁿ Sⁿ +resolveˢ (Sⁿ ⇒ Tⁿ) (defn sat-∩ sat-∪) Vⁿ G⊆F | Left V≮:S = no (λ { here → V≮:S }) +resolveˢ (Sⁿ ⇒ Tⁿ) (defn sat-∩ sat-∪) Vⁿ G⊆F | Right V<:S = yes _ _ (G⊆F here) V<:S (λ { here _ → <:-refl }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F with resolveˢ Gᶠ (defn sat-∩ sat-∪) Vⁿ (G⊆F ∘ left) | resolveˢ Hᶠ (defn sat-∩ sat-∪) Vⁿ (G⊆F ∘ right) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | yes S₁ T₁ o₁ V<:S₁ tgt₁ | yes S₂ T₂ o₂ V<:S₂ tgt₂ with sat-∩ o₁ o₂ +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | yes S₁ T₁ o₁ V<:S₁ tgt₁ | yes S₂ T₂ o₂ V<:S₂ tgt₂ | defn o p₁ p₂ = + yes _ _ o (<:-trans (<:-∩-glb V<:S₁ V<:S₂) p₁) (λ { (left o) p → <:-trans p₂ (<:-trans <:-∩-left (tgt₁ o p)) ; (right o) p → <:-trans p₂ (<:-trans <:-∩-right (tgt₂ o p)) }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | yes S₁ T₁ o₁ V<:S₁ tgt₁ | no src₂ = + yes _ _ o₁ V<:S₁ (λ { (left o) p → tgt₁ o p ; (right o) p → CONTRADICTION (<:-impl-¬≮: p (src₂ o)) }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | no src₁ | yes S₂ T₂ o₂ V<:S₂ tgt₂ = + yes _ _ o₂ V<:S₂ (λ { (left o) p → CONTRADICTION (<:-impl-¬≮: p (src₁ o)) ; (right o) p → tgt₂ o p }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | no src₁ | no src₂ = + no (λ { (left o) → src₁ o ; (right o) → src₂ o }) + +-- Which means we can resolve any normalized type, by saturating it first +resolveᶠ : ∀ {F V} → FunType F → Normal V → Type +resolveᶠ Fᶠ Vⁿ = target (resolveˢ (normal-saturate Fᶠ) (saturated Fᶠ) Vⁿ (λ o → o)) + +resolveⁿ : ∀ {F V} → Normal F → Normal V → Type +resolveⁿ (Sⁿ ⇒ Tⁿ) Vⁿ = resolveᶠ (Sⁿ ⇒ Tⁿ) Vⁿ +resolveⁿ (Fᶠ ∩ Gᶠ) Vⁿ = resolveᶠ (Fᶠ ∩ Gᶠ) Vⁿ +resolveⁿ (Sⁿ ∪ Tˢ) Vⁿ = unknown +resolveⁿ unknown Vⁿ = unknown +resolveⁿ never Vⁿ = never + +-- Which means we can resolve any type, by normalizing it first +resolve : Type → Type → Type +resolve F V = resolveⁿ (normal F) (normal V) diff --git a/prototyping/Luau/StrictMode.agda b/prototyping/Luau/StrictMode.agda index d3c0f153..0628951b 100644 --- a/prototyping/Luau/StrictMode.agda +++ b/prototyping/Luau/StrictMode.agda @@ -5,8 +5,8 @@ module Luau.StrictMode where open import Agda.Builtin.Equality using (_≡_) open import FFI.Data.Maybe using (just; nothing) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; var; binexp; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; local_←_; _∙_; done; return; name; +; -; *; /; <; >; <=; >=; ··) -open import Luau.FunctionTypes using (src; tgt) open import Luau.Type using (Type; nil; number; string; boolean; _⇒_; _∪_; _∩_) +open import Luau.ResolveOverloads using (src; resolve) open import Luau.Subtyping using (_≮:_) open import Luau.Heap using (Heap; function_is_end) renaming (_[_] to _[_]ᴴ) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) diff --git a/prototyping/Luau/StrictMode/ToString.agda b/prototyping/Luau/StrictMode/ToString.agda index eee5722e..7c5f0253 100644 --- a/prototyping/Luau/StrictMode/ToString.agda +++ b/prototyping/Luau/StrictMode/ToString.agda @@ -4,7 +4,7 @@ module Luau.StrictMode.ToString where open import Agda.Builtin.Nat using (Nat; suc) open import FFI.Data.String using (String; _++_) -open import Luau.Subtyping using (_≮:_; Tree; witness; scalar; function; function-ok; function-err) +open import Luau.Subtyping using (_≮:_; Tree; witness; scalar; function; function-ok; function-err; function-tgt) open import Luau.StrictMode using (Warningᴱ; Warningᴮ; UnallocatedAddress; UnboundVariable; FunctionCallMismatch; FunctionDefnMismatch; BlockMismatch; app₁; app₂; BinOpMismatch₁; BinOpMismatch₂; bin₁; bin₂; block₁; return; LocalVarMismatch; local₁; local₂; function₁; function₂; heap; expr; block; addr) open import Luau.Syntax using (Expr; val; yes; var; var_∈_; _⟨_⟩∈_; _$_; addr; number; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name) open import Luau.Type using (number; boolean; string; nil) @@ -27,8 +27,9 @@ treeToString (scalar boolean) n v = v ++ " is a boolean" treeToString (scalar string) n v = v ++ " is a string" treeToString (scalar nil) n v = v ++ " is nil" treeToString function n v = v ++ " is a function" -treeToString (function-ok t) n v = treeToString t n (v ++ "()") +treeToString (function-ok s t) n v = treeToString t (suc n) (v ++ "(" ++ w ++ ")") ++ " when\n " ++ treeToString s (suc n) w where w = tmp n treeToString (function-err t) n v = v ++ "(" ++ w ++ ") can error when\n " ++ treeToString t (suc n) w where w = tmp n +treeToString (function-tgt t) n v = treeToString t n (v ++ "()") subtypeWarningToString : ∀ {T U} → (T ≮: U) → String subtypeWarningToString (witness t p q) = "\n because provided type contains v, where " ++ treeToString t 0 "v" diff --git a/prototyping/Luau/Subtyping.agda b/prototyping/Luau/Subtyping.agda index 624b6be4..dc2abed0 100644 --- a/prototyping/Luau/Subtyping.agda +++ b/prototyping/Luau/Subtyping.agda @@ -13,8 +13,9 @@ data Tree : Set where scalar : ∀ {T} → Scalar T → Tree function : Tree - function-ok : Tree → Tree + function-ok : Tree → Tree → Tree function-err : Tree → Tree + function-tgt : Tree → Tree data Language : Type → Tree → Set data ¬Language : Type → Tree → Set @@ -23,8 +24,10 @@ data Language where scalar : ∀ {T} → (s : Scalar T) → Language T (scalar s) function : ∀ {T U} → Language (T ⇒ U) function - function-ok : ∀ {T U u} → (Language U u) → Language (T ⇒ U) (function-ok u) + function-ok₁ : ∀ {T U t u} → (¬Language T t) → Language (T ⇒ U) (function-ok t u) + function-ok₂ : ∀ {T U t u} → (Language U u) → Language (T ⇒ U) (function-ok t u) function-err : ∀ {T U t} → (¬Language T t) → Language (T ⇒ U) (function-err t) + function-tgt : ∀ {T U t} → (Language U t) → Language (T ⇒ U) (function-tgt t) left : ∀ {T U t} → Language T t → Language (T ∪ U) t right : ∀ {T U u} → Language U u → Language (T ∪ U) u _,_ : ∀ {T U t} → Language T t → Language U t → Language (T ∩ U) t @@ -34,11 +37,13 @@ data ¬Language where scalar-scalar : ∀ {S T} → (s : Scalar S) → (Scalar T) → (S ≢ T) → ¬Language T (scalar s) scalar-function : ∀ {S} → (Scalar S) → ¬Language S function - scalar-function-ok : ∀ {S u} → (Scalar S) → ¬Language S (function-ok u) + scalar-function-ok : ∀ {S t u} → (Scalar S) → ¬Language S (function-ok t u) scalar-function-err : ∀ {S t} → (Scalar S) → ¬Language S (function-err t) + scalar-function-tgt : ∀ {S t} → (Scalar S) → ¬Language S (function-tgt t) function-scalar : ∀ {S T U} (s : Scalar S) → ¬Language (T ⇒ U) (scalar s) - function-ok : ∀ {T U u} → (¬Language U u) → ¬Language (T ⇒ U) (function-ok u) + function-ok : ∀ {T U t u} → (Language T t) → (¬Language U u) → ¬Language (T ⇒ U) (function-ok t u) function-err : ∀ {T U t} → (Language T t) → ¬Language (T ⇒ U) (function-err t) + function-tgt : ∀ {T U t} → (¬Language U t) → ¬Language (T ⇒ U) (function-tgt t) _,_ : ∀ {T U t} → ¬Language T t → ¬Language U t → ¬Language (T ∪ U) t left : ∀ {T U t} → ¬Language T t → ¬Language (T ∩ U) t right : ∀ {T U u} → ¬Language U u → ¬Language (T ∩ U) u diff --git a/prototyping/Luau/TypeCheck.agda b/prototyping/Luau/TypeCheck.agda index d4fabb90..1abc1eda 100644 --- a/prototyping/Luau/TypeCheck.agda +++ b/prototyping/Luau/TypeCheck.agda @@ -3,16 +3,18 @@ module Luau.TypeCheck where open import Agda.Builtin.Equality using (_≡_) +open import FFI.Data.Either using (Either; Left; Right) open import FFI.Data.Maybe using (Maybe; just) +open import Luau.ResolveOverloads using (resolve) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; number; bool; string; val; var; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; binexp; local_←_; _∙_; done; return; name; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.Var using (Var) open import Luau.Addr using (Addr) -open import Luau.FunctionTypes using (src; tgt) open import Luau.Heap using (Heap; Object; function_is_end) renaming (_[_] to _[_]ᴴ) open import Luau.Type using (Type; nil; unknown; number; boolean; string; _⇒_) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) open import FFI.Data.Vector using (Vector) open import FFI.Data.Maybe using (Maybe; just; nothing) +open import Properties.DecSubtyping using (dec-subtyping) open import Properties.Product using (_×_; _,_) orUnknown : Maybe Type → Type @@ -113,8 +115,8 @@ data _⊢ᴱ_∈_ where Γ ⊢ᴱ M ∈ T → Γ ⊢ᴱ N ∈ U → - ---------------------- - Γ ⊢ᴱ (M $ N) ∈ (tgt T) + ---------------------------- + Γ ⊢ᴱ (M $ N) ∈ (resolve T U) function : ∀ {f x B T U V Γ} → diff --git a/prototyping/Luau/TypeNormalization.agda b/prototyping/Luau/TypeNormalization.agda index 341883ea..08f14474 100644 --- a/prototyping/Luau/TypeNormalization.agda +++ b/prototyping/Luau/TypeNormalization.agda @@ -2,11 +2,7 @@ module Luau.TypeNormalization where open import Luau.Type using (Type; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) --- The top non-function type -¬function : Type -¬function = number ∪ (string ∪ (nil ∪ boolean)) - --- Unions and intersections of normalized types +-- Operations on normalized types _∪ᶠ_ : Type → Type → Type _∪ⁿˢ_ : Type → Type → Type _∩ⁿˢ_ : Type → Type → Type @@ -23,8 +19,8 @@ F ∪ᶠ G = F ∪ G S ∪ⁿ (T₁ ∪ T₂) = (S ∪ⁿ T₁) ∪ T₂ S ∪ⁿ unknown = unknown S ∪ⁿ never = S -unknown ∪ⁿ T = unknown never ∪ⁿ T = T +unknown ∪ⁿ T = unknown (S₁ ∪ S₂) ∪ⁿ G = (S₁ ∪ⁿ G) ∪ S₂ F ∪ⁿ G = F ∪ᶠ G diff --git a/prototyping/Luau/TypeSaturation.agda b/prototyping/Luau/TypeSaturation.agda new file mode 100644 index 00000000..fa24ff73 --- /dev/null +++ b/prototyping/Luau/TypeSaturation.agda @@ -0,0 +1,66 @@ +module Luau.TypeSaturation where + +open import Luau.Type using (Type; _⇒_; _∩_; _∪_) +open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_) + +-- So, there's a problem with overloaded functions +-- (of the form (S_1 ⇒ T_1) ∩⋯∩ (S_n ⇒ T_n)) +-- which is that it's not good enough to compare them +-- for subtyping by comparing all of their overloads. + +-- For example (nil → nil) is a subtype of (number? → number?) ∩ (string? → string?) +-- but not a subtype of any of its overloads. + +-- To fix this, we adapt the semantic subtyping algorithm for +-- function types, given in +-- https://www.irif.fr/~gc/papers/covcon-again.pdf and +-- https://pnwamk.github.io/sst-tutorial/ + +-- A function type is *intersection-saturated* if for any overloads +-- (S₁ ⇒ T₁) and (S₂ ⇒ T₂), there exists an overload which is a subtype +-- of ((S₁ ∩ S₂) ⇒ (T₁ ∩ T₂)). + +-- A function type is *union-saturated* if for any overloads +-- (S₁ ⇒ T₁) and (S₂ ⇒ T₂), there exists an overload which is a subtype +-- of ((S₁ ∪ S₂) ⇒ (T₁ ∪ T₂)). + +-- A function type is *saturated* if it is both intersection- and +-- union-saturated. + +-- For example (number? → number?) ∩ (string? → string?) +-- is not saturated, but (number? → number?) ∩ (string? → string?) ∩ (nil → nil) ∩ ((number ∪ string)? → (number ∪ string)?) +-- is. + +-- Saturated function types have the nice property that they can ber +-- compared by just comparing their overloads: F <: G whenever for any +-- overload of G, there is an overload os F which is a subtype of it. + +-- Forunately every function type can be saturated! +_⋓_ : Type → Type → Type +(S₁ ⇒ T₁) ⋓ (S₂ ⇒ T₂) = (S₁ ∪ⁿ S₂) ⇒ (T₁ ∪ⁿ T₂) +(F₁ ∩ G₁) ⋓ F₂ = (F₁ ⋓ F₂) ∩ (G₁ ⋓ F₂) +F₁ ⋓ (F₂ ∩ G₂) = (F₁ ⋓ F₂) ∩ (F₁ ⋓ G₂) +F ⋓ G = F ∩ G + +_⋒_ : Type → Type → Type +(S₁ ⇒ T₁) ⋒ (S₂ ⇒ T₂) = (S₁ ∩ⁿ S₂) ⇒ (T₁ ∩ⁿ T₂) +(F₁ ∩ G₁) ⋒ F₂ = (F₁ ⋒ F₂) ∩ (G₁ ⋒ F₂) +F₁ ⋒ (F₂ ∩ G₂) = (F₁ ⋒ F₂) ∩ (F₁ ⋒ G₂) +F ⋒ G = F ∩ G + +_∩ᵘ_ : Type → Type → Type +F ∩ᵘ G = (F ∩ G) ∩ (F ⋓ G) + +_∩ⁱ_ : Type → Type → Type +F ∩ⁱ G = (F ∩ G) ∩ (F ⋒ G) + +∪-saturate : Type → Type +∪-saturate (F ∩ G) = (∪-saturate F ∩ᵘ ∪-saturate G) +∪-saturate F = F + +∩-saturate : Type → Type +∩-saturate (F ∩ G) = (∩-saturate F ∩ⁱ ∩-saturate G) +∩-saturate F = F + +saturate : Type → Type +saturate F = ∪-saturate (∩-saturate F) diff --git a/prototyping/Properties.agda b/prototyping/Properties.agda index b696c0fa..f883a3ea 100644 --- a/prototyping/Properties.agda +++ b/prototyping/Properties.agda @@ -7,7 +7,6 @@ import Properties.Dec import Properties.DecSubtyping import Properties.Equality import Properties.Functions -import Properties.FunctionTypes import Properties.Remember import Properties.Step import Properties.StrictMode diff --git a/prototyping/Properties/DecSubtyping.agda b/prototyping/Properties/DecSubtyping.agda index 332520a9..8dc7a446 100644 --- a/prototyping/Properties/DecSubtyping.agda +++ b/prototyping/Properties/DecSubtyping.agda @@ -4,21 +4,23 @@ module Properties.DecSubtyping where open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) -open import Luau.FunctionTypes using (src; srcⁿ; tgt) -open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; scalar-scalar; function-scalar; function-ok; function-ok₁; function-ok₂; function-err; function-tgt; left; right; _,_) open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) +open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_) +open import Luau.TypeSaturation using (saturate) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Functions using (_∘_) -open import Properties.Subtyping using (<:-refl; <:-trans; ≮:-trans-<:; <:-trans-≮:; <:-never; <:-unknown; <:-∪-left; <:-∪-right; <:-∪-lub; ≮:-∪-left; ≮:-∪-right; <:-∩-left; <:-∩-right; <:-∩-glb; ≮:-∩-left; ≮:-∩-right; dec-language; scalar-<:; <:-everything; <:-function; ≮:-function-left; ≮:-function-right) -open import Properties.TypeNormalization using (FunType; Normal; never; unknown; _∩_; _∪_; _⇒_; normal; <:-normalize; normalize-<:) -open import Properties.FunctionTypes using (fun-¬scalar; ¬fun-scalar; fun-function; src-unknown-≮:; tgt-never-≮:; src-tgtᶠ-<:) +open import Properties.Subtyping using (<:-refl; <:-trans; ≮:-trans-<:; <:-trans-≮:; <:-never; <:-unknown; <:-∪-left; <:-∪-right; <:-∪-lub; ≮:-∪-left; ≮:-∪-right; <:-∩-left; <:-∩-right; <:-∩-glb; ≮:-∩-left; ≮:-∩-right; dec-language; scalar-<:; <:-everything; <:-function; ≮:-function-left; ≮:-function-right; <:-impl-¬≮:; <:-intersect; <:-function-∩-∪; <:-function-∩; <:-union; ≮:-left-∪; ≮:-right-∪; <:-∩-distr-∪; <:-impl-⊇; language-comp) +open import Properties.TypeNormalization using (FunType; Normal; never; unknown; _∩_; _∪_; _⇒_; normal; <:-normalize; normalize-<:; normal-∩ⁿ; normal-∪ⁿ; ∪-<:-∪ⁿ; ∪ⁿ-<:-∪; ∩ⁿ-<:-∩; ∩-<:-∩ⁿ; normalᶠ; fun-top; fun-function; fun-¬scalar) +open import Properties.TypeSaturation using (Overloads; Saturated; _⊆ᵒ_; _<:ᵒ_; defn; here; left; right; ov-language; ov-<:; saturated; normal-saturate; normal-overload-src; normal-overload-tgt; saturate-<:; <:-saturate; <:ᵒ-impl-<:; _>>=ˡ_; _>>=ʳ_) open import Properties.Equality using (_≢_) --- Honest this terminates, since src and tgt reduce the depth of nested arrows +-- Honest this terminates, since saturation maintains the depth of nested arrows {-# TERMINATING #-} dec-subtypingˢⁿ : ∀ {T U} → Scalar T → Normal U → Either (T ≮: U) (T <: U) -dec-subtypingᶠ : ∀ {T U} → FunType T → FunType U → Either (T ≮: U) (T <: U) -dec-subtypingᶠⁿ : ∀ {T U} → FunType T → Normal U → Either (T ≮: U) (T <: U) +dec-subtypingˢᶠ : ∀ {F G} → FunType F → Saturated F → FunType G → Either (F ≮: G) (F <:ᵒ G) +dec-subtypingᶠ : ∀ {F G} → FunType F → FunType G → Either (F ≮: G) (F <: G) +dec-subtypingᶠⁿ : ∀ {F U} → FunType F → Normal U → Either (F ≮: U) (F <: U) dec-subtypingⁿ : ∀ {T U} → Normal T → Normal U → Either (T ≮: U) (T <: U) dec-subtyping : ∀ T U → Either (T ≮: U) (T <: U) @@ -26,22 +28,116 @@ dec-subtypingˢⁿ T U with dec-language _ (scalar T) dec-subtypingˢⁿ T U | Left p = Left (witness (scalar T) (scalar T) p) dec-subtypingˢⁿ T U | Right p = Right (scalar-<: T p) -dec-subtypingᶠ {T = T} _ (U ⇒ V) with dec-subtypingⁿ U (normal (src T)) | dec-subtypingⁿ (normal (tgt T)) V -dec-subtypingᶠ {T = T} _ (U ⇒ V) | Left p | q = Left (≮:-trans-<: (src-unknown-≮: (≮:-trans-<: p (<:-normalize (src T)))) (<:-function <:-refl <:-unknown)) -dec-subtypingᶠ {T = T} _ (U ⇒ V) | Right p | Left q = Left (≮:-trans-<: (tgt-never-≮: (<:-trans-≮: (normalize-<: (tgt T)) q)) (<:-trans (<:-function <:-never <:-refl) <:-∪-right)) -dec-subtypingᶠ T (U ⇒ V) | Right p | Right q = Right (src-tgtᶠ-<: T (<:-trans p (normalize-<: _)) (<:-trans (<:-normalize _) q)) +dec-subtypingˢᶠ {F} {S ⇒ T} Fᶠ (defn sat-∩ sat-∪) (Sⁿ ⇒ Tⁿ) = result (top Fᶠ (λ o → o)) where -dec-subtypingᶠ T (U ∩ V) with dec-subtypingᶠ T U | dec-subtypingᶠ T V -dec-subtypingᶠ T (U ∩ V) | Left p | q = Left (≮:-∩-left p) -dec-subtypingᶠ T (U ∩ V) | Right p | Left q = Left (≮:-∩-right q) -dec-subtypingᶠ T (U ∩ V) | Right p | Right q = Right (<:-∩-glb p q) + data Top G : Set where + + defn : ∀ Sᵗ Tᵗ → + + Overloads F (Sᵗ ⇒ Tᵗ) → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → (S′ <: Sᵗ)) → + ------------- + Top G + + top : ∀ {G} → (FunType G) → (G ⊆ᵒ F) → Top G + top {S′ ⇒ T′} _ G⊆F = defn S′ T′ (G⊆F here) (λ { here → <:-refl }) + top (Gᶠ ∩ Hᶠ) G⊆F with top Gᶠ (G⊆F ∘ left) | top Hᶠ (G⊆F ∘ right) + top (Gᶠ ∩ Hᶠ) G⊆F | defn Rᵗ Sᵗ p p₁ | defn Tᵗ Uᵗ q q₁ with sat-∪ p q + top (Gᶠ ∩ Hᶠ) G⊆F | defn Rᵗ Sᵗ p p₁ | defn Tᵗ Uᵗ q q₁ | defn n r r₁ = defn _ _ n + (λ { (left o) → <:-trans (<:-trans (p₁ o) <:-∪-left) r ; (right o) → <:-trans (<:-trans (q₁ o) <:-∪-right) r }) + + result : Top F → Either (F ≮: (S ⇒ T)) (F <:ᵒ (S ⇒ T)) + result (defn Sᵗ Tᵗ oᵗ srcᵗ) with dec-subtypingⁿ Sⁿ (normal-overload-src Fᶠ oᵗ) + result (defn Sᵗ Tᵗ oᵗ srcᵗ) | Left (witness s Ss ¬Sᵗs) = Left (witness (function-err s) (ov-language Fᶠ (λ o → function-err (<:-impl-⊇ (srcᵗ o) s ¬Sᵗs))) (function-err Ss)) + result (defn Sᵗ Tᵗ oᵗ srcᵗ) | Right S<:Sᵗ = result₀ (largest Fᶠ (λ o → o)) where + + data LargestSrc (G : Type) : Set where + + yes : ∀ S₀ T₀ → + + Overloads F (S₀ ⇒ T₀) → + T₀ <: T → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → T′ <: T → (S′ <: S₀)) → + ----------------------- + LargestSrc G + + no : ∀ S₀ T₀ → + + Overloads F (S₀ ⇒ T₀) → + T₀ ≮: T → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → T₀ <: T′) → + ----------------------- + LargestSrc G + + largest : ∀ {G} → (FunType G) → (G ⊆ᵒ F) → LargestSrc G + largest {S′ ⇒ T′} (S′ⁿ ⇒ T′ⁿ) G⊆F with dec-subtypingⁿ T′ⁿ Tⁿ + largest {S′ ⇒ T′} (S′ⁿ ⇒ T′ⁿ) G⊆F | Left T′≮:T = no S′ T′ (G⊆F here) T′≮:T λ { here → <:-refl } + largest {S′ ⇒ T′} (S′ⁿ ⇒ T′ⁿ) G⊆F | Right T′<:T = yes S′ T′ (G⊆F here) T′<:T (λ { here _ → <:-refl }) + largest (Gᶠ ∩ Hᶠ) GH⊆F with largest Gᶠ (GH⊆F ∘ left) | largest Hᶠ (GH⊆F ∘ right) + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ with sat-∩ o₁ o₂ + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ | defn o src tgt with dec-subtypingⁿ (normal-overload-tgt Fᶠ o) Tⁿ + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ | defn o src tgt | Left T₀≮:T = no _ _ o T₀≮:T (λ { (left o) → <:-trans tgt (<:-trans <:-∩-left (tgt₁ o)) ; (right o) → <:-trans tgt (<:-trans <:-∩-right (tgt₂ o)) }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ | defn o src tgt | Right T₀<:T = yes _ _ o T₀<:T (λ { (left o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₁ o) T₁≮:T)) ; (right o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₂ o) T₂≮:T)) }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | yes S₂ T₂ o₂ T₂<:T src₂ = yes S₂ T₂ o₂ T₂<:T (λ { (left o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₁ o) T₁≮:T)) ; (right o) p → src₂ o p }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | yes S₁ T₁ o₁ T₁<:T src₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ = yes S₁ T₁ o₁ T₁<:T (λ { (left o) p → src₁ o p ; (right o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₂ o) T₂≮:T)) }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | yes S₁ T₁ o₁ T₁<:T src₁ | yes S₂ T₂ o₂ T₂<:T src₂ with sat-∪ o₁ o₂ + largest (Gᶠ ∩ Hᶠ) GH⊆F | yes S₁ T₁ o₁ T₁<:T src₁ | yes S₂ T₂ o₂ T₂<:T src₂ | defn o src tgt = yes _ _ o (<:-trans tgt (<:-∪-lub T₁<:T T₂<:T)) + (λ { (left o) T′<:T → <:-trans (src₁ o T′<:T) (<:-trans <:-∪-left src) + ; (right o) T′<:T → <:-trans (src₂ o T′<:T) (<:-trans <:-∪-right src) + }) + + result₀ : LargestSrc F → Either (F ≮: (S ⇒ T)) (F <:ᵒ (S ⇒ T)) + result₀ (no S₀ T₀ o₀ (witness t T₀t ¬Tt) tgt₀) = Left (witness (function-tgt t) (ov-language Fᶠ (λ o → function-tgt (tgt₀ o t T₀t))) (function-tgt ¬Tt)) + result₀ (yes S₀ T₀ o₀ T₀<:T src₀) with dec-subtypingⁿ Sⁿ (normal-overload-src Fᶠ o₀) + result₀ (yes S₀ T₀ o₀ T₀<:T src₀) | Right S<:S₀ = Right λ { here → defn o₀ S<:S₀ T₀<:T } + result₀ (yes S₀ T₀ o₀ T₀<:T src₀) | Left (witness s Ss ¬S₀s) = Left (result₁ (smallest Fᶠ (λ o → o))) where + + data SmallestTgt (G : Type) : Set where + + defn : ∀ S₁ T₁ → + + Overloads F (S₁ ⇒ T₁) → + Language S₁ s → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → Language S′ s → (T₁ <: T′)) → + ----------------------- + SmallestTgt G + + smallest : ∀ {G} → (FunType G) → (G ⊆ᵒ F) → SmallestTgt G + smallest {S′ ⇒ T′} _ G⊆F with dec-language S′ s + smallest {S′ ⇒ T′} _ G⊆F | Left ¬S′s = defn Sᵗ Tᵗ oᵗ (S<:Sᵗ s Ss) λ { here S′s → CONTRADICTION (language-comp s ¬S′s S′s) } + smallest {S′ ⇒ T′} _ G⊆F | Right S′s = defn S′ T′ (G⊆F here) S′s (λ { here _ → <:-refl }) + smallest (Gᶠ ∩ Hᶠ) GH⊆F with smallest Gᶠ (GH⊆F ∘ left) | smallest Hᶠ (GH⊆F ∘ right) + smallest (Gᶠ ∩ Hᶠ) GH⊆F | defn S₁ T₁ o₁ R₁s tgt₁ | defn S₂ T₂ o₂ R₂s tgt₂ with sat-∩ o₁ o₂ + smallest (Gᶠ ∩ Hᶠ) GH⊆F | defn S₁ T₁ o₁ R₁s tgt₁ | defn S₂ T₂ o₂ R₂s tgt₂ | defn o src tgt = defn _ _ o (src s (R₁s , R₂s)) + (λ { (left o) S′s → <:-trans (<:-trans tgt <:-∩-left) (tgt₁ o S′s) + ; (right o) S′s → <:-trans (<:-trans tgt <:-∩-right) (tgt₂ o S′s) + }) + + result₁ : SmallestTgt F → (F ≮: (S ⇒ T)) + result₁ (defn S₁ T₁ o₁ S₁s tgt₁) with dec-subtypingⁿ (normal-overload-tgt Fᶠ o₁) Tⁿ + result₁ (defn S₁ T₁ o₁ S₁s tgt₁) | Right T₁<:T = CONTRADICTION (language-comp s ¬S₀s (src₀ o₁ T₁<:T s S₁s)) + result₁ (defn S₁ T₁ o₁ S₁s tgt₁) | Left (witness t T₁t ¬Tt) = witness (function-ok s t) (ov-language Fᶠ lemma) (function-ok Ss ¬Tt) where + + lemma : ∀ {S′ T′} → Overloads F (S′ ⇒ T′) → Language (S′ ⇒ T′) (function-ok s t) + lemma {S′} o with dec-language S′ s + lemma {S′} o | Left ¬S′s = function-ok₁ ¬S′s + lemma {S′} o | Right S′s = function-ok₂ (tgt₁ o S′s t T₁t) + +dec-subtypingˢᶠ F Fˢ (G ∩ H) with dec-subtypingˢᶠ F Fˢ G | dec-subtypingˢᶠ F Fˢ H +dec-subtypingˢᶠ F Fˢ (G ∩ H) | Left F≮:G | _ = Left (≮:-∩-left F≮:G) +dec-subtypingˢᶠ F Fˢ (G ∩ H) | _ | Left F≮:H = Left (≮:-∩-right F≮:H) +dec-subtypingˢᶠ F Fˢ (G ∩ H) | Right F<:G | Right F<:H = Right (λ { (left o) → F<:G o ; (right o) → F<:H o }) + +dec-subtypingᶠ F G with dec-subtypingˢᶠ (normal-saturate F) (saturated F) G +dec-subtypingᶠ F G | Left H≮:G = Left (<:-trans-≮: (saturate-<: F) H≮:G) +dec-subtypingᶠ F G | Right H<:G = Right (<:-trans (<:-saturate F) (<:ᵒ-impl-<: (normal-saturate F) G H<:G)) dec-subtypingᶠⁿ T never = Left (witness function (fun-function T) never) dec-subtypingᶠⁿ T unknown = Right <:-unknown dec-subtypingᶠⁿ T (U ⇒ V) = dec-subtypingᶠ T (U ⇒ V) dec-subtypingᶠⁿ T (U ∩ V) = dec-subtypingᶠ T (U ∩ V) dec-subtypingᶠⁿ T (U ∪ V) with dec-subtypingᶠⁿ T U -dec-subtypingᶠⁿ T (U ∪ V) | Left (witness t p q) = Left (witness t p (q , ¬fun-scalar V T p)) +dec-subtypingᶠⁿ T (U ∪ V) | Left (witness t p q) = Left (witness t p (q , fun-¬scalar V T p)) dec-subtypingᶠⁿ T (U ∪ V) | Right p = Right (<:-trans p <:-∪-left) dec-subtypingⁿ never U = Right <:-never @@ -68,3 +164,11 @@ dec-subtyping T U with dec-subtypingⁿ (normal T) (normal U) dec-subtyping T U | Left p = Left (<:-trans-≮: (normalize-<: T) (≮:-trans-<: p (<:-normalize U))) dec-subtyping T U | Right p = Right (<:-trans (<:-normalize T) (<:-trans p (normalize-<: U))) +-- As a corollary, for saturated functions +-- <:ᵒ coincides with <:, that is F is a subtype of (S ⇒ T) precisely +-- when one of its overloads is. + +<:-impl-<:ᵒ : ∀ {F G} → FunType F → Saturated F → FunType G → (F <: G) → (F <:ᵒ G) +<:-impl-<:ᵒ {F} {G} Fᶠ Fˢ Gᶠ F<:G with dec-subtypingˢᶠ Fᶠ Fˢ Gᶠ +<:-impl-<:ᵒ {F} {G} Fᶠ Fˢ Gᶠ F<:G | Left F≮:G = CONTRADICTION (<:-impl-¬≮: F<:G F≮:G) +<:-impl-<:ᵒ {F} {G} Fᶠ Fˢ Gᶠ F<:G | Right F<:ᵒG = F<:ᵒG diff --git a/prototyping/Properties/FunctionTypes.agda b/prototyping/Properties/FunctionTypes.agda deleted file mode 100644 index 514477f1..00000000 --- a/prototyping/Properties/FunctionTypes.agda +++ /dev/null @@ -1,150 +0,0 @@ -{-# OPTIONS --rewriting #-} - -module Properties.FunctionTypes where - -open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) -open import Luau.FunctionTypes using (srcⁿ; src; tgt) -open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) -open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; skalar) -open import Properties.Contradiction using (CONTRADICTION; ¬; ⊥) -open import Properties.Functions using (_∘_) -open import Properties.Subtyping using (<:-refl; ≮:-refl; <:-trans-≮:; skalar-scalar; <:-impl-⊇; skalar-function-ok; language-comp) -open import Properties.TypeNormalization using (FunType; Normal; never; unknown; _∩_; _∪_; _⇒_; normal; <:-normalize; normalize-<:) - --- Properties of src -function-err-srcⁿ : ∀ {T t} → (FunType T) → (¬Language (srcⁿ T) t) → Language T (function-err t) -function-err-srcⁿ (S ⇒ T) p = function-err p -function-err-srcⁿ (S ∩ T) (p₁ , p₂) = (function-err-srcⁿ S p₁ , function-err-srcⁿ T p₂) - -¬function-err-srcᶠ : ∀ {T t} → (FunType T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) -¬function-err-srcᶠ (S ⇒ T) p = function-err p -¬function-err-srcᶠ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) -¬function-err-srcᶠ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) - -¬function-err-srcⁿ : ∀ {T t} → (Normal T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) -¬function-err-srcⁿ never p = never -¬function-err-srcⁿ unknown (scalar ()) -¬function-err-srcⁿ (S ⇒ T) p = function-err p -¬function-err-srcⁿ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) -¬function-err-srcⁿ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) -¬function-err-srcⁿ (S ∪ T) (scalar ()) - -¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) -¬function-err-src {T = S ⇒ T} p = function-err p -¬function-err-src {T = nil} p = scalar-function-err nil -¬function-err-src {T = never} p = never -¬function-err-src {T = unknown} (scalar ()) -¬function-err-src {T = boolean} p = scalar-function-err boolean -¬function-err-src {T = number} p = scalar-function-err number -¬function-err-src {T = string} p = scalar-function-err string -¬function-err-src {T = S ∪ T} p = <:-impl-⊇ (<:-normalize (S ∪ T)) _ (¬function-err-srcⁿ (normal (S ∪ T)) p) -¬function-err-src {T = S ∩ T} p = <:-impl-⊇ (<:-normalize (S ∩ T)) _ (¬function-err-srcⁿ (normal (S ∩ T)) p) - -src-¬function-errᶠ : ∀ {T t} → (FunType T) → Language T (function-err t) → (¬Language (srcⁿ T) t) -src-¬function-errᶠ (S ⇒ T) (function-err p) = p -src-¬function-errᶠ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) - -src-¬function-errⁿ : ∀ {T t} → (Normal T) → Language T (function-err t) → (¬Language (srcⁿ T) t) -src-¬function-errⁿ unknown p = never -src-¬function-errⁿ (S ⇒ T) (function-err p) = p -src-¬function-errⁿ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) -src-¬function-errⁿ (S ∪ T) p = never - -src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) -src-¬function-err {T = S ⇒ T} (function-err p) = p -src-¬function-err {T = unknown} p = never -src-¬function-err {T = S ∪ T} p = src-¬function-errⁿ (normal (S ∪ T)) (<:-normalize (S ∪ T) _ p) -src-¬function-err {T = S ∩ T} p = src-¬function-errⁿ (normal (S ∩ T)) (<:-normalize (S ∩ T) _ p) - -fun-¬scalar : ∀ {S T} (s : Scalar S) → FunType T → ¬Language T (scalar s) -fun-¬scalar s (S ⇒ T) = function-scalar s -fun-¬scalar s (S ∩ T) = left (fun-¬scalar s S) - -¬fun-scalar : ∀ {S T t} (s : Scalar S) → FunType T → Language T t → ¬Language S t -¬fun-scalar s (S ⇒ T) function = scalar-function s -¬fun-scalar s (S ⇒ T) (function-ok p) = scalar-function-ok s -¬fun-scalar s (S ⇒ T) (function-err p) = scalar-function-err s -¬fun-scalar s (S ∩ T) (p₁ , p₂) = ¬fun-scalar s T p₂ - -fun-function : ∀ {T} → FunType T → Language T function -fun-function (S ⇒ T) = function -fun-function (S ∩ T) = (fun-function S , fun-function T) - -srcⁿ-¬scalar : ∀ {S T t} (s : Scalar S) → Normal T → Language T (scalar s) → (¬Language (srcⁿ T) t) -srcⁿ-¬scalar s never (scalar ()) -srcⁿ-¬scalar s unknown p = never -srcⁿ-¬scalar s (S ⇒ T) (scalar ()) -srcⁿ-¬scalar s (S ∩ T) (p₁ , p₂) = CONTRADICTION (language-comp (scalar s) (fun-¬scalar s S) p₁) -srcⁿ-¬scalar s (S ∪ T) p = never - -src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) -src-¬scalar {T = nil} s p = never -src-¬scalar {T = T ⇒ U} s (scalar ()) -src-¬scalar {T = never} s (scalar ()) -src-¬scalar {T = unknown} s p = never -src-¬scalar {T = boolean} s p = never -src-¬scalar {T = number} s p = never -src-¬scalar {T = string} s p = never -src-¬scalar {T = T ∪ U} s p = srcⁿ-¬scalar s (normal (T ∪ U)) (<:-normalize (T ∪ U) (scalar s) p) -src-¬scalar {T = T ∩ U} s p = srcⁿ-¬scalar s (normal (T ∩ U)) (<:-normalize (T ∩ U) (scalar s) p) - -srcⁿ-unknown-≮: : ∀ {T U} → (Normal U) → (T ≮: srcⁿ U) → (U ≮: (T ⇒ unknown)) -srcⁿ-unknown-≮: never (witness t p q) = CONTRADICTION (language-comp t q unknown) -srcⁿ-unknown-≮: unknown (witness t p q) = witness (function-err t) unknown (function-err p) -srcⁿ-unknown-≮: (U ⇒ V) (witness t p q) = witness (function-err t) (function-err q) (function-err p) -srcⁿ-unknown-≮: (U ∩ V) (witness t p q) = witness (function-err t) (function-err-srcⁿ (U ∩ V) q) (function-err p) -srcⁿ-unknown-≮: (U ∪ V) (witness t p q) = witness (scalar V) (right (scalar V)) (function-scalar V) - -src-unknown-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ unknown)) -src-unknown-≮: {U = nil} (witness t p q) = witness (scalar nil) (scalar nil) (function-scalar nil) -src-unknown-≮: {U = T ⇒ U} (witness t p q) = witness (function-err t) (function-err q) (function-err p) -src-unknown-≮: {U = never} (witness t p q) = CONTRADICTION (language-comp t q unknown) -src-unknown-≮: {U = unknown} (witness t p q) = witness (function-err t) unknown (function-err p) -src-unknown-≮: {U = boolean} (witness t p q) = witness (scalar boolean) (scalar boolean) (function-scalar boolean) -src-unknown-≮: {U = number} (witness t p q) = witness (scalar number) (scalar number) (function-scalar number) -src-unknown-≮: {U = string} (witness t p q) = witness (scalar string) (scalar string) (function-scalar string) -src-unknown-≮: {U = T ∪ U} p = <:-trans-≮: (normalize-<: (T ∪ U)) (srcⁿ-unknown-≮: (normal (T ∪ U)) p) -src-unknown-≮: {U = T ∩ U} p = <:-trans-≮: (normalize-<: (T ∩ U)) (srcⁿ-unknown-≮: (normal (T ∩ U)) p) - -unknown-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ unknown)) → (U ≮: src T) -unknown-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) -unknown-src-≮: r (witness (function-ok (scalar s)) p (function-ok (scalar-scalar s () q))) -unknown-src-≮: r (witness (function-ok (function-ok _)) p (function-ok (scalar-function-ok ()))) -unknown-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) - --- Properties of tgt -tgt-function-ok : ∀ {T t} → (Language (tgt T) t) → Language T (function-ok t) -tgt-function-ok {T = nil} (scalar ()) -tgt-function-ok {T = T₁ ⇒ T₂} p = function-ok p -tgt-function-ok {T = never} (scalar ()) -tgt-function-ok {T = unknown} p = unknown -tgt-function-ok {T = boolean} (scalar ()) -tgt-function-ok {T = number} (scalar ()) -tgt-function-ok {T = string} (scalar ()) -tgt-function-ok {T = T₁ ∪ T₂} (left p) = left (tgt-function-ok p) -tgt-function-ok {T = T₁ ∪ T₂} (right p) = right (tgt-function-ok p) -tgt-function-ok {T = T₁ ∩ T₂} (p₁ , p₂) = (tgt-function-ok p₁ , tgt-function-ok p₂) - -function-ok-tgt : ∀ {T t} → Language T (function-ok t) → (Language (tgt T) t) -function-ok-tgt (function-ok p) = p -function-ok-tgt (left p) = left (function-ok-tgt p) -function-ok-tgt (right p) = right (function-ok-tgt p) -function-ok-tgt (p₁ , p₂) = (function-ok-tgt p₁ , function-ok-tgt p₂) -function-ok-tgt unknown = unknown - -tgt-never-≮: : ∀ {T U} → (tgt T ≮: U) → (T ≮: (skalar ∪ (never ⇒ U))) -tgt-never-≮: (witness t p q) = witness (function-ok t) (tgt-function-ok p) (skalar-function-ok , function-ok q) - -never-tgt-≮: : ∀ {T U} → (T ≮: (skalar ∪ (never ⇒ U))) → (tgt T ≮: U) -never-tgt-≮: (witness (scalar s) p (q₁ , q₂)) = CONTRADICTION (≮:-refl (witness (scalar s) (skalar-scalar s) q₁)) -never-tgt-≮: (witness function p (q₁ , scalar-function ())) -never-tgt-≮: (witness (function-ok t) p (q₁ , function-ok q₂)) = witness t (function-ok-tgt p) q₂ -never-tgt-≮: (witness (function-err (scalar s)) p (q₁ , function-err (scalar ()))) - -src-tgtᶠ-<: : ∀ {T U V} → (FunType T) → (U <: src T) → (tgt T <: V) → (T <: (U ⇒ V)) -src-tgtᶠ-<: T p q (scalar s) r = CONTRADICTION (language-comp (scalar s) (fun-¬scalar s T) r) -src-tgtᶠ-<: T p q function r = function -src-tgtᶠ-<: T p q (function-ok s) r = function-ok (q s (function-ok-tgt r)) -src-tgtᶠ-<: T p q (function-err s) r = function-err (<:-impl-⊇ p s (src-¬function-err r)) - - diff --git a/prototyping/Properties/ResolveOverloads.agda b/prototyping/Properties/ResolveOverloads.agda new file mode 100644 index 00000000..8de4a875 --- /dev/null +++ b/prototyping/Properties/ResolveOverloads.agda @@ -0,0 +1,189 @@ +{-# OPTIONS --rewriting #-} + +module Properties.ResolveOverloads where + +open import FFI.Data.Either using (Left; Right) +open import Luau.ResolveOverloads using (Resolved; src; srcⁿ; resolve; resolveⁿ; resolveᶠ; resolveˢ; target; yes; no) +open import Luau.Subtyping using (_<:_; _≮:_; Language; ¬Language; witness; scalar; unknown; never; function; function-ok; function-err; function-tgt; function-scalar; function-ok₁; function-ok₂; scalar-scalar; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; _,_; left; right) +open import Luau.Type using (Type ; Scalar; _⇒_; _∩_; _∪_; nil; boolean; number; string; unknown; never) +open import Luau.TypeSaturation using (saturate) +open import Luau.TypeNormalization using (normalize) +open import Properties.Contradiction using (CONTRADICTION) +open import Properties.DecSubtyping using (dec-subtyping; dec-subtypingⁿ; <:-impl-<:ᵒ) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; <:-trans; <:-trans-≮:; ≮:-trans-<:; <:-∩-left; <:-∩-right; <:-∩-glb; <:-impl-¬≮:; <:-unknown; <:-function; function-≮:-never; <:-never; unknown-≮:-function; scalar-≮:-function; ≮:-∪-right; scalar-≮:-never; <:-∪-left; <:-∪-right; <:-impl-⊇; language-comp) +open import Properties.TypeNormalization using (Normal; FunType; normal; _⇒_; _∩_; _∪_; never; unknown; <:-normalize; normalize-<:; fun-≮:-never; unknown-≮:-fun; scalar-≮:-fun) +open import Properties.TypeSaturation using (Overloads; Saturated; _⊆ᵒ_; _<:ᵒ_; normal-saturate; saturated; <:-saturate; saturate-<:; defn; here; left; right) + +-- Properties of src +function-err-srcⁿ : ∀ {T t} → (FunType T) → (¬Language (srcⁿ T) t) → Language T (function-err t) +function-err-srcⁿ (S ⇒ T) p = function-err p +function-err-srcⁿ (S ∩ T) (p₁ , p₂) = (function-err-srcⁿ S p₁ , function-err-srcⁿ T p₂) + +¬function-err-srcᶠ : ∀ {T t} → (FunType T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) +¬function-err-srcᶠ (S ⇒ T) p = function-err p +¬function-err-srcᶠ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) +¬function-err-srcᶠ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) + +¬function-err-srcⁿ : ∀ {T t} → (Normal T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) +¬function-err-srcⁿ never p = never +¬function-err-srcⁿ unknown (scalar ()) +¬function-err-srcⁿ (S ⇒ T) p = function-err p +¬function-err-srcⁿ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) +¬function-err-srcⁿ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) +¬function-err-srcⁿ (S ∪ T) (scalar ()) + +¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) +¬function-err-src {T = S ⇒ T} p = function-err p +¬function-err-src {T = nil} p = scalar-function-err nil +¬function-err-src {T = never} p = never +¬function-err-src {T = unknown} (scalar ()) +¬function-err-src {T = boolean} p = scalar-function-err boolean +¬function-err-src {T = number} p = scalar-function-err number +¬function-err-src {T = string} p = scalar-function-err string +¬function-err-src {T = S ∪ T} p = <:-impl-⊇ (<:-normalize (S ∪ T)) _ (¬function-err-srcⁿ (normal (S ∪ T)) p) +¬function-err-src {T = S ∩ T} p = <:-impl-⊇ (<:-normalize (S ∩ T)) _ (¬function-err-srcⁿ (normal (S ∩ T)) p) + +src-¬function-errᶠ : ∀ {T t} → (FunType T) → Language T (function-err t) → (¬Language (srcⁿ T) t) +src-¬function-errᶠ (S ⇒ T) (function-err p) = p +src-¬function-errᶠ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) + +src-¬function-errⁿ : ∀ {T t} → (Normal T) → Language T (function-err t) → (¬Language (srcⁿ T) t) +src-¬function-errⁿ unknown p = never +src-¬function-errⁿ (S ⇒ T) (function-err p) = p +src-¬function-errⁿ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) +src-¬function-errⁿ (S ∪ T) p = never + +src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) +src-¬function-err {T = S ⇒ T} (function-err p) = p +src-¬function-err {T = unknown} p = never +src-¬function-err {T = S ∪ T} p = src-¬function-errⁿ (normal (S ∪ T)) (<:-normalize (S ∪ T) _ p) +src-¬function-err {T = S ∩ T} p = src-¬function-errⁿ (normal (S ∩ T)) (<:-normalize (S ∩ T) _ p) + +fun-¬scalar : ∀ {S T} (s : Scalar S) → FunType T → ¬Language T (scalar s) +fun-¬scalar s (S ⇒ T) = function-scalar s +fun-¬scalar s (S ∩ T) = left (fun-¬scalar s S) + +¬fun-scalar : ∀ {S T t} (s : Scalar S) → FunType T → Language T t → ¬Language S t +¬fun-scalar s (S ⇒ T) function = scalar-function s +¬fun-scalar s (S ⇒ T) (function-ok₁ p) = scalar-function-ok s +¬fun-scalar s (S ⇒ T) (function-ok₂ p) = scalar-function-ok s +¬fun-scalar s (S ⇒ T) (function-err p) = scalar-function-err s +¬fun-scalar s (S ⇒ T) (function-tgt p) = scalar-function-tgt s +¬fun-scalar s (S ∩ T) (p₁ , p₂) = ¬fun-scalar s T p₂ + +fun-function : ∀ {T} → FunType T → Language T function +fun-function (S ⇒ T) = function +fun-function (S ∩ T) = (fun-function S , fun-function T) + +srcⁿ-¬scalar : ∀ {S T t} (s : Scalar S) → Normal T → Language T (scalar s) → (¬Language (srcⁿ T) t) +srcⁿ-¬scalar s never (scalar ()) +srcⁿ-¬scalar s unknown p = never +srcⁿ-¬scalar s (S ⇒ T) (scalar ()) +srcⁿ-¬scalar s (S ∩ T) (p₁ , p₂) = CONTRADICTION (language-comp (scalar s) (fun-¬scalar s S) p₁) +srcⁿ-¬scalar s (S ∪ T) p = never + +src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) +src-¬scalar {T = nil} s p = never +src-¬scalar {T = T ⇒ U} s (scalar ()) +src-¬scalar {T = never} s (scalar ()) +src-¬scalar {T = unknown} s p = never +src-¬scalar {T = boolean} s p = never +src-¬scalar {T = number} s p = never +src-¬scalar {T = string} s p = never +src-¬scalar {T = T ∪ U} s p = srcⁿ-¬scalar s (normal (T ∪ U)) (<:-normalize (T ∪ U) (scalar s) p) +src-¬scalar {T = T ∩ U} s p = srcⁿ-¬scalar s (normal (T ∩ U)) (<:-normalize (T ∩ U) (scalar s) p) + +srcⁿ-unknown-≮: : ∀ {T U} → (Normal U) → (T ≮: srcⁿ U) → (U ≮: (T ⇒ unknown)) +srcⁿ-unknown-≮: never (witness t p q) = CONTRADICTION (language-comp t q unknown) +srcⁿ-unknown-≮: unknown (witness t p q) = witness (function-err t) unknown (function-err p) +srcⁿ-unknown-≮: (U ⇒ V) (witness t p q) = witness (function-err t) (function-err q) (function-err p) +srcⁿ-unknown-≮: (U ∩ V) (witness t p q) = witness (function-err t) (function-err-srcⁿ (U ∩ V) q) (function-err p) +srcⁿ-unknown-≮: (U ∪ V) (witness t p q) = witness (scalar V) (right (scalar V)) (function-scalar V) + +src-unknown-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ unknown)) +src-unknown-≮: {U = nil} (witness t p q) = witness (scalar nil) (scalar nil) (function-scalar nil) +src-unknown-≮: {U = T ⇒ U} (witness t p q) = witness (function-err t) (function-err q) (function-err p) +src-unknown-≮: {U = never} (witness t p q) = CONTRADICTION (language-comp t q unknown) +src-unknown-≮: {U = unknown} (witness t p q) = witness (function-err t) unknown (function-err p) +src-unknown-≮: {U = boolean} (witness t p q) = witness (scalar boolean) (scalar boolean) (function-scalar boolean) +src-unknown-≮: {U = number} (witness t p q) = witness (scalar number) (scalar number) (function-scalar number) +src-unknown-≮: {U = string} (witness t p q) = witness (scalar string) (scalar string) (function-scalar string) +src-unknown-≮: {U = T ∪ U} p = <:-trans-≮: (normalize-<: (T ∪ U)) (srcⁿ-unknown-≮: (normal (T ∪ U)) p) +src-unknown-≮: {U = T ∩ U} p = <:-trans-≮: (normalize-<: (T ∩ U)) (srcⁿ-unknown-≮: (normal (T ∩ U)) p) + +unknown-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ unknown)) → (U ≮: src T) +unknown-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) +unknown-src-≮: r (witness (function-ok s .(scalar s₁)) p (function-ok x (scalar-scalar s₁ () x₂))) +unknown-src-≮: r (witness (function-ok s .function) p (function-ok x (scalar-function ()))) +unknown-src-≮: r (witness (function-ok s .(function-ok _ _)) p (function-ok x (scalar-function-ok ()))) +unknown-src-≮: r (witness (function-ok s .(function-err _)) p (function-ok x (scalar-function-err ()))) +unknown-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) +unknown-src-≮: r (witness (function-tgt t) p (function-tgt (scalar-function-tgt ()))) + +-- Properties of resolve +resolveˢ-<:-⇒ : ∀ {F V U} → (FunType F) → (Saturated F) → (FunType (V ⇒ U)) → (r : Resolved F V) → (F <: (V ⇒ U)) → (target r <: U) +resolveˢ-<:-⇒ Fᶠ Fˢ V⇒Uᶠ r F<:V⇒U with <:-impl-<:ᵒ Fᶠ Fˢ V⇒Uᶠ F<:V⇒U here +resolveˢ-<:-⇒ Fᶠ Fˢ V⇒Uᶠ (yes Sʳ Tʳ oʳ V<:Sʳ tgtʳ) F<:V⇒U | defn o o₁ o₂ = <:-trans (tgtʳ o o₁) o₂ +resolveˢ-<:-⇒ Fᶠ Fˢ V⇒Uᶠ (no tgtʳ) F<:V⇒U | defn o o₁ o₂ = CONTRADICTION (<:-impl-¬≮: o₁ (tgtʳ o)) + +resolveⁿ-<:-⇒ : ∀ {F V U} → (Fⁿ : Normal F) → (Vⁿ : Normal V) → (Uⁿ : Normal U) → (F <: (V ⇒ U)) → (resolveⁿ Fⁿ Vⁿ <: U) +resolveⁿ-<:-⇒ (Sⁿ ⇒ Tⁿ) Vⁿ Uⁿ F<:V⇒U = resolveˢ-<:-⇒ (normal-saturate (Sⁿ ⇒ Tⁿ)) (saturated (Sⁿ ⇒ Tⁿ)) (Vⁿ ⇒ Uⁿ) (resolveˢ (normal-saturate (Sⁿ ⇒ Tⁿ)) (saturated (Sⁿ ⇒ Tⁿ)) Vⁿ (λ o → o)) F<:V⇒U +resolveⁿ-<:-⇒ (Fⁿ ∩ Gⁿ) Vⁿ Uⁿ F<:V⇒U = resolveˢ-<:-⇒ (normal-saturate (Fⁿ ∩ Gⁿ)) (saturated (Fⁿ ∩ Gⁿ)) (Vⁿ ⇒ Uⁿ) (resolveˢ (normal-saturate (Fⁿ ∩ Gⁿ)) (saturated (Fⁿ ∩ Gⁿ)) Vⁿ (λ o → o)) (<:-trans (saturate-<: (Fⁿ ∩ Gⁿ)) F<:V⇒U) +resolveⁿ-<:-⇒ (Sⁿ ∪ Tˢ) Vⁿ Uⁿ F<:V⇒U = CONTRADICTION (<:-impl-¬≮: F<:V⇒U (<:-trans-≮: <:-∪-right (scalar-≮:-function Tˢ))) +resolveⁿ-<:-⇒ never Vⁿ Uⁿ F<:V⇒U = <:-never +resolveⁿ-<:-⇒ unknown Vⁿ Uⁿ F<:V⇒U = CONTRADICTION (<:-impl-¬≮: F<:V⇒U unknown-≮:-function) + +resolve-<:-⇒ : ∀ {F V U} → (F <: (V ⇒ U)) → (resolve F V <: U) +resolve-<:-⇒ {F} {V} {U} F<:V⇒U = <:-trans (resolveⁿ-<:-⇒ (normal F) (normal V) (normal U) (<:-trans (normalize-<: F) (<:-trans F<:V⇒U (<:-normalize (V ⇒ U))))) (normalize-<: U) + +resolve-≮:-⇒ : ∀ {F V U} → (resolve F V ≮: U) → (F ≮: (V ⇒ U)) +resolve-≮:-⇒ {F} {V} {U} FV≮:U with dec-subtyping F (V ⇒ U) +resolve-≮:-⇒ {F} {V} {U} FV≮:U | Left F≮:V⇒U = F≮:V⇒U +resolve-≮:-⇒ {F} {V} {U} FV≮:U | Right F<:V⇒U = CONTRADICTION (<:-impl-¬≮: (resolve-<:-⇒ F<:V⇒U) FV≮:U) + +<:-resolveˢ-⇒ : ∀ {S T V} → (r : Resolved (S ⇒ T) V) → (V <: S) → T <: target r +<:-resolveˢ-⇒ (yes S T here _ _) V<:S = <:-refl +<:-resolveˢ-⇒ (no _) V<:S = <:-unknown + +<:-resolveⁿ-⇒ : ∀ {S T V} → (Sⁿ : Normal S) → (Tⁿ : Normal T) → (Vⁿ : Normal V) → (V <: S) → T <: resolveⁿ (Sⁿ ⇒ Tⁿ) Vⁿ +<:-resolveⁿ-⇒ Sⁿ Tⁿ Vⁿ V<:S = <:-resolveˢ-⇒ (resolveˢ (Sⁿ ⇒ Tⁿ) (saturated (Sⁿ ⇒ Tⁿ)) Vⁿ (λ o → o)) V<:S + +<:-resolve-⇒ : ∀ {S T V} → (V <: S) → T <: resolve (S ⇒ T) V +<:-resolve-⇒ {S} {T} {V} V<:S = <:-trans (<:-normalize T) (<:-resolveⁿ-⇒ (normal S) (normal T) (normal V) (<:-trans (normalize-<: V) (<:-trans V<:S (<:-normalize S)))) + +<:-resolveˢ : ∀ {F G V W} → (r : Resolved F V) → (s : Resolved G W) → (F <:ᵒ G) → (V <: W) → target r <: target s +<:-resolveˢ (yes Sʳ Tʳ oʳ V<:Sʳ tgtʳ) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W with F<:G oˢ +<:-resolveˢ (yes Sʳ Tʳ oʳ V<:Sʳ tgtʳ) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W | defn o o₁ o₂ = <:-trans (tgtʳ o (<:-trans (<:-trans V<:W W<:Sˢ) o₁)) o₂ +<:-resolveˢ (no r) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W with F<:G oˢ +<:-resolveˢ (no r) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W | defn o o₁ o₂ = CONTRADICTION (<:-impl-¬≮: (<:-trans V<:W (<:-trans W<:Sˢ o₁)) (r o)) +<:-resolveˢ r (no s) F<:G V<:W = <:-unknown + +<:-resolveᶠ : ∀ {F G V W} → (Fᶠ : FunType F) → (Gᶠ : FunType G) → (Vⁿ : Normal V) → (Wⁿ : Normal W) → (F <: G) → (V <: W) → resolveᶠ Fᶠ Vⁿ <: resolveᶠ Gᶠ Wⁿ +<:-resolveᶠ Fᶠ Gᶠ Vⁿ Wⁿ F<:G V<:W = <:-resolveˢ + (resolveˢ (normal-saturate Fᶠ) (saturated Fᶠ) Vⁿ (λ o → o)) + (resolveˢ (normal-saturate Gᶠ) (saturated Gᶠ) Wⁿ (λ o → o)) + (<:-impl-<:ᵒ (normal-saturate Fᶠ) (saturated Fᶠ) (normal-saturate Gᶠ) (<:-trans (saturate-<: Fᶠ) (<:-trans F<:G (<:-saturate Gᶠ)))) + V<:W + +<:-resolveⁿ : ∀ {F G V W} → (Fⁿ : Normal F) → (Gⁿ : Normal G) → (Vⁿ : Normal V) → (Wⁿ : Normal W) → (F <: G) → (V <: W) → resolveⁿ Fⁿ Vⁿ <: resolveⁿ Gⁿ Wⁿ +<:-resolveⁿ (Rⁿ ⇒ Sⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Rⁿ ⇒ Sⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Rⁿ ⇒ Sⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Rⁿ ⇒ Sⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Eⁿ ∩ Fⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Eⁿ ∩ Fⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Eⁿ ∩ Fⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Eⁿ ∩ Fⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Fⁿ ∪ Sˢ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (≮:-∪-right (scalar-≮:-function Sˢ))) +<:-resolveⁿ unknown (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G unknown-≮:-function) +<:-resolveⁿ (Fⁿ ∪ Sˢ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (≮:-∪-right (scalar-≮:-fun (Gⁿ ∩ Hⁿ) Sˢ))) +<:-resolveⁿ unknown (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (unknown-≮:-fun (Gⁿ ∩ Hⁿ))) +<:-resolveⁿ (Rⁿ ⇒ Sⁿ) never Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (fun-≮:-never (Rⁿ ⇒ Sⁿ))) +<:-resolveⁿ (Eⁿ ∩ Fⁿ) never Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (fun-≮:-never (Eⁿ ∩ Fⁿ))) +<:-resolveⁿ (Fⁿ ∪ Sˢ) never Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (≮:-∪-right (scalar-≮:-never Sˢ))) +<:-resolveⁿ unknown never Vⁿ Wⁿ F<:G V<:W = F<:G +<:-resolveⁿ never Gⁿ Vⁿ Wⁿ F<:G V<:W = <:-never +<:-resolveⁿ Fⁿ (Gⁿ ∪ Uˢ) Vⁿ Wⁿ F<:G V<:W = <:-unknown +<:-resolveⁿ Fⁿ unknown Vⁿ Wⁿ F<:G V<:W = <:-unknown + +<:-resolve : ∀ {F G V W} → (F <: G) → (V <: W) → resolve F V <: resolve G W +<:-resolve {F} {G} {V} {W} F<:G V<:W = <:-resolveⁿ (normal F) (normal G) (normal V) (normal W) + (<:-trans (normalize-<: F) (<:-trans F<:G (<:-normalize G))) + (<:-trans (normalize-<: V) (<:-trans V<:W (<:-normalize W))) diff --git a/prototyping/Properties/StrictMode.agda b/prototyping/Properties/StrictMode.agda index 69e9131c..948674b9 100644 --- a/prototyping/Properties/StrictMode.agda +++ b/prototyping/Properties/StrictMode.agda @@ -7,11 +7,11 @@ open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapL; mapR; mapLR; swapLR; cond) open import FFI.Data.Maybe using (Maybe; just; nothing) open import Luau.Heap using (Heap; Object; function_is_end; defn; alloc; ok; next; lookup-not-allocated) renaming (_≡_⊕_↦_ to _≡ᴴ_⊕_↦_; _[_] to _[_]ᴴ; ∅ to ∅ᴴ) +open import Luau.ResolveOverloads using (src; resolve) open import Luau.StrictMode using (Warningᴱ; Warningᴮ; Warningᴼ; Warningᴴ; UnallocatedAddress; UnboundVariable; FunctionCallMismatch; app₁; app₂; BinOpMismatch₁; BinOpMismatch₂; bin₁; bin₂; BlockMismatch; block₁; return; LocalVarMismatch; local₁; local₂; FunctionDefnMismatch; function₁; function₂; heap; expr; block; addr) open import Luau.Substitution using (_[_/_]ᴮ; _[_/_]ᴱ; _[_/_]ᴮunless_; var_[_/_]ᴱwhenever_) -open import Luau.Subtyping using (_≮:_; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) +open import Luau.Subtyping using (_<:_; _≮:_; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) open import Luau.Syntax using (Expr; yes; var; val; var_∈_; _⟨_⟩∈_; _$_; addr; number; bool; string; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name; ==; ~=) -open import Luau.FunctionTypes using (src; tgt) open import Luau.Type using (Type; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; _≡ᵀ_; _≡ᴹᵀ_) open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orUnknown; srcBinOp; tgtBinOp) open import Luau.Var using (_≡ⱽ_) @@ -23,8 +23,10 @@ open import Properties.Equality using (_≢_; sym; cong; trans; subst₁) open import Properties.Dec using (Dec; yes; no) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Functions using (_∘_) -open import Properties.FunctionTypes using (never-tgt-≮:; tgt-never-≮:; src-unknown-≮:; unknown-src-≮:) -open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never) +open import Properties.DecSubtyping using (dec-subtyping) +open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never; <:-refl; <:-unknown; <:-impl-¬≮:) +open import Properties.ResolveOverloads using (src-unknown-≮:; unknown-src-≮:; <:-resolve; resolve-<:-⇒; <:-resolve-⇒) +open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; ≮:-trans; <:-trans-≮:; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never; ≡-impl-<:; ≡-trans-<:; <:-trans-≡; ≮:-trans-<:; <:-trans) open import Properties.TypeCheck using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) open import Luau.OpSem using (_⟦_⟧_⟶_; _⊢_⟶*_⊣_; _⊢_⟶ᴮ_⊣_; _⊢_⟶ᴱ_⊣_; app₁; app₂; function; beta; return; block; done; local; subst; binOp₀; binOp₁; binOp₂; refl; step; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.RuntimeError using (BinOpError; RuntimeErrorᴱ; RuntimeErrorᴮ; FunctionMismatch; BinOpMismatch₁; BinOpMismatch₂; UnboundVariable; SEGV; app₁; app₂; bin₁; bin₂; block; local; return; +; -; *; /; <; >; <=; >=; ··) @@ -63,51 +65,32 @@ lookup-⊑-nothing {H} a (snoc defn) p with a ≡ᴬ next H lookup-⊑-nothing {H} a (snoc defn) p | yes refl = refl lookup-⊑-nothing {H} a (snoc o) p | no q = trans (lookup-not-allocated o q) p -heap-weakeningᴱ : ∀ Γ H M {H′ U} → (H ⊑ H′) → (typeOfᴱ H′ Γ M ≮: U) → (typeOfᴱ H Γ M ≮: U) -heap-weakeningᴱ Γ H (var x) h p = p -heap-weakeningᴱ Γ H (val nil) h p = p -heap-weakeningᴱ Γ H (val (addr a)) refl p = p -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p with a ≡ᴬ b -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = a} defn) p | yes refl = unknown-≮: p -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p | no r = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ (lookup-not-allocated q r))) p -heap-weakeningᴱ Γ H (val (number x)) h p = p -heap-weakeningᴱ Γ H (val (bool x)) h p = p -heap-weakeningᴱ Γ H (val (string x)) h p = p -heap-weakeningᴱ Γ H (M $ N) h p = never-tgt-≮: (heap-weakeningᴱ Γ H M h (tgt-never-≮: p)) -heap-weakeningᴱ Γ H (function f ⟨ var x ∈ T ⟩∈ U is B end) h p = p -heap-weakeningᴱ Γ H (block var b ∈ T is B end) h p = p -heap-weakeningᴱ Γ H (binexp M op N) h p = p +<:-heap-weakeningᴱ : ∀ Γ H M {H′} → (H ⊑ H′) → (typeOfᴱ H′ Γ M <: typeOfᴱ H Γ M) +<:-heap-weakeningᴱ Γ H (var x) h = <:-refl +<:-heap-weakeningᴱ Γ H (val nil) h = <:-refl +<:-heap-weakeningᴱ Γ H (val (addr a)) refl = <:-refl +<:-heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) with a ≡ᴬ b +<:-heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = a} defn) | yes refl = <:-unknown +<:-heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) | no r = ≡-impl-<: (sym (cong orUnknown (cong typeOfᴹᴼ (lookup-not-allocated q r)))) +<:-heap-weakeningᴱ Γ H (val (number n)) h = <:-refl +<:-heap-weakeningᴱ Γ H (val (bool b)) h = <:-refl +<:-heap-weakeningᴱ Γ H (val (string s)) h = <:-refl +<:-heap-weakeningᴱ Γ H (M $ N) h = <:-resolve (<:-heap-weakeningᴱ Γ H M h) (<:-heap-weakeningᴱ Γ H N h) +<:-heap-weakeningᴱ Γ H (function f ⟨ var x ∈ S ⟩∈ T is B end) h = <:-refl +<:-heap-weakeningᴱ Γ H (block var b ∈ T is N end) h = <:-refl +<:-heap-weakeningᴱ Γ H (binexp M op N) h = <:-refl -heap-weakeningᴮ : ∀ Γ H B {H′ U} → (H ⊑ H′) → (typeOfᴮ H′ Γ B ≮: U) → (typeOfᴮ H Γ B ≮: U) -heap-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h p = heap-weakeningᴮ (Γ ⊕ f ↦ (T ⇒ U)) H B h p -heap-weakeningᴮ Γ H (local var x ∈ T ← M ∙ B) h p = heap-weakeningᴮ (Γ ⊕ x ↦ T) H B h p -heap-weakeningᴮ Γ H (return M ∙ B) h p = heap-weakeningᴱ Γ H M h p -heap-weakeningᴮ Γ H done h p = p +<:-heap-weakeningᴮ : ∀ Γ H B {H′} → (H ⊑ H′) → (typeOfᴮ H′ Γ B <: typeOfᴮ H Γ B) +<:-heap-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h = <:-heap-weakeningᴮ (Γ ⊕ f ↦ (T ⇒ U)) H B h +<:-heap-weakeningᴮ Γ H (local var x ∈ T ← M ∙ B) h = <:-heap-weakeningᴮ (Γ ⊕ x ↦ T) H B h +<:-heap-weakeningᴮ Γ H (return M ∙ B) h = <:-heap-weakeningᴱ Γ H M h +<:-heap-weakeningᴮ Γ H done h = <:-refl -substitutivityᴱ : ∀ {Γ T U} H M v x → (typeOfᴱ H Γ (M [ v / x ]ᴱ) ≮: U) → Either (typeOfᴱ H (Γ ⊕ x ↦ T) M ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴱ-whenever : ∀ {Γ T U} H v x y (r : Dec(x ≡ y)) → (typeOfᴱ H Γ (var y [ v / x ]ᴱwhenever r) ≮: U) → Either (typeOfᴱ H (Γ ⊕ x ↦ T) (var y) ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ : ∀ {Γ T U} H B v x → (typeOfᴮ H Γ (B [ v / x ]ᴮ) ≮: U) → Either (typeOfᴮ H (Γ ⊕ x ↦ T) B ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ-unless : ∀ {Γ T U V} H B v x y (r : Dec(x ≡ y)) → (typeOfᴮ H (Γ ⊕ y ↦ U) (B [ v / x ]ᴮunless r) ≮: V) → Either (typeOfᴮ H ((Γ ⊕ x ↦ T) ⊕ y ↦ U) B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ-unless-yes : ∀ {Γ Γ′ T V} H B v x y (r : x ≡ y) → (Γ′ ≡ Γ) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless yes r) ≮: V) → Either (typeOfᴮ H Γ′ B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ-unless-no : ∀ {Γ Γ′ T V} H B v x y (r : x ≢ y) → (Γ′ ≡ Γ ⊕ x ↦ T) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless no r) ≮: V) → Either (typeOfᴮ H Γ′ B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-heap-weakeningᴱ : ∀ Γ H M {H′ U} → (H ⊑ H′) → (typeOfᴱ H′ Γ M ≮: U) → (typeOfᴱ H Γ M ≮: U) +≮:-heap-weakeningᴱ Γ H M h p = <:-trans-≮: (<:-heap-weakeningᴱ Γ H M h) p -substitutivityᴱ H (var y) v x p = substitutivityᴱ-whenever H v x y (x ≡ⱽ y) p -substitutivityᴱ H (val w) v x p = Left p -substitutivityᴱ H (binexp M op N) v x p = Left p -substitutivityᴱ H (M $ N) v x p = mapL never-tgt-≮: (substitutivityᴱ H M v x (tgt-never-≮: p)) -substitutivityᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x p = Left p -substitutivityᴱ H (block var b ∈ T is B end) v x p = Left p -substitutivityᴱ-whenever H v x x (yes refl) q = swapLR (≮:-trans q) -substitutivityᴱ-whenever H v x y (no p) q = Left (≡-trans-≮: (cong orUnknown (sym (⊕-lookup-miss x y _ _ p))) q) - -substitutivityᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x p = substitutivityᴮ-unless H B v x f (x ≡ⱽ f) p -substitutivityᴮ H (local var y ∈ T ← M ∙ B) v x p = substitutivityᴮ-unless H B v x y (x ≡ⱽ y) p -substitutivityᴮ H (return M ∙ B) v x p = substitutivityᴱ H M v x p -substitutivityᴮ H done v x p = Left p -substitutivityᴮ-unless H B v x y (yes p) q = substitutivityᴮ-unless-yes H B v x y p (⊕-over p) q -substitutivityᴮ-unless H B v x y (no p) q = substitutivityᴮ-unless-no H B v x y p (⊕-swap p) q -substitutivityᴮ-unless-yes H B v x y refl refl p = Left p -substitutivityᴮ-unless-no H B v x y p refl q = substitutivityᴮ H B v x q +≮:-heap-weakeningᴮ : ∀ Γ H B {H′ U} → (H ⊑ H′) → (typeOfᴮ H′ Γ B ≮: U) → (typeOfᴮ H Γ B ≮: U) +≮:-heap-weakeningᴮ Γ H B h p = <:-trans-≮: (<:-heap-weakeningᴮ Γ H B h) p binOpPreservation : ∀ H {op v w x} → (v ⟦ op ⟧ w ⟶ x) → (tgtBinOp op ≡ typeOfᴱ H ∅ (val x)) binOpPreservation H (+ m n) = refl @@ -122,24 +105,78 @@ binOpPreservation H (== v w) = refl binOpPreservation H (~= v w) = refl binOpPreservation H (·· v w) = refl -reflect-subtypingᴱ : ∀ H M {H′ M′ T} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → (typeOfᴱ H′ ∅ M′ ≮: T) → Either (typeOfᴱ H ∅ M ≮: T) (Warningᴱ H (typeCheckᴱ H ∅ M)) -reflect-subtypingᴮ : ∀ H B {H′ B′ T} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → (typeOfᴮ H′ ∅ B′ ≮: T) → Either (typeOfᴮ H ∅ B ≮: T) (Warningᴮ H (typeCheckᴮ H ∅ B)) +<:-substitutivityᴱ : ∀ {Γ T} H M v x → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴱ H Γ (M [ v / x ]ᴱ) <: typeOfᴱ H (Γ ⊕ x ↦ T) M) +<:-substitutivityᴱ-whenever : ∀ {Γ T} H v x y (r : Dec(x ≡ y)) → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴱ H Γ (var y [ v / x ]ᴱwhenever r) <: typeOfᴱ H (Γ ⊕ x ↦ T) (var y)) +<:-substitutivityᴮ : ∀ {Γ T} H B v x → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴮ H Γ (B [ v / x ]ᴮ) <: typeOfᴮ H (Γ ⊕ x ↦ T) B) +<:-substitutivityᴮ-unless : ∀ {Γ T U} H B v x y (r : Dec(x ≡ y)) → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴮ H (Γ ⊕ y ↦ U) (B [ v / x ]ᴮunless r) <: typeOfᴮ H ((Γ ⊕ x ↦ T) ⊕ y ↦ U) B) +<:-substitutivityᴮ-unless-yes : ∀ {Γ Γ′} H B v x y (r : x ≡ y) → (Γ′ ≡ Γ) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless yes r) <: typeOfᴮ H Γ′ B) +<:-substitutivityᴮ-unless-no : ∀ {Γ Γ′ T} H B v x y (r : x ≢ y) → (Γ′ ≡ Γ ⊕ x ↦ T) → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless no r) <: typeOfᴮ H Γ′ B) -reflect-subtypingᴱ H (M $ N) (app₁ s) p = mapLR never-tgt-≮: app₁ (reflect-subtypingᴱ H M s (tgt-never-≮: p)) -reflect-subtypingᴱ H (M $ N) (app₂ v s) p = Left (never-tgt-≮: (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (tgt-never-≮: p))) -reflect-subtypingᴱ H (M $ N) (beta (function f ⟨ var y ∈ T ⟩∈ U is B end) v refl q) p = Left (≡-trans-≮: (cong tgt (cong orUnknown (cong typeOfᴹᴼ q))) p) -reflect-subtypingᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) p = Left p -reflect-subtypingᴱ H (block var b ∈ T is B end) (block s) p = Left p -reflect-subtypingᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) p = mapR BlockMismatch (swapLR (≮:-trans p)) -reflect-subtypingᴱ H (block var b ∈ T is done end) done p = mapR BlockMismatch (swapLR (≮:-trans p)) -reflect-subtypingᴱ H (binexp M op N) (binOp₀ s) p = Left (≡-trans-≮: (binOpPreservation H s) p) -reflect-subtypingᴱ H (binexp M op N) (binOp₁ s) p = Left p -reflect-subtypingᴱ H (binexp M op N) (binOp₂ s) p = Left p +<:-substitutivityᴱ H (var y) v x p = <:-substitutivityᴱ-whenever H v x y (x ≡ⱽ y) p +<:-substitutivityᴱ H (val w) v x p = <:-refl +<:-substitutivityᴱ H (binexp M op N) v x p = <:-refl +<:-substitutivityᴱ H (M $ N) v x p = <:-resolve (<:-substitutivityᴱ H M v x p) (<:-substitutivityᴱ H N v x p) +<:-substitutivityᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x p = <:-refl +<:-substitutivityᴱ H (block var b ∈ T is B end) v x p = <:-refl +<:-substitutivityᴱ-whenever H v x x (yes refl) p = p +<:-substitutivityᴱ-whenever H v x y (no o) p = (≡-impl-<: (cong orUnknown (⊕-lookup-miss x y _ _ o))) -reflect-subtypingᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) p = mapLR (heap-weakeningᴮ _ _ B (snoc defn)) (CONTRADICTION ∘ ≮:-refl) (substitutivityᴮ _ B (addr a) f p) -reflect-subtypingᴮ H (local var x ∈ T ← M ∙ B) (local s) p = Left (heap-weakeningᴮ (x ↦ T) H B (rednᴱ⊑ s) p) -reflect-subtypingᴮ H (local var x ∈ T ← M ∙ B) (subst v) p = mapR LocalVarMismatch (substitutivityᴮ H B v x p) -reflect-subtypingᴮ H (return M ∙ B) (return s) p = mapR return (reflect-subtypingᴱ H M s p) +<:-substitutivityᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x p = <:-substitutivityᴮ-unless H B v x f (x ≡ⱽ f) p +<:-substitutivityᴮ H (local var y ∈ T ← M ∙ B) v x p = <:-substitutivityᴮ-unless H B v x y (x ≡ⱽ y) p +<:-substitutivityᴮ H (return M ∙ B) v x p = <:-substitutivityᴱ H M v x p +<:-substitutivityᴮ H done v x p = <:-refl +<:-substitutivityᴮ-unless H B v x y (yes r) p = <:-substitutivityᴮ-unless-yes H B v x y r (⊕-over r) +<:-substitutivityᴮ-unless H B v x y (no r) p = <:-substitutivityᴮ-unless-no H B v x y r (⊕-swap r) p +<:-substitutivityᴮ-unless-yes H B v x y refl refl = <:-refl +<:-substitutivityᴮ-unless-no H B v x y r refl p = <:-substitutivityᴮ H B v x p + +≮:-substitutivityᴱ : ∀ {Γ T U} H M v x → (typeOfᴱ H Γ (M [ v / x ]ᴱ) ≮: U) → Either (typeOfᴱ H (Γ ⊕ x ↦ T) M ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-substitutivityᴱ {T = T} H M v x p with dec-subtyping (typeOfᴱ H ∅ (val v)) T +≮:-substitutivityᴱ H M v x p | Left q = Right q +≮:-substitutivityᴱ H M v x p | Right q = Left (<:-trans-≮: (<:-substitutivityᴱ H M v x q) p) + +≮:-substitutivityᴮ : ∀ {Γ T U} H B v x → (typeOfᴮ H Γ (B [ v / x ]ᴮ) ≮: U) → Either (typeOfᴮ H (Γ ⊕ x ↦ T) B ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-substitutivityᴮ {T = T} H M v x p with dec-subtyping (typeOfᴱ H ∅ (val v)) T +≮:-substitutivityᴮ H M v x p | Left q = Right q +≮:-substitutivityᴮ H M v x p | Right q = Left (<:-trans-≮: (<:-substitutivityᴮ H M v x q) p) + +≮:-substitutivityᴮ-unless : ∀ {Γ T U V} H B v x y (r : Dec(x ≡ y)) → (typeOfᴮ H (Γ ⊕ y ↦ U) (B [ v / x ]ᴮunless r) ≮: V) → Either (typeOfᴮ H ((Γ ⊕ x ↦ T) ⊕ y ↦ U) B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-substitutivityᴮ-unless {T = T} H B v x y r p with dec-subtyping (typeOfᴱ H ∅ (val v)) T +≮:-substitutivityᴮ-unless H B v x y r p | Left q = Right q +≮:-substitutivityᴮ-unless H B v x y r p | Right q = Left (<:-trans-≮: (<:-substitutivityᴮ-unless H B v x y r q) p) + +<:-reductionᴱ : ∀ H M {H′ M′} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → Either (typeOfᴱ H′ ∅ M′ <: typeOfᴱ H ∅ M) (Warningᴱ H (typeCheckᴱ H ∅ M)) +<:-reductionᴮ : ∀ H B {H′ B′} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → Either (typeOfᴮ H′ ∅ B′ <: typeOfᴮ H ∅ B) (Warningᴮ H (typeCheckᴮ H ∅ B)) + +<:-reductionᴱ H (M $ N) (app₁ s) = mapLR (λ p → <:-resolve p (<:-heap-weakeningᴱ ∅ H N (rednᴱ⊑ s))) app₁ (<:-reductionᴱ H M s) +<:-reductionᴱ H (M $ N) (app₂ q s) = mapLR (λ p → <:-resolve (<:-heap-weakeningᴱ ∅ H M (rednᴱ⊑ s)) p) app₂ (<:-reductionᴱ H N s) +<:-reductionᴱ H (M $ N) (beta (function f ⟨ var y ∈ S ⟩∈ U is B end) v refl q) with dec-subtyping (typeOfᴱ H ∅ (val v)) S +<:-reductionᴱ H (M $ N) (beta (function f ⟨ var y ∈ S ⟩∈ U is B end) v refl q) | Left r = Right (FunctionCallMismatch (≮:-trans-≡ r (cong src (cong orUnknown (cong typeOfᴹᴼ (sym q)))))) +<:-reductionᴱ H (M $ N) (beta (function f ⟨ var y ∈ S ⟩∈ U is B end) v refl q) | Right r = Left (<:-trans-≡ (<:-resolve-⇒ r) (cong (λ F → resolve F (typeOfᴱ H ∅ N)) (cong orUnknown (cong typeOfᴹᴼ (sym q))))) +<:-reductionᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) = Left <:-refl +<:-reductionᴱ H (block var b ∈ T is B end) (block s) = Left <:-refl +<:-reductionᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) with dec-subtyping (typeOfᴱ H ∅ (val v)) T +<:-reductionᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) | Left p = Right (BlockMismatch p) +<:-reductionᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) | Right p = Left p +<:-reductionᴱ H (block var b ∈ T is done end) done with dec-subtyping nil T +<:-reductionᴱ H (block var b ∈ T is done end) done | Left p = Right (BlockMismatch p) +<:-reductionᴱ H (block var b ∈ T is done end) done | Right p = Left p +<:-reductionᴱ H (binexp M op N) (binOp₀ s) = Left (≡-impl-<: (sym (binOpPreservation H s))) +<:-reductionᴱ H (binexp M op N) (binOp₁ s) = Left <:-refl +<:-reductionᴱ H (binexp M op N) (binOp₂ s) = Left <:-refl + +<:-reductionᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) = Left (<:-trans (<:-substitutivityᴮ _ B (addr a) f <:-refl) (<:-heap-weakeningᴮ (f ↦ (T ⇒ U)) H B (snoc defn))) +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (local s) = Left (<:-heap-weakeningᴮ (x ↦ T) H B (rednᴱ⊑ s)) +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (subst v) with dec-subtyping (typeOfᴱ H ∅ (val v)) T +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (subst v) | Left p = Right (LocalVarMismatch p) +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (subst v) | Right p = Left (<:-substitutivityᴮ H B v x p) +<:-reductionᴮ H (return M ∙ B) (return s) = mapR return (<:-reductionᴱ H M s) + +≮:-reductionᴱ : ∀ H M {H′ M′ T} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → (typeOfᴱ H′ ∅ M′ ≮: T) → Either (typeOfᴱ H ∅ M ≮: T) (Warningᴱ H (typeCheckᴱ H ∅ M)) +≮:-reductionᴱ H M s p = mapL (λ q → <:-trans-≮: q p) (<:-reductionᴱ H M s) + +≮:-reductionᴮ : ∀ H B {H′ B′ T} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → (typeOfᴮ H′ ∅ B′ ≮: T) → Either (typeOfᴮ H ∅ B ≮: T) (Warningᴮ H (typeCheckᴮ H ∅ B)) +≮:-reductionᴮ H B s p = mapL (λ q → <:-trans-≮: q p) (<:-reductionᴮ H B s) reflect-substitutionᴱ : ∀ {Γ T} H M v x → Warningᴱ H (typeCheckᴱ H Γ (M [ v / x ]ᴱ)) → Either (Warningᴱ H (typeCheckᴱ H (Γ ⊕ x ↦ T) M)) (Either (Warningᴱ H (typeCheckᴱ H ∅ (val v))) (typeOfᴱ H ∅ (val v) ≮: T)) reflect-substitutionᴱ-whenever : ∀ {Γ T} H v x y (p : Dec(x ≡ y)) → Warningᴱ H (typeCheckᴱ H Γ (var y [ v / x ]ᴱwhenever p)) → Either (Warningᴱ H (typeCheckᴱ H (Γ ⊕ x ↦ T) (var y))) (Either (Warningᴱ H (typeCheckᴱ H ∅ (val v))) (typeOfᴱ H ∅ (val v) ≮: T)) @@ -150,29 +187,29 @@ reflect-substitutionᴮ-unless-no : ∀ {Γ Γ′ T} H B v x y (r : x ≢ y) → reflect-substitutionᴱ H (var y) v x W = reflect-substitutionᴱ-whenever H v x y (x ≡ⱽ y) W reflect-substitutionᴱ H (val (addr a)) v x (UnallocatedAddress r) = Left (UnallocatedAddress r) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) with substitutivityᴱ H N v x p +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) with ≮:-substitutivityᴱ H N v x p reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Right W = Right (Right W) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q with substitutivityᴱ H M v x (src-unknown-≮: q) +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q with ≮:-substitutivityᴱ H M v x (src-unknown-≮: q) reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Left r = Left ((FunctionCallMismatch ∘ unknown-src-≮: q) r) reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Right W = Right (Right W) reflect-substitutionᴱ H (M $ N) v x (app₁ W) = mapL app₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴱ H (M $ N) v x (app₂ W) = mapL app₂ (reflect-substitutionᴱ H N v x W) -reflect-substitutionᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (substitutivityᴮ-unless H B v x y (x ≡ⱽ y) q) +reflect-substitutionᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (≮:-substitutivityᴮ-unless H B v x y (x ≡ⱽ y) q) reflect-substitutionᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x (function₁ W) = mapL function₁ (reflect-substitutionᴮ-unless H B v x y (x ≡ⱽ y) W) -reflect-substitutionᴱ H (block var b ∈ T is B end) v x (BlockMismatch q) = mapLR BlockMismatch Right (substitutivityᴮ H B v x q) +reflect-substitutionᴱ H (block var b ∈ T is B end) v x (BlockMismatch q) = mapLR BlockMismatch Right (≮:-substitutivityᴮ H B v x q) reflect-substitutionᴱ H (block var b ∈ T is B end) v x (block₁ W′) = mapL block₁ (reflect-substitutionᴮ H B v x W′) -reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₁ q) = mapLR BinOpMismatch₁ Right (substitutivityᴱ H M v x q) -reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₂ q) = mapLR BinOpMismatch₂ Right (substitutivityᴱ H N v x q) +reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₁ q) = mapLR BinOpMismatch₁ Right (≮:-substitutivityᴱ H M v x q) +reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₂ q) = mapLR BinOpMismatch₂ Right (≮:-substitutivityᴱ H N v x q) reflect-substitutionᴱ H (binexp M op N) v x (bin₁ W) = mapL bin₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴱ H (binexp M op N) v x (bin₂ W) = mapL bin₂ (reflect-substitutionᴱ H N v x W) reflect-substitutionᴱ-whenever H a x x (yes refl) (UnallocatedAddress p) = Right (Left (UnallocatedAddress p)) reflect-substitutionᴱ-whenever H v x y (no p) (UnboundVariable q) = Left (UnboundVariable (trans (sym (⊕-lookup-miss x y _ _ p)) q)) -reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (substitutivityᴮ-unless H C v x y (x ≡ⱽ y) q) +reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (≮:-substitutivityᴮ-unless H C v x y (x ≡ⱽ y) q) reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (function₁ W) = mapL function₁ (reflect-substitutionᴮ-unless H C v x y (x ≡ⱽ y) W) reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (function₂ W) = mapL function₂ (reflect-substitutionᴮ-unless H B v x f (x ≡ⱽ f) W) -reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (LocalVarMismatch q) = mapLR LocalVarMismatch Right (substitutivityᴱ H M v x q) +reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (LocalVarMismatch q) = mapLR LocalVarMismatch Right (≮:-substitutivityᴱ H M v x q) reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (local₁ W) = mapL local₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (local₂ W) = mapL local₂ (reflect-substitutionᴮ-unless H B v x y (x ≡ⱽ y) W) reflect-substitutionᴮ H (return M ∙ B) v x (return W) = mapL return (reflect-substitutionᴱ H M v x W) @@ -187,61 +224,61 @@ reflect-weakeningᴮ : ∀ Γ H B {H′} → (H ⊑ H′) → Warningᴮ H′ (t reflect-weakeningᴱ Γ H (var x) h (UnboundVariable p) = (UnboundVariable p) reflect-weakeningᴱ Γ H (val (addr a)) h (UnallocatedAddress p) = UnallocatedAddress (lookup-⊑-nothing a h p) -reflect-weakeningᴱ Γ H (M $ N) h (FunctionCallMismatch p) = FunctionCallMismatch (heap-weakeningᴱ Γ H N h (unknown-src-≮: p (heap-weakeningᴱ Γ H M h (src-unknown-≮: p)))) +reflect-weakeningᴱ Γ H (M $ N) h (FunctionCallMismatch p) = FunctionCallMismatch (≮:-heap-weakeningᴱ Γ H N h (unknown-src-≮: p (≮:-heap-weakeningᴱ Γ H M h (src-unknown-≮: p)))) reflect-weakeningᴱ Γ H (M $ N) h (app₁ W) = app₁ (reflect-weakeningᴱ Γ H M h W) reflect-weakeningᴱ Γ H (M $ N) h (app₂ W) = app₂ (reflect-weakeningᴱ Γ H N h W) -reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₁ p) = BinOpMismatch₁ (heap-weakeningᴱ Γ H M h p) -reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₂ p) = BinOpMismatch₂ (heap-weakeningᴱ Γ H N h p) +reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₁ p) = BinOpMismatch₁ (≮:-heap-weakeningᴱ Γ H M h p) +reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₂ p) = BinOpMismatch₂ (≮:-heap-weakeningᴱ Γ H N h p) reflect-weakeningᴱ Γ H (binexp M op N) h (bin₁ W′) = bin₁ (reflect-weakeningᴱ Γ H M h W′) reflect-weakeningᴱ Γ H (binexp M op N) h (bin₂ W′) = bin₂ (reflect-weakeningᴱ Γ H N h W′) -reflect-weakeningᴱ Γ H (function f ⟨ var y ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (heap-weakeningᴮ (Γ ⊕ y ↦ T) H B h p) +reflect-weakeningᴱ Γ H (function f ⟨ var y ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (≮:-heap-weakeningᴮ (Γ ⊕ y ↦ T) H B h p) reflect-weakeningᴱ Γ H (function f ⟨ var y ∈ T ⟩∈ U is B end) h (function₁ W) = function₁ (reflect-weakeningᴮ (Γ ⊕ y ↦ T) H B h W) -reflect-weakeningᴱ Γ H (block var b ∈ T is B end) h (BlockMismatch p) = BlockMismatch (heap-weakeningᴮ Γ H B h p) +reflect-weakeningᴱ Γ H (block var b ∈ T is B end) h (BlockMismatch p) = BlockMismatch (≮:-heap-weakeningᴮ Γ H B h p) reflect-weakeningᴱ Γ H (block var b ∈ T is B end) h (block₁ W) = block₁ (reflect-weakeningᴮ Γ H B h W) reflect-weakeningᴮ Γ H (return M ∙ B) h (return W) = return (reflect-weakeningᴱ Γ H M h W) -reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (LocalVarMismatch p) = LocalVarMismatch (heap-weakeningᴱ Γ H M h p) +reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (LocalVarMismatch p) = LocalVarMismatch (≮:-heap-weakeningᴱ Γ H M h p) reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (local₁ W) = local₁ (reflect-weakeningᴱ Γ H M h W) reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (local₂ W) = local₂ (reflect-weakeningᴮ (Γ ⊕ y ↦ T) H B h W) -reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (FunctionDefnMismatch p) = FunctionDefnMismatch (heap-weakeningᴮ (Γ ⊕ x ↦ T) H C h p) +reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (FunctionDefnMismatch p) = FunctionDefnMismatch (≮:-heap-weakeningᴮ (Γ ⊕ x ↦ T) H C h p) reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (function₁ W) = function₁ (reflect-weakeningᴮ (Γ ⊕ x ↦ T) H C h W) reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (function₂ W) = function₂ (reflect-weakeningᴮ (Γ ⊕ f ↦ (T ⇒ U)) H B h W) reflect-weakeningᴼ : ∀ H O {H′} → (H ⊑ H′) → Warningᴼ H′ (typeCheckᴼ H′ O) → Warningᴼ H (typeCheckᴼ H O) -reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (heap-weakeningᴮ (x ↦ T) H B h p) +reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (≮:-heap-weakeningᴮ (x ↦ T) H B h p) reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (function₁ W) = function₁ (reflect-weakeningᴮ (x ↦ T) H B h W) reflectᴱ : ∀ H M {H′ M′} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → Warningᴱ H′ (typeCheckᴱ H′ ∅ M′) → Either (Warningᴱ H (typeCheckᴱ H ∅ M)) (Warningᴴ H (typeCheckᴴ H)) reflectᴮ : ∀ H B {H′ B′} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → Warningᴮ H′ (typeCheckᴮ H′ ∅ B′) → Either (Warningᴮ H (typeCheckᴮ H ∅ B)) (Warningᴴ H (typeCheckᴴ H)) -reflectᴱ H (M $ N) (app₁ s) (FunctionCallMismatch p) = cond (Left ∘ FunctionCallMismatch ∘ heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) ∘ unknown-src-≮: p) (Left ∘ app₁) (reflect-subtypingᴱ H M s (src-unknown-≮: p)) +reflectᴱ H (M $ N) (app₁ s) (FunctionCallMismatch p) = cond (Left ∘ FunctionCallMismatch ∘ ≮:-heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) ∘ unknown-src-≮: p) (Left ∘ app₁) (≮:-reductionᴱ H M s (src-unknown-≮: p)) reflectᴱ H (M $ N) (app₁ s) (app₁ W′) = mapL app₁ (reflectᴱ H M s W′) reflectᴱ H (M $ N) (app₁ s) (app₂ W′) = Left (app₂ (reflect-weakeningᴱ ∅ H N (rednᴱ⊑ s) W′)) -reflectᴱ H (M $ N) (app₂ p s) (FunctionCallMismatch q) = cond (λ r → Left (FunctionCallMismatch (unknown-src-≮: r (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (src-unknown-≮: r))))) (Left ∘ app₂) (reflect-subtypingᴱ H N s q) +reflectᴱ H (M $ N) (app₂ p s) (FunctionCallMismatch q) = cond (λ r → Left (FunctionCallMismatch (unknown-src-≮: r (≮:-heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (src-unknown-≮: r))))) (Left ∘ app₂) (≮:-reductionᴱ H N s q) reflectᴱ H (M $ N) (app₂ p s) (app₁ W′) = Left (app₁ (reflect-weakeningᴱ ∅ H M (rednᴱ⊑ s) W′)) reflectᴱ H (M $ N) (app₂ p s) (app₂ W′) = mapL app₂ (reflectᴱ H N s W′) -reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) with substitutivityᴮ H B v x q +reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) with ≮:-substitutivityᴮ H B v x q reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Left r = Right (addr a p (FunctionDefnMismatch r)) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Right r = Left (FunctionCallMismatch (≮:-trans-≡ r ((cong src (cong orUnknown (cong typeOfᴹᴼ (sym p))))))) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) with reflect-substitutionᴮ _ B v x W′ reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Left W = Right (addr a p (function₁ W)) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Left W) = Left (app₂ W) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Right q) = Left (FunctionCallMismatch (≮:-trans-≡ q (cong src (cong orUnknown (cong typeOfᴹᴼ (sym p)))))) -reflectᴱ H (block var b ∈ T is B end) (block s) (BlockMismatch p) = Left (cond BlockMismatch block₁ (reflect-subtypingᴮ H B s p)) +reflectᴱ H (block var b ∈ T is B end) (block s) (BlockMismatch p) = Left (cond BlockMismatch block₁ (≮:-reductionᴮ H B s p)) reflectᴱ H (block var b ∈ T is B end) (block s) (block₁ W′) = mapL block₁ (reflectᴮ H B s W′) reflectᴱ H (block var b ∈ T is B end) (return v) W′ = Left (block₁ (return W′)) reflectᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (UnallocatedAddress ()) reflectᴱ H (binexp M op N) (binOp₀ ()) (UnallocatedAddress p) -reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₁ p) = Left (cond BinOpMismatch₁ bin₁ (reflect-subtypingᴱ H M s p)) -reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₂ p) = Left (BinOpMismatch₂ (heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) p)) +reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₁ p) = Left (cond BinOpMismatch₁ bin₁ (≮:-reductionᴱ H M s p)) +reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₂ p) = Left (BinOpMismatch₂ (≮:-heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) p)) reflectᴱ H (binexp M op N) (binOp₁ s) (bin₁ W′) = mapL bin₁ (reflectᴱ H M s W′) reflectᴱ H (binexp M op N) (binOp₁ s) (bin₂ W′) = Left (bin₂ (reflect-weakeningᴱ ∅ H N (rednᴱ⊑ s) W′)) -reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₁ p) = Left (BinOpMismatch₁ (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) p)) -reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₂ p) = Left (cond BinOpMismatch₂ bin₂ (reflect-subtypingᴱ H N s p)) +reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₁ p) = Left (BinOpMismatch₁ (≮:-heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) p)) +reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₂ p) = Left (cond BinOpMismatch₂ bin₂ (≮:-reductionᴱ H N s p)) reflectᴱ H (binexp M op N) (binOp₂ s) (bin₁ W′) = Left (bin₁ (reflect-weakeningᴱ ∅ H M (rednᴱ⊑ s) W′)) reflectᴱ H (binexp M op N) (binOp₂ s) (bin₂ W′) = mapL bin₂ (reflectᴱ H N s W′) -reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (LocalVarMismatch p) = Left (cond LocalVarMismatch local₁ (reflect-subtypingᴱ H M s p)) +reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (LocalVarMismatch p) = Left (cond LocalVarMismatch local₁ (≮:-reductionᴱ H M s p)) reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (local₁ W′) = mapL local₁ (reflectᴱ H M s W′) reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (local₂ W′) = Left (local₂ (reflect-weakeningᴮ (x ↦ T) H B (rednᴱ⊑ s) W′)) reflectᴮ H (local var x ∈ T ← M ∙ B) (subst v) W′ = Left (cond local₂ (cond local₁ LocalVarMismatch) (reflect-substitutionᴮ H B v x W′)) @@ -258,7 +295,7 @@ reflectᴴᴱ H (M $ N) (app₁ s) W = mapL app₁ (reflectᴴᴱ H M s W) reflectᴴᴱ H (M $ N) (app₂ v s) W = mapL app₂ (reflectᴴᴱ H N s W) reflectᴴᴱ H (M $ N) (beta O v refl p) W = Right W reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a p) (addr b refl W) with b ≡ᴬ a -reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (heap-weakeningᴮ (x ↦ T) H B (snoc defn) p)) +reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (≮:-heap-weakeningᴮ (x ↦ T) H B (snoc defn) p)) reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (addr b refl (function₁ W)) | yes refl = Left (function₁ (reflect-weakeningᴮ (x ↦ T) H B (snoc defn) W)) reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a p) (addr b refl W) | no q = Right (addr b (lookup-not-allocated p q) (reflect-weakeningᴼ H _ (snoc p) W)) reflectᴴᴱ H (block var b ∈ T is B end) (block s) W = mapL block₁ (reflectᴴᴮ H B s W) @@ -269,7 +306,7 @@ reflectᴴᴱ H (binexp M op N) (binOp₁ s) W = mapL bin₁ (reflectᴴᴱ H M reflectᴴᴱ H (binexp M op N) (binOp₂ s) W = mapL bin₂ (reflectᴴᴱ H N s W) reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a p) (addr b refl W) with b ≡ᴬ a -reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (heap-weakeningᴮ (x ↦ T) H C (snoc defn) p)) +reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (≮:-heap-weakeningᴮ (x ↦ T) H C (snoc defn) p)) reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) (addr b refl (function₁ W)) | yes refl = Left (function₁ (reflect-weakeningᴮ (x ↦ T) H C (snoc defn) W)) reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a p) (addr b refl W) | no q = Right (addr b (lookup-not-allocated p q) (reflect-weakeningᴼ H _ (snoc p) W)) reflectᴴᴮ H (local var x ∈ T ← M ∙ B) (local s) W = mapL local₁ (reflectᴴᴱ H M s W) diff --git a/prototyping/Properties/Subtyping.agda b/prototyping/Properties/Subtyping.agda index 34e6691f..73bf0e9a 100644 --- a/prototyping/Properties/Subtyping.agda +++ b/prototyping/Properties/Subtyping.agda @@ -5,7 +5,7 @@ module Properties.Subtyping where open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) open import FFI.Data.Maybe using (Maybe; just; nothing) -open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; scalar-scalar; function-scalar; function-ok; function-ok₁; function-ok₂; function-err; function-tgt; left; right; _,_) open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; skalar) open import Properties.Contradiction using (CONTRADICTION; ¬; ⊥) open import Properties.Equality using (_≢_) @@ -19,37 +19,42 @@ dec-language nil (scalar boolean) = Left (scalar-scalar boolean nil (λ ())) dec-language nil (scalar string) = Left (scalar-scalar string nil (λ ())) dec-language nil (scalar nil) = Right (scalar nil) dec-language nil function = Left (scalar-function nil) -dec-language nil (function-ok t) = Left (scalar-function-ok nil) +dec-language nil (function-ok s t) = Left (scalar-function-ok nil) dec-language nil (function-err t) = Left (scalar-function-err nil) dec-language boolean (scalar number) = Left (scalar-scalar number boolean (λ ())) dec-language boolean (scalar boolean) = Right (scalar boolean) dec-language boolean (scalar string) = Left (scalar-scalar string boolean (λ ())) dec-language boolean (scalar nil) = Left (scalar-scalar nil boolean (λ ())) dec-language boolean function = Left (scalar-function boolean) -dec-language boolean (function-ok t) = Left (scalar-function-ok boolean) +dec-language boolean (function-ok s t) = Left (scalar-function-ok boolean) dec-language boolean (function-err t) = Left (scalar-function-err boolean) dec-language number (scalar number) = Right (scalar number) dec-language number (scalar boolean) = Left (scalar-scalar boolean number (λ ())) dec-language number (scalar string) = Left (scalar-scalar string number (λ ())) dec-language number (scalar nil) = Left (scalar-scalar nil number (λ ())) dec-language number function = Left (scalar-function number) -dec-language number (function-ok t) = Left (scalar-function-ok number) +dec-language number (function-ok s t) = Left (scalar-function-ok number) dec-language number (function-err t) = Left (scalar-function-err number) dec-language string (scalar number) = Left (scalar-scalar number string (λ ())) dec-language string (scalar boolean) = Left (scalar-scalar boolean string (λ ())) dec-language string (scalar string) = Right (scalar string) dec-language string (scalar nil) = Left (scalar-scalar nil string (λ ())) dec-language string function = Left (scalar-function string) -dec-language string (function-ok t) = Left (scalar-function-ok string) +dec-language string (function-ok s t) = Left (scalar-function-ok string) dec-language string (function-err t) = Left (scalar-function-err string) dec-language (T₁ ⇒ T₂) (scalar s) = Left (function-scalar s) dec-language (T₁ ⇒ T₂) function = Right function -dec-language (T₁ ⇒ T₂) (function-ok t) = mapLR function-ok function-ok (dec-language T₂ t) +dec-language (T₁ ⇒ T₂) (function-ok s t) = cond (Right ∘ function-ok₁) (λ p → mapLR (function-ok p) function-ok₂ (dec-language T₂ t)) (dec-language T₁ s) dec-language (T₁ ⇒ T₂) (function-err t) = mapLR function-err function-err (swapLR (dec-language T₁ t)) dec-language never t = Left never dec-language unknown t = Right unknown dec-language (T₁ ∪ T₂) t = cond (λ p → cond (Left ∘ _,_ p) (Right ∘ right) (dec-language T₂ t)) (Right ∘ left) (dec-language T₁ t) dec-language (T₁ ∩ T₂) t = cond (Left ∘ left) (λ p → cond (Left ∘ right) (Right ∘ _,_ p) (dec-language T₂ t)) (dec-language T₁ t) +dec-language nil (function-tgt t) = Left (scalar-function-tgt nil) +dec-language (T₁ ⇒ T₂) (function-tgt t) = mapLR function-tgt function-tgt (dec-language T₂ t) +dec-language boolean (function-tgt t) = Left (scalar-function-tgt boolean) +dec-language number (function-tgt t) = Left (scalar-function-tgt number) +dec-language string (function-tgt t) = Left (scalar-function-tgt string) -- ¬Language T is the complement of Language T language-comp : ∀ {T} t → ¬Language T t → ¬(Language T t) @@ -61,9 +66,12 @@ language-comp (scalar s) (scalar-scalar s p₁ p₂) (scalar s) = p₂ refl language-comp (scalar s) (function-scalar s) (scalar s) = language-comp function (scalar-function s) function language-comp (scalar s) never (scalar ()) language-comp function (scalar-function ()) function -language-comp (function-ok t) (scalar-function-ok ()) (function-ok q) -language-comp (function-ok t) (function-ok p) (function-ok q) = language-comp t p q -language-comp (function-err t) (function-err p) (function-err q) = language-comp t q p +language-comp (function-ok s t) (scalar-function-ok ()) (function-ok₁ p) +language-comp (function-ok s t) (function-ok p₁ p₂) (function-ok₁ q) = language-comp s q p₁ +language-comp (function-ok s t) (function-ok p₁ p₂) (function-ok₂ q) = language-comp t p₂ q +language-comp (function-err t) (function-err p) (function-err q) = language-comp t q p +language-comp (function-tgt t) (scalar-function-tgt ()) (function-tgt q) +language-comp (function-tgt t) (function-tgt p) (function-tgt q) = language-comp t p q -- ≮: is the complement of <: ¬≮:-impl-<: : ∀ {T U} → ¬(T ≮: U) → (T <: U) @@ -90,9 +98,18 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp ≮:-trans-≡ : ∀ {S T U} → (S ≮: T) → (T ≡ U) → (S ≮: U) ≮:-trans-≡ p refl = p +<:-trans-≡ : ∀ {S T U} → (S <: T) → (T ≡ U) → (S <: U) +<:-trans-≡ p refl = p + +≡-impl-<: : ∀ {T U} → (T ≡ U) → (T <: U) +≡-impl-<: refl = <:-refl + ≡-trans-≮: : ∀ {S T U} → (S ≡ T) → (T ≮: U) → (S ≮: U) ≡-trans-≮: refl p = p +≡-trans-<: : ∀ {S T U} → (S ≡ T) → (T <: U) → (S <: U) +≡-trans-<: refl p = p + ≮:-trans : ∀ {S T U} → (S ≮: U) → Either (S ≮: T) (T ≮: U) ≮:-trans {T = T} (witness t p q) = mapLR (witness t p) (λ z → witness t z q) (dec-language T t) @@ -141,6 +158,12 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp ≮:-∪-right : ∀ {S T U} → (T ≮: U) → ((S ∪ T) ≮: U) ≮:-∪-right (witness t p q) = witness t (right p) q +≮:-left-∪ : ∀ {S T U} → (S ≮: (T ∪ U)) → (S ≮: T) +≮:-left-∪ (witness t p (q₁ , q₂)) = witness t p q₁ + +≮:-right-∪ : ∀ {S T U} → (S ≮: (T ∪ U)) → (S ≮: U) +≮:-right-∪ (witness t p (q₁ , q₂)) = witness t p q₂ + -- Properties of intersection <:-intersect : ∀ {R S T U} → (R <: T) → (S <: U) → ((R ∩ S) <: (T ∩ U)) @@ -158,6 +181,12 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp <:-∩-symm : ∀ {T U} → (T ∩ U) <: (U ∩ T) <:-∩-symm t (p₁ , p₂) = (p₂ , p₁) +<:-∩-assocl : ∀ {S T U} → (S ∩ (T ∩ U)) <: ((S ∩ T) ∩ U) +<:-∩-assocl t (p , (p₁ , p₂)) = (p , p₁) , p₂ + +<:-∩-assocr : ∀ {S T U} → ((S ∩ T) ∩ U) <: (S ∩ (T ∩ U)) +<:-∩-assocr t ((p , p₁) , p₂) = p , (p₁ , p₂) + ≮:-∩-left : ∀ {S T U} → (S ≮: T) → (S ≮: (T ∩ U)) ≮:-∩-left (witness t p q) = witness t p (left q) @@ -199,47 +228,84 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp ∪-distr-∩-<: t (left p₁ , right p₂) = right p₂ ∪-distr-∩-<: t (right p₁ , p₂) = right p₁ +∩-<:-∪ : ∀ {S T} → (S ∩ T) <: (S ∪ T) +∩-<:-∪ t (p , _) = left p + -- Properties of functions <:-function : ∀ {R S T U} → (R <: S) → (T <: U) → (S ⇒ T) <: (R ⇒ U) <:-function p q function function = function -<:-function p q (function-ok t) (function-ok r) = function-ok (q t r) +<:-function p q (function-ok s t) (function-ok₁ r) = function-ok₁ (<:-impl-⊇ p s r) +<:-function p q (function-ok s t) (function-ok₂ r) = function-ok₂ (q t r) <:-function p q (function-err s) (function-err r) = function-err (<:-impl-⊇ p s r) +<:-function p q (function-tgt t) (function-tgt r) = function-tgt (q t r) + +<:-function-∩-∩ : ∀ {R S T U} → ((R ⇒ T) ∩ (S ⇒ U)) <: ((R ∩ S) ⇒ (T ∩ U)) +<:-function-∩-∩ function (function , function) = function +<:-function-∩-∩ (function-ok s t) (function-ok₁ p , q) = function-ok₁ (left p) +<:-function-∩-∩ (function-ok s t) (function-ok₂ p , function-ok₁ q) = function-ok₁ (right q) +<:-function-∩-∩ (function-ok s t) (function-ok₂ p , function-ok₂ q) = function-ok₂ (p , q) +<:-function-∩-∩ (function-err s) (function-err p , q) = function-err (left p) +<:-function-∩-∩ (function-tgt s) (function-tgt p , function-tgt q) = function-tgt (p , q) <:-function-∩-∪ : ∀ {R S T U} → ((R ⇒ T) ∩ (S ⇒ U)) <: ((R ∪ S) ⇒ (T ∪ U)) <:-function-∩-∪ function (function , function) = function -<:-function-∩-∪ (function-ok t) (function-ok p₁ , function-ok p₂) = function-ok (right p₂) -<:-function-∩-∪ (function-err _) (function-err p₁ , function-err q₂) = function-err (p₁ , q₂) +<:-function-∩-∪ (function-ok s t) (function-ok₁ p₁ , function-ok₁ p₂) = function-ok₁ (p₁ , p₂) +<:-function-∩-∪ (function-ok s t) (p₁ , function-ok₂ p₂) = function-ok₂ (right p₂) +<:-function-∩-∪ (function-ok s t) (function-ok₂ p₁ , p₂) = function-ok₂ (left p₁) +<:-function-∩-∪ (function-err s) (function-err p₁ , function-err q₂) = function-err (p₁ , q₂) +<:-function-∩-∪ (function-tgt t) (function-tgt p , q) = function-tgt (left p) <:-function-∩ : ∀ {S T U} → ((S ⇒ T) ∩ (S ⇒ U)) <: (S ⇒ (T ∩ U)) <:-function-∩ function (function , function) = function -<:-function-∩ (function-ok t) (function-ok p₁ , function-ok p₂) = function-ok (p₁ , p₂) +<:-function-∩ (function-ok s t) (p₁ , function-ok₁ p₂) = function-ok₁ p₂ +<:-function-∩ (function-ok s t) (function-ok₁ p₁ , p₂) = function-ok₁ p₁ +<:-function-∩ (function-ok s t) (function-ok₂ p₁ , function-ok₂ p₂) = function-ok₂ (p₁ , p₂) <:-function-∩ (function-err s) (function-err p₁ , function-err p₂) = function-err p₂ +<:-function-∩ (function-tgt t) (function-tgt p₁ , function-tgt p₂) = function-tgt (p₁ , p₂) <:-function-∪ : ∀ {R S T U} → ((R ⇒ S) ∪ (T ⇒ U)) <: ((R ∩ T) ⇒ (S ∪ U)) <:-function-∪ function (left function) = function -<:-function-∪ (function-ok t) (left (function-ok p)) = function-ok (left p) +<:-function-∪ (function-ok s t) (left (function-ok₁ p)) = function-ok₁ (left p) +<:-function-∪ (function-ok s t) (left (function-ok₂ p)) = function-ok₂ (left p) <:-function-∪ (function-err s) (left (function-err p)) = function-err (left p) <:-function-∪ (scalar s) (left (scalar ())) <:-function-∪ function (right function) = function -<:-function-∪ (function-ok t) (right (function-ok p)) = function-ok (right p) +<:-function-∪ (function-ok s t) (right (function-ok₁ p)) = function-ok₁ (right p) +<:-function-∪ (function-ok s t) (right (function-ok₂ p)) = function-ok₂ (right p) <:-function-∪ (function-err s) (right (function-err x)) = function-err (right x) <:-function-∪ (scalar s) (right (scalar ())) +<:-function-∪ (function-tgt t) (left (function-tgt p)) = function-tgt (left p) +<:-function-∪ (function-tgt t) (right (function-tgt p)) = function-tgt (right p) <:-function-∪-∩ : ∀ {R S T U} → ((R ∩ S) ⇒ (T ∪ U)) <: ((R ⇒ T) ∪ (S ⇒ U)) <:-function-∪-∩ function function = left function -<:-function-∪-∩ (function-ok t) (function-ok (left p)) = left (function-ok p) -<:-function-∪-∩ (function-ok t) (function-ok (right p)) = right (function-ok p) +<:-function-∪-∩ (function-ok s t) (function-ok₁ (left p)) = left (function-ok₁ p) +<:-function-∪-∩ (function-ok s t) (function-ok₂ (left p)) = left (function-ok₂ p) +<:-function-∪-∩ (function-ok s t) (function-ok₁ (right p)) = right (function-ok₁ p) +<:-function-∪-∩ (function-ok s t) (function-ok₂ (right p)) = right (function-ok₂ p) <:-function-∪-∩ (function-err s) (function-err (left p)) = left (function-err p) <:-function-∪-∩ (function-err s) (function-err (right p)) = right (function-err p) +<:-function-∪-∩ (function-tgt t) (function-tgt (left p)) = left (function-tgt p) +<:-function-∪-∩ (function-tgt t) (function-tgt (right p)) = right (function-tgt p) + +<:-function-left : ∀ {R S T U} → (S ⇒ T) <: (R ⇒ U) → (R <: S) +<:-function-left {R} {S} p s Rs with dec-language S s +<:-function-left p s Rs | Right Ss = Ss +<:-function-left p s Rs | Left ¬Ss with p (function-err s) (function-err ¬Ss) +<:-function-left p s Rs | Left ¬Ss | function-err ¬Rs = CONTRADICTION (language-comp s ¬Rs Rs) + +<:-function-right : ∀ {R S T U} → (S ⇒ T) <: (R ⇒ U) → (T <: U) +<:-function-right p t Tt with p (function-tgt t) (function-tgt Tt) +<:-function-right p t Tt | function-tgt St = St ≮:-function-left : ∀ {R S T U} → (R ≮: S) → (S ⇒ T) ≮: (R ⇒ U) ≮:-function-left (witness t p q) = witness (function-err t) (function-err q) (function-err p) ≮:-function-right : ∀ {R S T U} → (T ≮: U) → (S ⇒ T) ≮: (R ⇒ U) -≮:-function-right (witness t p q) = witness (function-ok t) (function-ok p) (function-ok q) +≮:-function-right (witness t p q) = witness (function-tgt t) (function-tgt p) (function-tgt q) -- Properties of scalars -skalar-function-ok : ∀ {t} → (¬Language skalar (function-ok t)) +skalar-function-ok : ∀ {s t} → (¬Language skalar (function-ok s t)) skalar-function-ok = (scalar-function-ok number , (scalar-function-ok string , (scalar-function-ok nil , scalar-function-ok boolean))) scalar-<: : ∀ {S T} → (s : Scalar S) → Language T (scalar s) → (S <: T) @@ -261,7 +327,7 @@ scalar-≮:-function : ∀ {S T U} → (Scalar U) → (U ≮: (S ⇒ T)) scalar-≮:-function s = witness (scalar s) (scalar s) (function-scalar s) unknown-≮:-scalar : ∀ {U} → (Scalar U) → (unknown ≮: U) -unknown-≮:-scalar s = witness (function-ok (scalar s)) unknown (scalar-function-ok s) +unknown-≮:-scalar s = witness function unknown (scalar-function s) scalar-≮:-never : ∀ {U} → (Scalar U) → (U ≮: never) scalar-≮:-never s = witness (scalar s) (scalar s) never @@ -288,6 +354,9 @@ never-≮: (witness t p q) = witness t p never unknown-≮:-never : (unknown ≮: never) unknown-≮:-never = witness (scalar nil) unknown never +unknown-≮:-function : ∀ {S T} → (unknown ≮: (S ⇒ T)) +unknown-≮:-function = witness (scalar nil) unknown (function-scalar nil) + function-≮:-never : ∀ {T U} → ((T ⇒ U) ≮: never) function-≮:-never = witness function function never @@ -310,8 +379,9 @@ function-≮:-never = witness function function never <:-everything : unknown <: ((never ⇒ unknown) ∪ skalar) <:-everything (scalar s) p = right (skalar-scalar s) <:-everything function p = left function -<:-everything (function-ok t) p = left (function-ok unknown) +<:-everything (function-ok s t) p = left (function-ok₁ never) <:-everything (function-err s) p = left (function-err never) +<:-everything (function-tgt t) p = left (function-tgt unknown) -- A Gentle Introduction To Semantic Subtyping (https://www.cduce.org/papers/gentle.pdf) -- defines a "set-theoretic" model (sec 2.5) @@ -351,8 +421,9 @@ set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , just u) Qtu (S₂t , ¬T S₁t | Right r = r ¬T₁u : ¬(Language T₁ u) - ¬T₁u T₁u with p (function-ok u) (function-ok T₁u) - ¬T₁u T₁u | function-ok T₂u = ¬T₂u T₂u + ¬T₁u T₁u with p (function-ok t u) (function-ok₂ T₁u) + ¬T₁u T₁u | function-ok₁ ¬S₂t = language-comp t ¬S₂t S₂t + ¬T₁u T₁u | function-ok₂ T₂u = ¬T₂u T₂u set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , nothing) Qt- (S₂t , _) = q (t , nothing) Qt- (S₁t , λ ()) where @@ -365,33 +436,41 @@ set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , nothing) Qt- (S₂t , _) not-quite-set-theoretic-only-if : ∀ {S₁ T₁ S₂ T₂} → -- We don't quite have that this is a set-theoretic model - -- it's only true when Language T₁ and ¬Language T₂ t₂ are inhabited - -- in particular it's not true when T₁ is never, or T₂ is unknown. - ∀ s₂ t₂ → Language S₂ s₂ → ¬Language T₂ t₂ → + -- it's only true when Language S₂ is inhabited + -- in particular it's not true when S₂ is never, + ∀ s₂ → Language S₂ s₂ → -- This is the "only if" part of being a set-theoretic model (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) → Q ⊆ Comp((Language S₂) ⊗ Comp(Lift(Language T₂)))) → (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) -not-quite-set-theoretic-only-if {S₁} {T₁} {S₂} {T₂} s₂ t₂ S₂s₂ ¬T₂t₂ p = r where +not-quite-set-theoretic-only-if {S₁} {T₁} {S₂} {T₂} s₂ S₂s₂ p = r where Q : (Tree × Maybe Tree) → Set Q (t , just u) = Either (¬Language S₁ t) (Language T₁ u) Q (t , nothing) = ¬Language S₁ t - - q : Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) + + q : Q ⊆ Comp(Language S₁ ⊗ Comp(Lift(Language T₁))) q (t , just u) (Left ¬S₁t) (S₁t , ¬T₁u) = language-comp t ¬S₁t S₁t q (t , just u) (Right T₂u) (S₁t , ¬T₁u) = ¬T₁u T₂u q (t , nothing) ¬S₁t (S₁t , _) = language-comp t ¬S₁t S₁t - + r : Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂) r function function = function r (function-err s) (function-err ¬S₁s) with dec-language S₂ s r (function-err s) (function-err ¬S₁s) | Left ¬S₂s = function-err ¬S₂s r (function-err s) (function-err ¬S₁s) | Right S₂s = CONTRADICTION (p Q q (s , nothing) ¬S₁s (S₂s , λ ())) - r (function-ok t) (function-ok T₁t) with dec-language T₂ t - r (function-ok t) (function-ok T₁t) | Left ¬T₂t = CONTRADICTION (p Q q (s₂ , just t) (Right T₁t) (S₂s₂ , language-comp t ¬T₂t)) - r (function-ok t) (function-ok T₁t) | Right T₂t = function-ok T₂t + r (function-ok s t) (function-ok₁ ¬S₁s) with dec-language S₂ s + r (function-ok s t) (function-ok₁ ¬S₁s) | Left ¬S₂s = function-ok₁ ¬S₂s + r (function-ok s t) (function-ok₁ ¬S₁s) | Right S₂s = CONTRADICTION (p Q q (s , nothing) ¬S₁s (S₂s , λ ())) + r (function-ok s t) (function-ok₂ T₁t) with dec-language T₂ t + r (function-ok s t) (function-ok₂ T₁t) | Left ¬T₂t with dec-language S₂ s + r (function-ok s t) (function-ok₂ T₁t) | Left ¬T₂t | Left ¬S₂s = function-ok₁ ¬S₂s + r (function-ok s t) (function-ok₂ T₁t) | Left ¬T₂t | Right S₂s = CONTRADICTION (p Q q (s , just t) (Right T₁t) (S₂s , language-comp t ¬T₂t)) + r (function-ok s t) (function-ok₂ T₁t) | Right T₂t = function-ok₂ T₂t + r (function-tgt t) (function-tgt T₁t) with dec-language T₂ t + r (function-tgt t) (function-tgt T₁t) | Left ¬T₂t = CONTRADICTION (p Q q (s₂ , just t) (Right T₁t) (S₂s₂ , language-comp t ¬T₂t)) + r (function-tgt t) (function-tgt T₁t) | Right T₂t = function-tgt T₂t -- A counterexample when the argument type is empty. @@ -399,22 +478,4 @@ set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language never) ⊗ Co set-theoretic-counterexample-one Q q ((scalar s) , u) Qtu (scalar () , p) set-theoretic-counterexample-two : (never ⇒ number) ≮: (never ⇒ string) -set-theoretic-counterexample-two = witness - (function-ok (scalar number)) (function-ok (scalar number)) - (function-ok (scalar-scalar number string (λ ()))) - --- At some point we may deal with overloaded function resolution, which should fix this problem... --- The reason why this is connected to overloaded functions is that currently we have that the type of --- f(x) is (tgt T) where f:T. Really we should have the type depend on the type of x, that is use (tgt T U), --- where U is the type of x. In particular (tgt (S => T) (U & V)) should be the same as (tgt ((S&U) => T) V) --- and tgt(never => T) should be unknown. For example --- --- tgt((number => string) & (string => bool))(number) --- is tgt(number => string)(number) & tgt(string => bool)(number) --- is tgt(number => string)(number) & tgt(string => bool)(number&unknown) --- is tgt(number => string)(number) & tgt(string&number => bool)(unknown) --- is tgt(number => string)(number) & tgt(never => bool)(unknown) --- is string & unknown --- is string --- --- there's some discussion of this in the Gentle Introduction paper. +set-theoretic-counterexample-two = witness (function-tgt (scalar number)) (function-tgt (scalar number)) (function-tgt (scalar-scalar number string (λ ()))) diff --git a/prototyping/Properties/TypeCheck.agda b/prototyping/Properties/TypeCheck.agda index 37fbeda5..b53bbd04 100644 --- a/prototyping/Properties/TypeCheck.agda +++ b/prototyping/Properties/TypeCheck.agda @@ -6,9 +6,9 @@ open import Agda.Builtin.Equality using (_≡_; refl) open import Agda.Builtin.Bool using (Bool; true; false) open import FFI.Data.Maybe using (Maybe; just; nothing) open import FFI.Data.Either using (Either) +open import Luau.ResolveOverloads using (resolve) open import Luau.TypeCheck using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orUnknown; tgtBinOp) open import Luau.Syntax using (Block; Expr; Value; BinaryOperator; yes; nil; addr; number; bool; string; val; var; binexp; _$_; function_is_end; block_is_end; _∙_; return; done; local_←_; _⟨_⟩; _⟨_⟩∈_; var_∈_; name; fun; arg; +; -; *; /; <; >; ==; ~=; <=; >=) -open import Luau.FunctionTypes using (src; tgt) open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_) open import Luau.RuntimeType using (RuntimeType; nil; number; function; string; valueType) open import Luau.VarCtxt using (VarCtxt; ∅; _↦_; _⊕_↦_; _⋒_; _⊝_) renaming (_[_] to _[_]ⱽ) @@ -40,7 +40,7 @@ typeOfᴮ : Heap yes → VarCtxt → (Block yes) → Type typeOfᴱ H Γ (var x) = orUnknown(Γ [ x ]ⱽ) typeOfᴱ H Γ (val v) = orUnknown(typeOfⱽ H v) -typeOfᴱ H Γ (M $ N) = tgt(typeOfᴱ H Γ M) +typeOfᴱ H Γ (M $ N) = resolve (typeOfᴱ H Γ M) (typeOfᴱ H Γ N) typeOfᴱ H Γ (function f ⟨ var x ∈ S ⟩∈ T is B end) = S ⇒ T typeOfᴱ H Γ (block var b ∈ T is B end) = T typeOfᴱ H Γ (binexp M op N) = tgtBinOp op @@ -50,14 +50,6 @@ typeOfᴮ H Γ (local var x ∈ T ← M ∙ B) = typeOfᴮ H (Γ ⊕ x ↦ T) B typeOfᴮ H Γ (return M ∙ B) = typeOfᴱ H Γ M typeOfᴮ H Γ done = nil -mustBeFunction : ∀ H Γ v → (never ≢ src (typeOfᴱ H Γ (val v))) → (function ≡ valueType(v)) -mustBeFunction H Γ nil p = CONTRADICTION (p refl) -mustBeFunction H Γ (addr a) p = refl -mustBeFunction H Γ (number n) p = CONTRADICTION (p refl) -mustBeFunction H Γ (bool true) p = CONTRADICTION (p refl) -mustBeFunction H Γ (bool false) p = CONTRADICTION (p refl) -mustBeFunction H Γ (string x) p = CONTRADICTION (p refl) - mustBeNumber : ∀ H Γ v → (typeOfᴱ H Γ (val v) ≡ number) → (valueType(v) ≡ number) mustBeNumber H Γ (addr a) p with remember (H [ a ]ᴴ) mustBeNumber H Γ (addr a) p | (just O , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p diff --git a/prototyping/Properties/TypeNormalization.agda b/prototyping/Properties/TypeNormalization.agda index 299f648c..cbd8139f 100644 --- a/prototyping/Properties/TypeNormalization.agda +++ b/prototyping/Properties/TypeNormalization.agda @@ -3,12 +3,12 @@ module Properties.TypeNormalization where open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) -open import Luau.Subtyping using (scalar-function-err) +open import Luau.Subtyping using (Tree; Language; ¬Language; function; scalar; unknown; left; right; function-ok₁; function-ok₂; function-err; function-tgt; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; function-scalar; _,_) open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_; _∪ᶠ_; _∪ⁿˢ_; _∩ⁿˢ_; normalize) -open import Luau.Subtyping using (_<:_) +open import Luau.Subtyping using (_<:_; _≮:_; witness; never) open import Properties.Subtyping using (<:-trans; <:-refl; <:-unknown; <:-never; <:-∪-left; <:-∪-right; <:-∪-lub; <:-∩-left; <:-∩-right; <:-∩-glb; <:-∩-symm; <:-function; <:-function-∪-∩; <:-function-∩-∪; <:-function-∪; <:-everything; <:-union; <:-∪-assocl; <:-∪-assocr; <:-∪-symm; <:-intersect; ∪-distl-∩-<:; ∪-distr-∩-<:; <:-∪-distr-∩; <:-∪-distl-∩; ∩-distl-∪-<:; <:-∩-distl-∪; <:-∩-distr-∪; scalar-∩-function-<:-never; scalar-≢-∩-<:-never) --- Notmal forms for types +-- Normal forms for types data FunType : Type → Set data Normal : Type → Set @@ -17,11 +17,11 @@ data FunType where _∩_ : ∀ {F G} → FunType F → FunType G → FunType (F ∩ G) data Normal where - never : Normal never - unknown : Normal unknown _⇒_ : ∀ {S T} → Normal S → Normal T → Normal (S ⇒ T) _∩_ : ∀ {F G} → FunType F → FunType G → Normal (F ∩ G) _∪_ : ∀ {S T} → Normal S → Scalar T → Normal (S ∪ T) + never : Normal never + unknown : Normal unknown data OptScalar : Type → Set where never : OptScalar never @@ -30,6 +30,38 @@ data OptScalar : Type → Set where string : OptScalar string nil : OptScalar nil +-- Top function type +fun-top : ∀ {F} → (FunType F) → (F <: (never ⇒ unknown)) +fun-top (S ⇒ T) = <:-function <:-never <:-unknown +fun-top (F ∩ G) = <:-trans <:-∩-left (fun-top F) + +-- function types are inhabited +fun-function : ∀ {F} → FunType F → Language F function +fun-function (S ⇒ T) = function +fun-function (F ∩ G) = (fun-function F , fun-function G) + +fun-≮:-never : ∀ {F} → FunType F → (F ≮: never) +fun-≮:-never F = witness function (fun-function F) never + +-- function types aren't scalars +fun-¬scalar : ∀ {F S t} → (s : Scalar S) → FunType F → Language F t → ¬Language S t +fun-¬scalar s (S ⇒ T) function = scalar-function s +fun-¬scalar s (S ⇒ T) (function-ok₁ p) = scalar-function-ok s +fun-¬scalar s (S ⇒ T) (function-ok₂ p) = scalar-function-ok s +fun-¬scalar s (S ⇒ T) (function-err p) = scalar-function-err s +fun-¬scalar s (S ⇒ T) (function-tgt p) = scalar-function-tgt s +fun-¬scalar s (F ∩ G) (p₁ , p₂) = fun-¬scalar s G p₂ + +¬scalar-fun : ∀ {F S} → FunType F → (s : Scalar S) → ¬Language F (scalar s) +¬scalar-fun (S ⇒ T) s = function-scalar s +¬scalar-fun (F ∩ G) s = left (¬scalar-fun F s) + +scalar-≮:-fun : ∀ {F S} → FunType F → Scalar S → S ≮: F +scalar-≮:-fun F s = witness (scalar s) (scalar s) (¬scalar-fun F s) + +unknown-≮:-fun : ∀ {F} → FunType F → unknown ≮: F +unknown-≮:-fun F = witness (scalar nil) unknown (¬scalar-fun F nil) + -- Normalization produces normal types normal : ∀ T → Normal (normalize T) normalᶠ : ∀ {F} → FunType F → Normal F @@ -40,7 +72,7 @@ normal-∩ⁿˢ : ∀ {S T} → Normal S → Scalar T → OptScalar (S ∩ⁿˢ normal-∪ᶠ : ∀ {F G} → FunType F → FunType G → FunType (F ∪ᶠ G) normal nil = never ∪ nil -normal (S ⇒ T) = normalᶠ ((normal S) ⇒ (normal T)) +normal (S ⇒ T) = (normal S) ⇒ (normal T) normal never = never normal unknown = unknown normal boolean = never ∪ boolean @@ -338,7 +370,7 @@ flipper = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) <:-∪ ∪-<:-∪ⁿ unknown (T ⇒ U) = <:-unknown ∪-<:-∪ⁿ (R ⇒ S) (T ⇒ U) = ∪-<:-∪ᶠ (R ⇒ S) (T ⇒ U) ∪-<:-∪ⁿ (R ∩ S) (T ⇒ U) = ∪-<:-∪ᶠ (R ∩ S) (T ⇒ U) -∪-<:-∪ⁿ (R ∪ S) (T ⇒ U) = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) (<:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ R (T ⇒ U)) <:-refl))) +∪-<:-∪ⁿ (R ∪ S) (T ⇒ U) = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) (<:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ R (T ⇒ U)) <:-refl))) ∪-<:-∪ⁿ never (T ∩ U) = <:-∪-lub <:-never <:-refl ∪-<:-∪ⁿ unknown (T ∩ U) = <:-unknown ∪-<:-∪ⁿ (R ⇒ S) (T ∩ U) = ∪-<:-∪ᶠ (R ⇒ S) (T ∩ U) diff --git a/prototyping/Properties/TypeSaturation.agda b/prototyping/Properties/TypeSaturation.agda new file mode 100644 index 00000000..13f7d171 --- /dev/null +++ b/prototyping/Properties/TypeSaturation.agda @@ -0,0 +1,433 @@ +{-# OPTIONS --rewriting #-} + +module Properties.TypeSaturation where + +open import Agda.Builtin.Equality using (_≡_; refl) +open import FFI.Data.Either using (Either; Left; Right) +open import Luau.Subtyping using (Tree; Language; ¬Language; _<:_; _≮:_; witness; scalar; function; function-err; function-ok; function-ok₁; function-ok₂; scalar-function; _,_; never) +open import Luau.Type using (Type; _⇒_; _∩_; _∪_; never; unknown) +open import Luau.TypeNormalization using (_∩ⁿ_; _∪ⁿ_) +open import Luau.TypeSaturation using (_⋓_; _⋒_; _∩ᵘ_; _∩ⁱ_; ∪-saturate; ∩-saturate; saturate) +open import Properties.Subtyping using (dec-language; language-comp; <:-impl-⊇; <:-refl; <:-trans; <:-trans-≮:; <:-impl-¬≮: ; <:-never; <:-unknown; <:-function; <:-union; <:-∪-symm; <:-∪-left; <:-∪-right; <:-∪-lub; <:-∪-assocl; <:-∪-assocr; <:-intersect; <:-∩-symm; <:-∩-left; <:-∩-right; <:-∩-glb; ≮:-function-left; ≮:-function-right; <:-function-∩-∪; <:-function-∩-∩; <:-∩-assocl; <:-∩-assocr; ∩-<:-∪; <:-∩-distl-∪; ∩-distl-∪-<:; <:-∩-distr-∪; ∩-distr-∪-<:) +open import Properties.TypeNormalization using (Normal; FunType; _⇒_; _∩_; _∪_; never; unknown; normal-∪ⁿ; normal-∩ⁿ; ∪ⁿ-<:-∪; ∪-<:-∪ⁿ; ∩ⁿ-<:-∩; ∩-<:-∩ⁿ) +open import Properties.Contradiction using (CONTRADICTION) +open import Properties.Functions using (_∘_) + +-- Saturation preserves normalization +normal-⋒ : ∀ {F G} → FunType F → FunType G → FunType (F ⋒ G) +normal-⋒ (R ⇒ S) (T ⇒ U) = (normal-∩ⁿ R T) ⇒ (normal-∩ⁿ S U) +normal-⋒ (R ⇒ S) (G ∩ H) = normal-⋒ (R ⇒ S) G ∩ normal-⋒ (R ⇒ S) H +normal-⋒ (E ∩ F) G = normal-⋒ E G ∩ normal-⋒ F G + +normal-⋓ : ∀ {F G} → FunType F → FunType G → FunType (F ⋓ G) +normal-⋓ (R ⇒ S) (T ⇒ U) = (normal-∪ⁿ R T) ⇒ (normal-∪ⁿ S U) +normal-⋓ (R ⇒ S) (G ∩ H) = normal-⋓ (R ⇒ S) G ∩ normal-⋓ (R ⇒ S) H +normal-⋓ (E ∩ F) G = normal-⋓ E G ∩ normal-⋓ F G + +normal-∩-saturate : ∀ {F} → FunType F → FunType (∩-saturate F) +normal-∩-saturate (S ⇒ T) = S ⇒ T +normal-∩-saturate (F ∩ G) = (normal-∩-saturate F ∩ normal-∩-saturate G) ∩ normal-⋒ (normal-∩-saturate F) (normal-∩-saturate G) + +normal-∪-saturate : ∀ {F} → FunType F → FunType (∪-saturate F) +normal-∪-saturate (S ⇒ T) = S ⇒ T +normal-∪-saturate (F ∩ G) = (normal-∪-saturate F ∩ normal-∪-saturate G) ∩ normal-⋓ (normal-∪-saturate F) (normal-∪-saturate G) + +normal-saturate : ∀ {F} → FunType F → FunType (saturate F) +normal-saturate F = normal-∪-saturate (normal-∩-saturate F) + +-- Saturation resects subtyping +∪-saturate-<: : ∀ {F} → FunType F → ∪-saturate F <: F +∪-saturate-<: (S ⇒ T) = <:-refl +∪-saturate-<: (F ∩ G) = <:-trans <:-∩-left (<:-intersect (∪-saturate-<: F) (∪-saturate-<: G)) + +∩-saturate-<: : ∀ {F} → FunType F → ∩-saturate F <: F +∩-saturate-<: (S ⇒ T) = <:-refl +∩-saturate-<: (F ∩ G) = <:-trans <:-∩-left (<:-intersect (∩-saturate-<: F) (∩-saturate-<: G)) + +saturate-<: : ∀ {F} → FunType F → saturate F <: F +saturate-<: F = <:-trans (∪-saturate-<: (normal-∩-saturate F)) (∩-saturate-<: F) + +∩-<:-⋓ : ∀ {F G} → FunType F → FunType G → (F ∩ G) <: (F ⋓ G) +∩-<:-⋓ (R ⇒ S) (T ⇒ U) = <:-trans <:-function-∩-∪ (<:-function (∪ⁿ-<:-∪ R T) (∪-<:-∪ⁿ S U)) +∩-<:-⋓ (R ⇒ S) (G ∩ H) = <:-trans (<:-∩-glb (<:-intersect <:-refl <:-∩-left) (<:-intersect <:-refl <:-∩-right)) (<:-intersect (∩-<:-⋓ (R ⇒ S) G) (∩-<:-⋓ (R ⇒ S) H)) +∩-<:-⋓ (E ∩ F) G = <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-intersect <:-∩-right <:-refl)) (<:-intersect (∩-<:-⋓ E G) (∩-<:-⋓ F G)) + +∩-<:-⋒ : ∀ {F G} → FunType F → FunType G → (F ∩ G) <: (F ⋒ G) +∩-<:-⋒ (R ⇒ S) (T ⇒ U) = <:-trans <:-function-∩-∩ (<:-function (∩ⁿ-<:-∩ R T) (∩-<:-∩ⁿ S U)) +∩-<:-⋒ (R ⇒ S) (G ∩ H) = <:-trans (<:-∩-glb (<:-intersect <:-refl <:-∩-left) (<:-intersect <:-refl <:-∩-right)) (<:-intersect (∩-<:-⋒ (R ⇒ S) G) (∩-<:-⋒ (R ⇒ S) H)) +∩-<:-⋒ (E ∩ F) G = <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-intersect <:-∩-right <:-refl)) (<:-intersect (∩-<:-⋒ E G) (∩-<:-⋒ F G)) + +<:-∪-saturate : ∀ {F} → FunType F → F <: ∪-saturate F +<:-∪-saturate (S ⇒ T) = <:-refl +<:-∪-saturate (F ∩ G) = <:-∩-glb (<:-intersect (<:-∪-saturate F) (<:-∪-saturate G)) (<:-trans (<:-intersect (<:-∪-saturate F) (<:-∪-saturate G)) (∩-<:-⋓ (normal-∪-saturate F) (normal-∪-saturate G))) + +<:-∩-saturate : ∀ {F} → FunType F → F <: ∩-saturate F +<:-∩-saturate (S ⇒ T) = <:-refl +<:-∩-saturate (F ∩ G) = <:-∩-glb (<:-intersect (<:-∩-saturate F) (<:-∩-saturate G)) (<:-trans (<:-intersect (<:-∩-saturate F) (<:-∩-saturate G)) (∩-<:-⋒ (normal-∩-saturate F) (normal-∩-saturate G))) + +<:-saturate : ∀ {F} → FunType F → F <: saturate F +<:-saturate F = <:-trans (<:-∩-saturate F) (<:-∪-saturate (normal-∩-saturate F)) + +-- Overloads F is the set of overloads of F +data Overloads : Type → Type → Set where + + here : ∀ {S T} → Overloads (S ⇒ T) (S ⇒ T) + left : ∀ {S T F G} → Overloads F (S ⇒ T) → Overloads (F ∩ G) (S ⇒ T) + right : ∀ {S T F G} → Overloads G (S ⇒ T) → Overloads (F ∩ G) (S ⇒ T) + +normal-overload-src : ∀ {F S T} → FunType F → Overloads F (S ⇒ T) → Normal S +normal-overload-src (S ⇒ T) here = S +normal-overload-src (F ∩ G) (left o) = normal-overload-src F o +normal-overload-src (F ∩ G) (right o) = normal-overload-src G o + +normal-overload-tgt : ∀ {F S T} → FunType F → Overloads F (S ⇒ T) → Normal T +normal-overload-tgt (S ⇒ T) here = T +normal-overload-tgt (F ∩ G) (left o) = normal-overload-tgt F o +normal-overload-tgt (F ∩ G) (right o) = normal-overload-tgt G o + +-- An inductive presentation of the overloads of F ⋓ G +data ∪-Lift (P Q : Type → Set) : Type → Set where + + union : ∀ {R S T U} → + + P (R ⇒ S) → + Q (T ⇒ U) → + -------------------- + ∪-Lift P Q ((R ∪ T) ⇒ (S ∪ U)) + +-- An inductive presentation of the overloads of F ⋒ G +data ∩-Lift (P Q : Type → Set) : Type → Set where + + intersect : ∀ {R S T U} → + + P (R ⇒ S) → + Q (T ⇒ U) → + -------------------- + ∩-Lift P Q ((R ∩ T) ⇒ (S ∩ U)) + +-- An inductive presentation of the overloads of ∪-saturate F +data ∪-Saturate (P : Type → Set) : Type → Set where + + base : ∀ {S T} → + + P (S ⇒ T) → + -------------------- + ∪-Saturate P (S ⇒ T) + + union : ∀ {R S T U} → + + ∪-Saturate P (R ⇒ S) → + ∪-Saturate P (T ⇒ U) → + -------------------- + ∪-Saturate P ((R ∪ T) ⇒ (S ∪ U)) + +-- An inductive presentation of the overloads of ∩-saturate F +data ∩-Saturate (P : Type → Set) : Type → Set where + + base : ∀ {S T} → + + P (S ⇒ T) → + -------------------- + ∩-Saturate P (S ⇒ T) + + intersect : ∀ {R S T U} → + + ∩-Saturate P (R ⇒ S) → + ∩-Saturate P (T ⇒ U) → + -------------------- + ∩-Saturate P ((R ∩ T) ⇒ (S ∩ U)) + +-- The <:-up-closure of a set of function types +data <:-Close (P : Type → Set) : Type → Set where + + defn : ∀ {R S T U} → + + P (S ⇒ T) → + R <: S → + T <: U → + ------------------ + <:-Close P (R ⇒ U) + +-- F ⊆ᵒ G whenever every overload of F is an overload of G +_⊆ᵒ_ : Type → Type → Set +F ⊆ᵒ G = ∀ {S T} → Overloads F (S ⇒ T) → Overloads G (S ⇒ T) + +-- F <:ᵒ G when every overload of G is a supertype of an overload of F +_<:ᵒ_ : Type → Type → Set +_<:ᵒ_ F G = ∀ {S T} → Overloads G (S ⇒ T) → <:-Close (Overloads F) (S ⇒ T) + +-- P ⊂: Q when any type in P is a subtype of some type in Q +_⊂:_ : (Type → Set) → (Type → Set) → Set +P ⊂: Q = ∀ {S T} → P (S ⇒ T) → <:-Close Q (S ⇒ T) + +-- <:-Close is a monad +just : ∀ {P S T} → P (S ⇒ T) → <:-Close P (S ⇒ T) +just p = defn p <:-refl <:-refl + +infixl 5 _>>=_ _>>=ˡ_ _>>=ʳ_ +_>>=_ : ∀ {P Q S T} → <:-Close P (S ⇒ T) → (P ⊂: Q) → <:-Close Q (S ⇒ T) +(defn p p₁ p₂) >>= P⊂Q with P⊂Q p +(defn p p₁ p₂) >>= P⊂Q | defn q q₁ q₂ = defn q (<:-trans p₁ q₁) (<:-trans q₂ p₂) + +_>>=ˡ_ : ∀ {P R S T} → <:-Close P (S ⇒ T) → (R <: S) → <:-Close P (R ⇒ T) +(defn p p₁ p₂) >>=ˡ q = defn p (<:-trans q p₁) p₂ + +_>>=ʳ_ : ∀ {P S T U} → <:-Close P (S ⇒ T) → (T <: U) → <:-Close P (S ⇒ U) +(defn p p₁ p₂) >>=ʳ q = defn p p₁ (<:-trans p₂ q) + +-- Properties of ⊂: +⊂:-refl : ∀ {P} → P ⊂: P +⊂:-refl p = just p + +_[∪]_ : ∀ {P Q R S T U} → <:-Close P (R ⇒ S) → <:-Close Q (T ⇒ U) → <:-Close (∪-Lift P Q) ((R ∪ T) ⇒ (S ∪ U)) +(defn p p₁ p₂) [∪] (defn q q₁ q₂) = defn (union p q) (<:-union p₁ q₁) (<:-union p₂ q₂) + +_[∩]_ : ∀ {P Q R S T U} → <:-Close P (R ⇒ S) → <:-Close Q (T ⇒ U) → <:-Close (∩-Lift P Q) ((R ∩ T) ⇒ (S ∩ U)) +(defn p p₁ p₂) [∩] (defn q q₁ q₂) = defn (intersect p q) (<:-intersect p₁ q₁) (<:-intersect p₂ q₂) + +⊂:-∩-saturate-inj : ∀ {P} → P ⊂: ∩-Saturate P +⊂:-∩-saturate-inj p = defn (base p) <:-refl <:-refl + +⊂:-∪-saturate-inj : ∀ {P} → P ⊂: ∪-Saturate P +⊂:-∪-saturate-inj p = just (base p) + +⊂:-∩-lift-saturate : ∀ {P} → ∩-Lift (∩-Saturate P) (∩-Saturate P) ⊂: ∩-Saturate P +⊂:-∩-lift-saturate (intersect p q) = just (intersect p q) + +⊂:-∪-lift-saturate : ∀ {P} → ∪-Lift (∪-Saturate P) (∪-Saturate P) ⊂: ∪-Saturate P +⊂:-∪-lift-saturate (union p q) = just (union p q) + +⊂:-∩-lift : ∀ {P Q R S} → (P ⊂: Q) → (R ⊂: S) → (∩-Lift P R ⊂: ∩-Lift Q S) +⊂:-∩-lift P⊂Q R⊂S (intersect n o) = P⊂Q n [∩] R⊂S o + +⊂:-∪-lift : ∀ {P Q R S} → (P ⊂: Q) → (R ⊂: S) → (∪-Lift P R ⊂: ∪-Lift Q S) +⊂:-∪-lift P⊂Q R⊂S (union n o) = P⊂Q n [∪] R⊂S o + +⊂:-∩-saturate : ∀ {P Q} → (P ⊂: Q) → (∩-Saturate P ⊂: ∩-Saturate Q) +⊂:-∩-saturate P⊂Q (base p) = P⊂Q p >>= ⊂:-∩-saturate-inj +⊂:-∩-saturate P⊂Q (intersect p q) = (⊂:-∩-saturate P⊂Q p [∩] ⊂:-∩-saturate P⊂Q q) >>= ⊂:-∩-lift-saturate + +⊂:-∪-saturate : ∀ {P Q} → (P ⊂: Q) → (∪-Saturate P ⊂: ∪-Saturate Q) +⊂:-∪-saturate P⊂Q (base p) = P⊂Q p >>= ⊂:-∪-saturate-inj +⊂:-∪-saturate P⊂Q (union p q) = (⊂:-∪-saturate P⊂Q p [∪] ⊂:-∪-saturate P⊂Q q) >>= ⊂:-∪-lift-saturate + +⊂:-∩-saturate-indn : ∀ {P Q} → (P ⊂: Q) → (∩-Lift Q Q ⊂: Q) → (∩-Saturate P ⊂: Q) +⊂:-∩-saturate-indn P⊂Q QQ⊂Q (base p) = P⊂Q p +⊂:-∩-saturate-indn P⊂Q QQ⊂Q (intersect p q) = (⊂:-∩-saturate-indn P⊂Q QQ⊂Q p [∩] ⊂:-∩-saturate-indn P⊂Q QQ⊂Q q) >>= QQ⊂Q + +⊂:-∪-saturate-indn : ∀ {P Q} → (P ⊂: Q) → (∪-Lift Q Q ⊂: Q) → (∪-Saturate P ⊂: Q) +⊂:-∪-saturate-indn P⊂Q QQ⊂Q (base p) = P⊂Q p +⊂:-∪-saturate-indn P⊂Q QQ⊂Q (union p q) = (⊂:-∪-saturate-indn P⊂Q QQ⊂Q p [∪] ⊂:-∪-saturate-indn P⊂Q QQ⊂Q q) >>= QQ⊂Q + +∪-saturate-resp-∩-saturation : ∀ {P} → (∩-Lift P P ⊂: P) → (∩-Lift (∪-Saturate P) (∪-Saturate P) ⊂: ∪-Saturate P) +∪-saturate-resp-∩-saturation ∩P⊂P (intersect (base p) (base q)) = ∩P⊂P (intersect p q) >>= ⊂:-∪-saturate-inj +∪-saturate-resp-∩-saturation ∩P⊂P (intersect p (union q q₁)) = (∪-saturate-resp-∩-saturation ∩P⊂P (intersect p q) [∪] ∪-saturate-resp-∩-saturation ∩P⊂P (intersect p q₁)) >>= ⊂:-∪-lift-saturate >>=ˡ <:-∩-distl-∪ >>=ʳ ∩-distl-∪-<: +∪-saturate-resp-∩-saturation ∩P⊂P (intersect (union p p₁) q) = (∪-saturate-resp-∩-saturation ∩P⊂P (intersect p q) [∪] ∪-saturate-resp-∩-saturation ∩P⊂P (intersect p₁ q)) >>= ⊂:-∪-lift-saturate >>=ˡ <:-∩-distr-∪ >>=ʳ ∩-distr-∪-<: + +ov-language : ∀ {F t} → FunType F → (∀ {S T} → Overloads F (S ⇒ T) → Language (S ⇒ T) t) → Language F t +ov-language (S ⇒ T) p = p here +ov-language (F ∩ G) p = (ov-language F (p ∘ left) , ov-language G (p ∘ right)) + +ov-<: : ∀ {F R S T U} → FunType F → Overloads F (R ⇒ S) → ((R ⇒ S) <: (T ⇒ U)) → F <: (T ⇒ U) +ov-<: F here p = p +ov-<: (F ∩ G) (left o) p = <:-trans <:-∩-left (ov-<: F o p) +ov-<: (F ∩ G) (right o) p = <:-trans <:-∩-right (ov-<: G o p) + +<:ᵒ-impl-<: : ∀ {F G} → FunType F → FunType G → (F <:ᵒ G) → (F <: G) +<:ᵒ-impl-<: F (T ⇒ U) F>= ⊂:-overloads-left +⊂:-overloads-⋒ (R ⇒ S) (G ∩ H) (intersect here (right o)) = ⊂:-overloads-⋒ (R ⇒ S) H (intersect here o) >>= ⊂:-overloads-right +⊂:-overloads-⋒ (E ∩ F) G (intersect (left n) o) = ⊂:-overloads-⋒ E G (intersect n o) >>= ⊂:-overloads-left +⊂:-overloads-⋒ (E ∩ F) G (intersect (right n) o) = ⊂:-overloads-⋒ F G (intersect n o) >>= ⊂:-overloads-right + +⊂:-⋒-overloads : ∀ {F G} → FunType F → FunType G → Overloads (F ⋒ G) ⊂: ∩-Lift (Overloads F) (Overloads G) +⊂:-⋒-overloads (R ⇒ S) (T ⇒ U) here = defn (intersect here here) (∩ⁿ-<:-∩ R T) (∩-<:-∩ⁿ S U) +⊂:-⋒-overloads (R ⇒ S) (G ∩ H) (left o) = ⊂:-⋒-overloads (R ⇒ S) G o >>= ⊂:-∩-lift ⊂:-refl ⊂:-overloads-left +⊂:-⋒-overloads (R ⇒ S) (G ∩ H) (right o) = ⊂:-⋒-overloads (R ⇒ S) H o >>= ⊂:-∩-lift ⊂:-refl ⊂:-overloads-right +⊂:-⋒-overloads (E ∩ F) G (left o) = ⊂:-⋒-overloads E G o >>= ⊂:-∩-lift ⊂:-overloads-left ⊂:-refl +⊂:-⋒-overloads (E ∩ F) G (right o) = ⊂:-⋒-overloads F G o >>= ⊂:-∩-lift ⊂:-overloads-right ⊂:-refl + +⊂:-overloads-⋓ : ∀ {F G} → FunType F → FunType G → ∪-Lift (Overloads F) (Overloads G) ⊂: Overloads (F ⋓ G) +⊂:-overloads-⋓ (R ⇒ S) (T ⇒ U) (union here here) = defn here (∪-<:-∪ⁿ R T) (∪ⁿ-<:-∪ S U) +⊂:-overloads-⋓ (R ⇒ S) (G ∩ H) (union here (left o)) = ⊂:-overloads-⋓ (R ⇒ S) G (union here o) >>= ⊂:-overloads-left +⊂:-overloads-⋓ (R ⇒ S) (G ∩ H) (union here (right o)) = ⊂:-overloads-⋓ (R ⇒ S) H (union here o) >>= ⊂:-overloads-right +⊂:-overloads-⋓ (E ∩ F) G (union (left n) o) = ⊂:-overloads-⋓ E G (union n o) >>= ⊂:-overloads-left +⊂:-overloads-⋓ (E ∩ F) G (union (right n) o) = ⊂:-overloads-⋓ F G (union n o) >>= ⊂:-overloads-right + +⊂:-⋓-overloads : ∀ {F G} → FunType F → FunType G → Overloads (F ⋓ G) ⊂: ∪-Lift (Overloads F) (Overloads G) +⊂:-⋓-overloads (R ⇒ S) (T ⇒ U) here = defn (union here here) (∪ⁿ-<:-∪ R T) (∪-<:-∪ⁿ S U) +⊂:-⋓-overloads (R ⇒ S) (G ∩ H) (left o) = ⊂:-⋓-overloads (R ⇒ S) G o >>= ⊂:-∪-lift ⊂:-refl ⊂:-overloads-left +⊂:-⋓-overloads (R ⇒ S) (G ∩ H) (right o) = ⊂:-⋓-overloads (R ⇒ S) H o >>= ⊂:-∪-lift ⊂:-refl ⊂:-overloads-right +⊂:-⋓-overloads (E ∩ F) G (left o) = ⊂:-⋓-overloads E G o >>= ⊂:-∪-lift ⊂:-overloads-left ⊂:-refl +⊂:-⋓-overloads (E ∩ F) G (right o) = ⊂:-⋓-overloads F G o >>= ⊂:-∪-lift ⊂:-overloads-right ⊂:-refl + +∪-saturate-overloads : ∀ {F} → FunType F → Overloads (∪-saturate F) ⊂: ∪-Saturate (Overloads F) +∪-saturate-overloads (S ⇒ T) here = just (base here) +∪-saturate-overloads (F ∩ G) (left (left o)) = ∪-saturate-overloads F o >>= ⊂:-∪-saturate ⊂:-overloads-left +∪-saturate-overloads (F ∩ G) (left (right o)) = ∪-saturate-overloads G o >>= ⊂:-∪-saturate ⊂:-overloads-right +∪-saturate-overloads (F ∩ G) (right o) = + ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) o >>= + ⊂:-∪-lift (∪-saturate-overloads F) (∪-saturate-overloads G) >>= + ⊂:-∪-lift (⊂:-∪-saturate ⊂:-overloads-left) (⊂:-∪-saturate ⊂:-overloads-right) >>= + ⊂:-∪-lift-saturate + +overloads-∪-saturate : ∀ {F} → FunType F → ∪-Saturate (Overloads F) ⊂: Overloads (∪-saturate F) +overloads-∪-saturate F = ⊂:-∪-saturate-indn (inj F) (step F) where + + inj : ∀ {F} → FunType F → Overloads F ⊂: Overloads (∪-saturate F) + inj (S ⇒ T) here = just here + inj (F ∩ G) (left p) = inj F p >>= ⊂:-overloads-left >>= ⊂:-overloads-left + inj (F ∩ G) (right p) = inj G p >>= ⊂:-overloads-right >>= ⊂:-overloads-left + + step : ∀ {F} → FunType F → ∪-Lift (Overloads (∪-saturate F)) (Overloads (∪-saturate F)) ⊂: Overloads (∪-saturate F) + step (S ⇒ T) (union here here) = defn here (<:-∪-lub <:-refl <:-refl) <:-∪-left + step (F ∩ G) (union (left (left p)) (left (left q))) = step F (union p q) >>= ⊂:-overloads-left >>= ⊂:-overloads-left + step (F ∩ G) (union (left (left p)) (left (right q))) = ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) (union p q) >>= ⊂:-overloads-right + step (F ∩ G) (union (left (right p)) (left (left q))) = ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) (union q p) >>= ⊂:-overloads-right >>=ˡ <:-∪-symm >>=ʳ <:-∪-symm + step (F ∩ G) (union (left (right p)) (left (right q))) = step G (union p q) >>= ⊂:-overloads-right >>= ⊂:-overloads-left + step (F ∩ G) (union p (right q)) with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) q + step (F ∩ G) (union (left (left p)) (right q)) | defn (union q₁ q₂) q₃ q₄ = + (step F (union p q₁) [∪] just q₂) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union <:-refl q₃) <:-∪-assocl >>=ʳ + <:-trans <:-∪-assocr (<:-union <:-refl q₄) + step (F ∩ G) (union (left (right p)) (right q)) | defn (union q₁ q₂) q₃ q₄ = + (just q₁ [∪] step G (union p q₂)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union <:-refl q₃) (<:-∪-lub (<:-trans <:-∪-left <:-∪-right) (<:-∪-lub <:-∪-left (<:-trans <:-∪-right <:-∪-right))) >>=ʳ + <:-trans (<:-∪-lub (<:-trans <:-∪-left <:-∪-right) (<:-∪-lub <:-∪-left (<:-trans <:-∪-right <:-∪-right))) (<:-union <:-refl q₄) + step (F ∩ G) (union (right p) (right q)) | defn (union q₁ q₂) q₃ q₄ with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) p + step (F ∩ G) (union (right p) (right q)) | defn (union q₁ q₂) q₃ q₄ | defn (union p₁ p₂) p₃ p₄ = + (step F (union p₁ q₁) [∪] step G (union p₂ q₂)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ q₃) (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) >>=ʳ + <:-trans (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) (<:-union p₄ q₄) + step (F ∩ G) (union (right p) q) with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) p + step (F ∩ G) (union (right p) (left (left q))) | defn (union p₁ p₂) p₃ p₄ = + (step F (union p₁ q) [∪] just p₂) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ <:-refl) (<:-∪-lub (<:-union <:-∪-left <:-refl) (<:-trans <:-∪-right <:-∪-left)) >>=ʳ + <:-trans (<:-∪-lub (<:-union <:-∪-left <:-refl) (<:-trans <:-∪-right <:-∪-left)) (<:-union p₄ <:-refl) + step (F ∩ G) (union (right p) (left (right q))) | defn (union p₁ p₂) p₃ p₄ = + (just p₁ [∪] step G (union p₂ q)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ <:-refl) <:-∪-assocr >>=ʳ + <:-trans <:-∪-assocl (<:-union p₄ <:-refl) + step (F ∩ G) (union (right p) (right q)) | defn (union p₁ p₂) p₃ p₄ with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) q + step (F ∩ G) (union (right p) (right q)) | defn (union p₁ p₂) p₃ p₄ | defn (union q₁ q₂) q₃ q₄ = + (step F (union p₁ q₁) [∪] step G (union p₂ q₂)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ q₃) (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) >>=ʳ + <:-trans (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) (<:-union p₄ q₄) + +∪-saturated : ∀ {F} → FunType F → ∪-Lift (Overloads (∪-saturate F)) (Overloads (∪-saturate F)) ⊂: Overloads (∪-saturate F) +∪-saturated F o = + ⊂:-∪-lift (∪-saturate-overloads F) (∪-saturate-overloads F) o >>= + ⊂:-∪-lift-saturate >>= + overloads-∪-saturate F + +∩-saturate-overloads : ∀ {F} → FunType F → Overloads (∩-saturate F) ⊂: ∩-Saturate (Overloads F) +∩-saturate-overloads (S ⇒ T) here = just (base here) +∩-saturate-overloads (F ∩ G) (left (left o)) = ∩-saturate-overloads F o >>= ⊂:-∩-saturate ⊂:-overloads-left +∩-saturate-overloads (F ∩ G) (left (right o)) = ∩-saturate-overloads G o >>= ⊂:-∩-saturate ⊂:-overloads-right +∩-saturate-overloads (F ∩ G) (right o) = + ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) o >>= + ⊂:-∩-lift (∩-saturate-overloads F) (∩-saturate-overloads G) >>= + ⊂:-∩-lift (⊂:-∩-saturate ⊂:-overloads-left) (⊂:-∩-saturate ⊂:-overloads-right) >>= + ⊂:-∩-lift-saturate + +overloads-∩-saturate : ∀ {F} → FunType F → ∩-Saturate (Overloads F) ⊂: Overloads (∩-saturate F) +overloads-∩-saturate F = ⊂:-∩-saturate-indn (inj F) (step F) where + + inj : ∀ {F} → FunType F → Overloads F ⊂: Overloads (∩-saturate F) + inj (S ⇒ T) here = just here + inj (F ∩ G) (left p) = inj F p >>= ⊂:-overloads-left >>= ⊂:-overloads-left + inj (F ∩ G) (right p) = inj G p >>= ⊂:-overloads-right >>= ⊂:-overloads-left + + step : ∀ {F} → FunType F → ∩-Lift (Overloads (∩-saturate F)) (Overloads (∩-saturate F)) ⊂: Overloads (∩-saturate F) + step (S ⇒ T) (intersect here here) = defn here <:-∩-left (<:-∩-glb <:-refl <:-refl) + step (F ∩ G) (intersect (left (left p)) (left (left q))) = step F (intersect p q) >>= ⊂:-overloads-left >>= ⊂:-overloads-left + step (F ∩ G) (intersect (left (left p)) (left (right q))) = ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) (intersect p q) >>= ⊂:-overloads-right + step (F ∩ G) (intersect (left (right p)) (left (left q))) = ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) (intersect q p) >>= ⊂:-overloads-right >>=ˡ <:-∩-symm >>=ʳ <:-∩-symm + step (F ∩ G) (intersect (left (right p)) (left (right q))) = step G (intersect p q) >>= ⊂:-overloads-right >>= ⊂:-overloads-left + step (F ∩ G) (intersect (right p) q) with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) p + step (F ∩ G) (intersect (right p) (left (left q))) | defn (intersect p₁ p₂) p₃ p₄ = + (step F (intersect p₁ q) [∩] just p₂) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ <:-refl) (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-trans <:-∩-left <:-∩-right)) >>=ʳ + <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-trans <:-∩-left <:-∩-right)) (<:-intersect p₄ <:-refl) + step (F ∩ G) (intersect (right p) (left (right q))) | defn (intersect p₁ p₂) p₃ p₄ = + (just p₁ [∩] step G (intersect p₂ q)) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ <:-refl) <:-∩-assocr >>=ʳ + <:-trans <:-∩-assocl (<:-intersect p₄ <:-refl) + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect p₁ p₂) p₃ p₄ with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) q + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect p₁ p₂) p₃ p₄ | defn (intersect q₁ q₂) q₃ q₄ = + (step F (intersect p₁ q₁) [∩] step G (intersect p₂ q₂)) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ q₃) (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) >>=ʳ + <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) (<:-intersect p₄ q₄) + step (F ∩ G) (intersect p (right q)) with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) q + step (F ∩ G) (intersect (left (left p)) (right q)) | defn (intersect q₁ q₂) q₃ q₄ = + (step F (intersect p q₁) [∩] just q₂) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect <:-refl q₃) <:-∩-assocl >>=ʳ + <:-trans <:-∩-assocr (<:-intersect <:-refl q₄) + step (F ∩ G) (intersect (left (right p)) (right q)) | defn (intersect q₁ q₂) q₃ q₄ = + (just q₁ [∩] step G (intersect p q₂) ) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect <:-refl q₃) (<:-∩-glb (<:-trans <:-∩-right <:-∩-left) (<:-∩-glb <:-∩-left (<:-trans <:-∩-right <:-∩-right))) >>=ʳ + <:-∩-glb (<:-trans <:-∩-right <:-∩-left) (<:-trans (<:-∩-glb <:-∩-left (<:-trans <:-∩-right <:-∩-right)) q₄) + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect q₁ q₂) q₃ q₄ with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) p + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect q₁ q₂) q₃ q₄ | defn (intersect p₁ p₂) p₃ p₄ = + (step F (intersect p₁ q₁) [∩] step G (intersect p₂ q₂)) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ q₃) (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) >>=ʳ + <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) (<:-intersect p₄ q₄) + +saturate-overloads : ∀ {F} → FunType F → Overloads (saturate F) ⊂: ∪-Saturate (∩-Saturate (Overloads F)) +saturate-overloads F o = ∪-saturate-overloads (normal-∩-saturate F) o >>= (⊂:-∪-saturate (∩-saturate-overloads F)) + +overloads-saturate : ∀ {F} → FunType F → ∪-Saturate (∩-Saturate (Overloads F)) ⊂: Overloads (saturate F) +overloads-saturate F o = ⊂:-∪-saturate (overloads-∩-saturate F) o >>= overloads-∪-saturate (normal-∩-saturate F) + +-- Saturated F whenever +-- * if F has overloads (R ⇒ S) and (T ⇒ U) then F has an overload which is a subtype of ((R ∩ T) ⇒ (S ∩ U)) +-- * ditto union +data Saturated (F : Type) : Set where + + defn : + + (∀ {R S T U} → Overloads F (R ⇒ S) → Overloads F (T ⇒ U) → <:-Close (Overloads F) ((R ∩ T) ⇒ (S ∩ U))) → + (∀ {R S T U} → Overloads F (R ⇒ S) → Overloads F (T ⇒ U) → <:-Close (Overloads F) ((R ∪ T) ⇒ (S ∪ U))) → + ----------- + Saturated F + +-- saturated F is saturated! +saturated : ∀ {F} → FunType F → Saturated (saturate F) +saturated F = defn + (λ n o → (saturate-overloads F n [∩] saturate-overloads F o) >>= ∪-saturate-resp-∩-saturation ⊂:-∩-lift-saturate >>= overloads-saturate F) + (λ n o → ∪-saturated (normal-∩-saturate F) (union n o)) From f1b46f4b967f11fabe666da1de0e71b225368260 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 16 Jun 2022 18:05:14 -0700 Subject: [PATCH 086/102] Sync to upstream/release/532 (#545) --- Analysis/include/Luau/Constraint.h | 82 ++++ .../include/Luau/ConstraintGraphBuilder.h | 113 ++--- Analysis/include/Luau/ConstraintSolver.h | 103 +++-- .../include/Luau/ConstraintSolverLogger.h | 26 ++ Analysis/include/Luau/Frontend.h | 4 +- Analysis/include/Luau/Module.h | 1 + Analysis/include/Luau/Normalize.h | 1 + Analysis/include/Luau/NotNull.h | 41 +- Analysis/include/Luau/Quantify.h | 3 +- Analysis/include/Luau/RequireTracer.h | 2 +- Analysis/include/Luau/TypeChecker2.h | 13 + Analysis/include/Luau/TypeInfer.h | 41 +- Analysis/include/Luau/TypeVar.h | 37 +- Analysis/include/Luau/Unifier.h | 2 +- Analysis/include/Luau/VisitTypeVar.h | 2 +- Analysis/src/Autocomplete.cpp | 19 +- Analysis/src/BuiltinDefinitions.cpp | 68 +-- Analysis/src/Clone.cpp | 10 +- Analysis/src/Constraint.cpp | 14 + Analysis/src/ConstraintGraphBuilder.cpp | 406 +++++++++++++++--- Analysis/src/ConstraintSolver.cpp | 174 +++++--- Analysis/src/ConstraintSolverLogger.cpp | 139 ++++++ Analysis/src/Frontend.cpp | 34 +- Analysis/src/Instantiation.cpp | 2 +- Analysis/src/Linter.cpp | 4 +- Analysis/src/Normalize.cpp | 46 +- Analysis/src/Quantify.cpp | 153 ++++++- Analysis/src/RequireTracer.cpp | 14 +- Analysis/src/Substitution.cpp | 4 +- Analysis/src/ToDot.cpp | 2 +- Analysis/src/ToString.cpp | 32 +- Analysis/src/TypeAttach.cpp | 7 +- Analysis/src/TypeChecker2.cpp | 160 +++++++ Analysis/src/TypeInfer.cpp | 253 ++++++----- Analysis/src/TypeUtils.cpp | 2 +- Analysis/src/TypeVar.cpp | 41 +- Analysis/src/Unifier.cpp | 13 +- Compiler/src/BytecodeBuilder.cpp | 14 + Compiler/src/Compiler.cpp | 76 +++- Sources.cmake | 8 +- VM/src/lobject.h | 2 +- VM/src/ltable.cpp | 8 +- VM/src/ltm.cpp | 4 +- VM/src/ltm.h | 4 +- tests/Autocomplete.test.cpp | 2 +- tests/Compiler.test.cpp | 252 ++++++++++- tests/ConstraintGraphBuilder.test.cpp | 61 ++- tests/Fixture.cpp | 3 +- tests/Frontend.test.cpp | 8 +- tests/Module.test.cpp | 2 +- tests/NonstrictMode.test.cpp | 53 ++- tests/Normalize.test.cpp | 42 +- tests/NotNull.test.cpp | 53 ++- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 60 ++- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.operators.test.cpp | 4 +- tests/TypeInfer.provisional.test.cpp | 5 +- tests/TypeInfer.refinements.test.cpp | 6 +- tests/TypeInfer.tables.test.cpp | 11 +- tests/TypeInfer.test.cpp | 89 +--- tests/TypeInfer.typePacks.cpp | 14 +- tools/natvis/Analysis.natvis | 66 +-- 63 files changed, 2186 insertions(+), 737 deletions(-) create mode 100644 Analysis/include/Luau/Constraint.h create mode 100644 Analysis/include/Luau/ConstraintSolverLogger.h create mode 100644 Analysis/include/Luau/TypeChecker2.h create mode 100644 Analysis/src/Constraint.cpp create mode 100644 Analysis/src/ConstraintSolverLogger.cpp create mode 100644 Analysis/src/TypeChecker2.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h new file mode 100644 index 00000000..c62166e2 --- /dev/null +++ b/Analysis/include/Luau/Constraint.h @@ -0,0 +1,82 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +struct Scope2; +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +// subType <: superType +struct SubtypeConstraint +{ + TypeId subType; + TypeId superType; +}; + +// subPack <: superPack +struct PackSubtypeConstraint +{ + TypePackId subPack; + TypePackId superPack; +}; + +// subType ~ gen superType +struct GeneralizationConstraint +{ + TypeId generalizedType; + TypeId sourceType; + Scope2* scope; +}; + +// subType ~ inst superType +struct InstantiationConstraint +{ + TypeId subType; + TypeId superType; +}; + +using ConstraintV = Variant; +using ConstraintPtr = std::unique_ptr; + +struct Constraint +{ + Constraint(ConstraintV&& c, Location location); + + Constraint(const Constraint&) = delete; + Constraint& operator=(const Constraint&) = delete; + + ConstraintV c; + Location location; + std::vector> dependencies; +}; + +inline Constraint& asMutable(const Constraint& c) +{ + return const_cast(c); +} + +template +T* getMutable(Constraint& c) +{ + return ::Luau::get_if(&c.c); +} + +template +const T* get(const Constraint& c) +{ + return getMutable(asMutable(c)); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 4234f2f6..da774a2a 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -4,9 +4,12 @@ #include #include +#include #include "Luau/Ast.h" +#include "Luau/Constraint.h" #include "Luau/Module.h" +#include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" @@ -14,69 +17,6 @@ namespace Luau { -struct Scope2; - -// subType <: superType -struct SubtypeConstraint -{ - TypeId subType; - TypeId superType; -}; - -// subPack <: superPack -struct PackSubtypeConstraint -{ - TypePackId subPack; - TypePackId superPack; -}; - -// subType ~ gen superType -struct GeneralizationConstraint -{ - TypeId subType; - TypeId superType; - Scope2* scope; -}; - -// subType ~ inst superType -struct InstantiationConstraint -{ - TypeId subType; - TypeId superType; -}; - -using ConstraintV = Variant; -using ConstraintPtr = std::unique_ptr; - -struct Constraint -{ - Constraint(ConstraintV&& c); - Constraint(ConstraintV&& c, std::vector dependencies); - - Constraint(const Constraint&) = delete; - Constraint& operator=(const Constraint&) = delete; - - ConstraintV c; - std::vector dependencies; -}; - -inline Constraint& asMutable(const Constraint& c) -{ - return const_cast(c); -} - -template -T* getMutable(Constraint& c) -{ - return ::Luau::get_if(&c.c); -} - -template -const T* get(const Constraint& c) -{ - return getMutable(asMutable(c)); -} - struct Scope2 { // The parent scope of this scope. Null if there is no parent (i.e. this @@ -102,6 +42,11 @@ struct ConstraintGraphBuilder TypeArena* const arena; // The root scope of the module we're generating constraints for. Scope2* rootScope; + // A mapping of AST node to TypeId. + DenseHashMap astTypes{nullptr}; + // A mapping of AST node to TypePackId. + DenseHashMap astTypePacks{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; explicit ConstraintGraphBuilder(TypeArena* arena); @@ -128,8 +73,9 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param cv the constraint variant to add. + * @param location the location to attribute to the constraint. */ - void addConstraint(Scope2* scope, ConstraintV cv); + void addConstraint(Scope2* scope, ConstraintV cv, Location location); /** * Adds a constraint to a given scope. @@ -148,15 +94,48 @@ struct ConstraintGraphBuilder void visit(Scope2* scope, AstStat* stat); void visit(Scope2* scope, AstStatBlock* block); void visit(Scope2* scope, AstStatLocal* local); - void visit(Scope2* scope, AstStatLocalFunction* local); - void visit(Scope2* scope, AstStatReturn* local); + void visit(Scope2* scope, AstStatLocalFunction* function); + void visit(Scope2* scope, AstStatFunction* function); + void visit(Scope2* scope, AstStatReturn* ret); + void visit(Scope2* scope, AstStatAssign* assign); + void visit(Scope2* scope, AstStatIf* ifStatement); + + TypePackId checkExprList(Scope2* scope, const AstArray& exprs); TypePackId checkPack(Scope2* scope, AstArray exprs); TypePackId checkPack(Scope2* scope, AstExpr* expr); + /** + * Checks an expression that is expected to evaluate to one type. + * @param scope the scope the expression is contained within. + * @param expr the expression to check. + * @return the type of the expression. + */ TypeId check(Scope2* scope, AstExpr* expr); + + TypeId checkExprTable(Scope2* scope, AstExprTable* expr); + TypeId check(Scope2* scope, AstExprIndexName* indexName); + + std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + + /** + * Checks the body of a function expression. + * @param scope the interior scope of the body of the function. + * @param fn the function expression to check. + */ + void checkFunctionBody(Scope2* scope, AstExprFunction* fn); }; -std::vector collectConstraints(Scope2* rootScope); +/** + * Collects a vector of borrowed constraints from the scope and all its child + * scopes. It is important to only call this function when you're done adding + * constraints to the scope or its descendants, lest the borrowed pointers + * become invalid due to a container reallocation. + * @param rootScope the root scope of the scope graph to collect constraints + * from. + * @return a list of pointers to constraints contained within the scope graph. + * None of these pointers should be null. + */ +std::vector> collectConstraints(Scope2* rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 85006e68..7e6d4461 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -4,7 +4,8 @@ #include "Luau/Error.h" #include "Luau/Variant.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Constraint.h" +#include "Luau/ConstraintSolverLogger.h" #include "Luau/TypeVar.h" #include @@ -20,39 +21,81 @@ struct ConstraintSolver { TypeArena* arena; InternalErrorReporter iceReporter; - // The entire set of constraints that the solver is trying to resolve. - std::vector constraints; + // The entire set of constraints that the solver is trying to resolve. It + // is important to not add elements to this vector, lest the underlying + // storage that we retain pointers to be mutated underneath us. + const std::vector> constraints; Scope2* rootScope; - std::vector errors; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. - std::unordered_set unsolvedConstraints; + std::vector> unsolvedConstraints; // A mapping of constraint pointer to how many things the constraint is // blocked on. Can be empty or 0 for constraints that are not blocked on // anything. - std::unordered_map blockedConstraints; + std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map> blocked; + std::unordered_map>> blocked; + + ConstraintSolverLogger logger; explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution - * that satisfies all of the constraints, recording any errors that are - * encountered. + * that satisfies all of the constraints. **/ void run(); bool done(); - bool tryDispatch(const Constraint* c); - bool tryDispatch(const SubtypeConstraint& c); - bool tryDispatch(const PackSubtypeConstraint& c); - bool tryDispatch(const GeneralizationConstraint& c); - bool tryDispatch(const InstantiationConstraint& c, const Constraint* constraint); + bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + void block(NotNull target, NotNull constraint); + /** + * Block a constraint on the resolution of a TypeVar. + * @returns false always. This is just to allow tryDispatch to return the result of block() + */ + bool block(TypeId target, NotNull constraint); + bool block(TypePackId target, NotNull constraint); + + void unblock(NotNull progressed); + void unblock(TypeId progressed); + void unblock(TypePackId progressed); + + /** + * @returns true if the TypeId is in a blocked state. + */ + bool isBlocked(TypeId ty); + + /** + * Returns whether the constraint is blocked on anything. + * @param constraint the constraint to check. + */ + bool isBlocked(NotNull constraint); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subType the sub-type to unify. + * @param superType the super-type to unify. + */ + void unify(TypeId subType, TypeId superType, Location location); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subPack the sub-type pack to unify. + * @param superPack the super-type pack to unify. + */ + void unify(TypePackId subPack, TypePackId superPack, Location location); + +private: /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their @@ -60,10 +103,7 @@ struct ConstraintSolver * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, const Constraint* constraint); - void block(const Constraint* target, const Constraint* constraint); - void block(TypeId target, const Constraint* constraint); - void block(TypePackId target, const Constraint* constraint); + void block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The @@ -72,33 +112,6 @@ struct ConstraintSolver * @param progressed the type or type pack pointer that has progressed. **/ void unblock_(BlockedConstraintId progressed); - void unblock(const Constraint* progressed); - void unblock(TypeId progressed); - void unblock(TypePackId progressed); - - /** - * Returns whether the constraint is blocked on anything. - * @param constraint the constraint to check. - */ - bool isBlocked(const Constraint* constraint); - - void reportErrors(const std::vector& errors); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result and reports errors if necessary. - * @param subType the sub-type to unify. - * @param superType the super-type to unify. - */ - void unify(TypeId subType, TypeId superType); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result and reports errors if necessary. - * @param subPack the sub-type pack to unify. - * @param superPack the super-type pack to unify. - */ - void unify(TypePackId subPack, TypePackId superPack); }; void dump(Scope2* rootScope, struct ToStringOptions& opts); diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h new file mode 100644 index 00000000..2b195d71 --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -0,0 +1,26 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ToString.h" + +#include +#include +#include + +namespace Luau +{ + +struct ConstraintSolverLogger +{ + std::string compileOutput(); + void captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints); + void prepareStepSnapshot(const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints); + void commitPreparedStepSnapshot(); + +private: + std::vector snapshots; + std::optional preparedSnapshot; + ToStringOptions opts; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 58be0ffe..f4226cc1 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -66,7 +66,7 @@ struct SourceNode } ModuleName name; - std::unordered_set requires; + std::unordered_set requireSet; std::vector> requireLocations; bool dirtySourceModule = true; bool dirtyModule = true; @@ -186,7 +186,7 @@ public: std::unordered_map sourceNodes; std::unordered_map sourceModules; - std::unordered_map requires; + std::unordered_map requireTrace; Stats stats = {}; }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index f6e077dc..e979b3f0 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -69,6 +69,7 @@ struct Module std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; + DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; DenseHashMap astOverloadResolvedTypes{nullptr}; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 262b54b2..d4c7698b 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -10,6 +10,7 @@ namespace Luau struct InternalErrorReporter; bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); +bool isSubtype(TypePackId superTy, TypePackId subTy, InternalErrorReporter& ice); std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index 3d05fdea..f6043e9c 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -9,20 +9,22 @@ namespace Luau { /** A non-owning, non-null pointer to a T. - * - * A NotNull is notionally identical to a T* with the added restriction that it - * can never store nullptr. - * - * The sole conversion rule from T* to NotNull is the single-argument constructor, which - * is intentionally marked explicit. This constructor performs a runtime test to verify - * that the passed pointer is never nullptr. - * - * Pointer arithmetic, increment, decrement, and array indexing are all forbidden. - * - * An implicit coersion from NotNull to T* is afforded, as are the pointer indirection and member - * access operators. (*p and p->prop) * - * The explicit delete statement is permitted on a NotNull through this implicit conversion. + * A NotNull is notionally identical to a T* with the added restriction that + * it can never store nullptr. + * + * The sole conversion rule from T* to NotNull is the single-argument + * constructor, which is intentionally marked explicit. This constructor + * performs a runtime test to verify that the passed pointer is never nullptr. + * + * Pointer arithmetic, increment, decrement, and array indexing are all + * forbidden. + * + * An implicit coersion from NotNull to T* is afforded, as are the pointer + * indirection and member access operators. (*p and p->prop) + * + * The explicit delete statement is permitted (but not recommended) on a + * NotNull through this implicit conversion. */ template struct NotNull @@ -36,6 +38,11 @@ struct NotNull explicit NotNull(std::nullptr_t) = delete; void operator=(std::nullptr_t) = delete; + template + NotNull(NotNull other) + : ptr(other.get()) + {} + operator T*() const noexcept { return ptr; @@ -56,6 +63,12 @@ struct NotNull T& operator+(int) = delete; T& operator-(int) = delete; + T* get() const noexcept + { + return ptr; + } + +private: T* ptr; }; @@ -68,7 +81,7 @@ template struct hash> { size_t operator()(const Luau::NotNull& p) const { - return std::hash()(p.ptr); + return std::hash()(p.get()); } }; diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index b32d684e..f46f0cb5 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,9 +6,10 @@ namespace Luau { +struct TypeArena; struct Scope2; void quantify(TypeId ty, TypeLevel level); -void quantify(TypeId ty, Scope2* scope); +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope); } // namespace Luau diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index c25545f5..f69d133e 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -19,7 +19,7 @@ struct RequireTraceResult { DenseHashMap exprs{nullptr}; - std::vector> requires; + std::vector> requireList; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h new file mode 100644 index 00000000..a6c7a3e3 --- /dev/null +++ b/Analysis/include/Luau/TypeChecker2.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Module.h" + +namespace Luau +{ + +void check(const SourceModule& sourceModule, Module* module); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 183cc053..28adc9d9 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -138,25 +138,25 @@ struct TypeChecker void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr( + WithPredicate checkExpr( const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); - ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprCall& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -179,11 +179,11 @@ struct TypeChecker void checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); - ExprResult checkExprPack(const ScopePtr& scope, const AstExpr& expr); - ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -191,7 +191,7 @@ struct TypeChecker const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); - ExprResult checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, const std::vector>& expectedTypes = {}); @@ -234,7 +234,7 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -412,7 +412,6 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b59e7c64..ff7708d4 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -84,6 +84,24 @@ using Tags = std::vector; using ModuleName = std::string; +/** A TypeVar that cannot be computed. + * + * BlockedTypeVars essentially serve as a way to encode partial ordering on the + * constraint graph. Until a BlockedTypeVar is unblocked by its owning + * constraint, nothing at all can be said about it. Constraints that need to + * process a BlockedTypeVar cannot be dispatched. + * + * Whenever a BlockedTypeVar is added to the graph, we also record a constraint + * that will eventually unblock it. + */ +struct BlockedTypeVar +{ + BlockedTypeVar(); + int index; + + static int nextIndex; +}; + struct PrimitiveTypeVar { enum Type @@ -231,29 +249,29 @@ struct FunctionDefinition // TODO: Do we actually need this? We'll find out later if we can delete this. // Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. template -struct ExprResult +struct WithPredicate { T type; PredicateVec predicates; }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, ExprResult)>; +using MagicFunction = std::function>( + struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; struct FunctionTypeVar { // Global monomorphic function - FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local monomorphic function - FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function - FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); TypeLevel level; @@ -263,7 +281,7 @@ struct FunctionTypeVar std::vector genericPacks; TypePackId argTypes; std::vector> argNames; - TypePackId retType; + TypePackId retTypes; std::optional definition; MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; @@ -442,7 +460,7 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; struct TypeVar final @@ -555,7 +573,6 @@ struct SingletonTypes const TypeId trueType; const TypeId falseType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 627b52ca..b51a485e 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -110,7 +110,7 @@ private: void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); public: - void unifyLowerBound(TypePackId subTy, TypePackId superTy); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index f3839915..642522c9 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -209,7 +209,7 @@ struct GenericTypeVarVisitor if (visit(ty, *ftv)) { traverse(ftv->argTypes); - traverse(ftv->retType); + traverse(ftv->retTypes); } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index a8319c59..8a63901f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,7 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { @@ -268,14 +267,14 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { if (FFlag::LuauSelfCallAutocompleteFix2) { - if (std::optional firstRetTy = first(ftv->retType)) + if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); return false; } else { - auto [retHead, retTail] = flatten(ftv->retType); + auto [retHead, retTail] = flatten(ftv->retTypes); if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return true; @@ -454,7 +453,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -493,7 +492,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -742,7 +741,7 @@ static std::optional findTypeElementAt(AstType* astType, TypeId ty, Posi if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) return element; - if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position)) + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) return element; } @@ -958,7 +957,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = get(follow(*it))) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) inferredType = *ty; } } @@ -1050,7 +1049,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, i)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } @@ -1067,7 +1066,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } } @@ -1266,7 +1265,7 @@ static bool autocompleteIfElseExpression( if (!parent) return false; - if (FFlag::LuauIfElseExprFixCompletionIssue && node->is()) + if (node->is()) { // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else // expression. diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 98737b43..2f57e23c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -19,16 +19,16 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -263,10 +263,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; (void)scope; @@ -287,10 +287,10 @@ static std::optional> magicFunctionSelect( if (size_t(offset) < v.size()) { std::vector result(v.begin() + offset, v.end()); - return ExprResult{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; } else if (tail) - return ExprResult{*tail}; + return WithPredicate{*tail}; } typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}}); @@ -298,16 +298,16 @@ static std::optional> magicFunctionSelect( else if (AstExprConstantString* str = arg1->as()) { if (str->value.size == 1 && str->value.data[0] == '#') - return ExprResult{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; } return std::nullopt; } -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -343,7 +343,7 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return ExprResult{}; + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -356,7 +356,7 @@ static std::optional> magicFunctionSetMetaTable( } } - return ExprResult{arena.addTypePack({mtTy})}; + return WithPredicate{arena.addTypePack({mtTy})}; } } else if (get(target) || get(target) || isTableIntersection(target)) @@ -367,13 +367,13 @@ static std::optional> magicFunctionSetMetaTable( typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); } - return ExprResult{arena.addTypePack({target})}; + return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, predicates] = exprResult; + auto [paramPack, predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -382,7 +382,7 @@ static std::optional> magicFunctionAssert( { std::optional fst = first(*tail); if (!fst) - return ExprResult{paramPack}; + return WithPredicate{paramPack}; head.push_back(*fst); } @@ -397,13 +397,13 @@ static std::optional> magicFunctionAssert( head[0] = *newhead; } - return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; + return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -436,7 +436,7 @@ static std::optional> magicFunctionPack( TypeId packedTable = arena.addType( TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); - return ExprResult{arena.addTypePack({packedTable})}; + return WithPredicate{arena.addTypePack({packedTable})}; } static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) @@ -461,8 +461,8 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { TypeArena& arena = typechecker.currentModule->internalTypes; @@ -476,7 +476,7 @@ static std::optional> magicFunctionRequire( return std::nullopt; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) - return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; + return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 9180f309..248262ce 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -47,6 +47,7 @@ struct TypeCloner void operator()(const Unifiable::Generic& t); void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); + void operator()(const BlockedTypeVar& t); void operator()(const PrimitiveTypeVar& t); void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); @@ -158,6 +159,11 @@ void TypeCloner::operator()(const Unifiable::Error& t) defaultClone(t); } +void TypeCloner::operator()(const BlockedTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); @@ -200,7 +206,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) ftv->tags = t.tags; ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, cloneState); + ftv->retTypes = clone(t.retTypes, dest, cloneState); ftv->hasNoGenerics = t.hasNoGenerics; } @@ -391,7 +397,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) if (const FunctionTypeVar* ftv = get(ty)) { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp new file mode 100644 index 00000000..6cb0e4ee --- /dev/null +++ b/Analysis/src/Constraint.cpp @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Constraint.h" + +namespace Luau +{ + +Constraint::Constraint(ConstraintV&& c, Location location) + : c(std::move(c)) + , location(location) +{ +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c8f77ddf..fa627e7a 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -5,16 +5,7 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c) - : c(std::move(c)) -{ -} - -Constraint::Constraint(ConstraintV&& c, std::vector dependencies) - : c(std::move(c)) - , dependencies(dependencies) -{ -} +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp std::optional Scope2::lookup(Symbol sym) { @@ -68,10 +59,10 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv, Location location) { LUAU_ASSERT(scope); - scope->constraints.emplace_back(new Constraint{std::move(cv)}); + scope->constraints.emplace_back(new Constraint{std::move(cv), location}); } void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) @@ -99,10 +90,18 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto f = stat->as()) + visit(scope, f); else if (auto f = stat->as()) visit(scope, f); else if (auto r = stat->as()) visit(scope, r); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto e = stat->as()) + checkPack(scope, e->expr); + else if (auto i = stat->as()) + visit(scope, i); else LUAU_ASSERT(0); } @@ -121,12 +120,30 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) scope->bindings[local] = ty; } - for (size_t i = 0; i < local->vars.size; ++i) + for (size_t i = 0; i < local->values.size; ++i) { - if (i < local->values.size) + if (local->values.data[i]->is()) + { + // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. + // See the test TypeInfer/infer_locals_with_nil_value. + // Better flow awareness should make this obsolete. + } + else if (i == local->values.size - 1) + { + TypePackId exprPack = checkPack(scope, local->values.data[i]); + + if (i < local->vars.size) + { + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; + TypePackId tailPack = arena->addTypePack(std::move(tailValues)); + addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}, local->location); + } + } + else { TypeId exprType = check(scope, local->values.data[i]); - addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); + if (i < varTypes.size()) + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}, local->vars.data[i]->location); } } } @@ -138,7 +155,7 @@ void addConstraints(Constraint* constraint, Scope2* scope) scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) - constraint->dependencies.push_back(c.get()); + constraint->dependencies.push_back(NotNull{c.get()}); for (Scope2* childScope : scope->children) addConstraints(constraint, childScope); @@ -155,31 +172,75 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function TypeId functionType = nullptr; auto ty = scope->lookup(function->name); - LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - - functionType = freshType(scope); - scope->bindings[function->name] = functionType; - - Scope2* innerScope = childScope(function->func->body->location, scope); - TypePackId returnType = freshTypePack(scope); - innerScope->returnType = returnType; - - std::vector argTypes; - - for (AstLocal* local : function->func->args) + if (ty.has_value()) { - TypeId t = freshType(innerScope); - argTypes.push_back(t); - innerScope->bindings[local] = t; // TODO annotations + // TODO: This is duplicate definition of a local function. Is this allowed? + functionType = *ty; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; } - for (AstStat* stat : function->func->body->body) - visit(innerScope, stat); + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + innerScope->bindings[function->name] = actualFunctionType; - FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; - TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + addConstraints(c.get(), innerScope); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + TypeId functionType = nullptr; + + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + + if (AstExprLocal* localName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(localName->local); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[localName->local] = functionType; + } + innerScope->bindings[localName->local] = actualFunctionType; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(globalName->name); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + rootScope->bindings[globalName->name] = functionType; + } + innerScope->bindings[globalName->name] = actualFunctionType; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + LUAU_ASSERT(0); // not yet implemented + } + + checkFunctionBody(innerScope, function->func); + + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -190,7 +251,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) LUAU_ASSERT(scope); TypePackId exprTypes = checkPack(scope, ret->list); - addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}, ret->location); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) @@ -201,6 +262,28 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) visit(scope, stat); } +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +{ + TypePackId varPackId = checkExprList(scope, assign->vars); + TypePackId valuePack = checkPack(scope, assign->values); + + addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}, assign->location); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +{ + check(scope, ifStatement->condition); + + Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + visit(thenScope, ifStatement->thenbody); + + if (ifStatement->elsebody) + { + Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + visit(elseScope, ifStatement->elsebody); + } +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) { LUAU_ASSERT(scope); @@ -224,75 +307,256 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray e return arena->addTypePack(TypePack{std::move(types), last}); } +TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +{ + TypePackId result = arena->addTypePack({}); + TypePack* resultPack = getMutable(result); + LUAU_ASSERT(resultPack); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + resultPack->head.push_back(check(scope, expr)); + else + resultPack->tail = checkPack(scope, expr); + } + + if (resultPack->head.empty() && resultPack->tail) + return *resultPack->tail; + else + return result; +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); - // TEMP TEMP TEMP HACK HACK HACK FIXME FIXME - TypeId t = check(scope, expr); - return arena->addTypePack({t}); + TypePackId result = nullptr; + + if (AstExprCall* call = expr->as()) + { + std::vector args; + + for (AstExpr* arg : call->args) + { + args.push_back(check(scope, arg)); + } + + // TODO self + + TypeId fnType = check(scope, call->func); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = freshType(scope); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}, expr->location); + + TypePackId rets = freshTypePack(scope); + FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypeId inferredFnType = arena->addType(ftv); + + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}, expr->location); + result = rets; + } + else + { + TypeId t = check(scope, expr); + result = arena->addTypePack({t}); + } + + LUAU_ASSERT(result); + astTypePacks[expr] = result; + return result; } TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); - if (auto a = expr->as()) - return singletonTypes.stringType; - else if (auto a = expr->as()) - return singletonTypes.numberType; - else if (auto a = expr->as()) - return singletonTypes.booleanType; - else if (auto a = expr->as()) - return singletonTypes.nilType; + TypeId result = nullptr; + + if (auto group = expr->as()) + result = check(scope, group->expr); + else if (expr->is()) + result = singletonTypes.stringType; + else if (expr->is()) + result = singletonTypes.numberType; + else if (expr->is()) + result = singletonTypes.booleanType; + else if (expr->is()) + result = singletonTypes.nilType; else if (auto a = expr->as()) { std::optional ty = scope->lookup(a->local); if (ty) - return *ty; + result = *ty; else - return singletonTypes.errorRecoveryType(singletonTypes.anyType); // FIXME? Record an error at this point? + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + } + else if (auto g = expr->as()) + { + std::optional ty = scope->lookup(g->name); + if (ty) + result = *ty; + else + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? } else if (auto a = expr->as()) { - std::vector args; - - for (AstExpr* arg : a->args) + TypePackId packResult = checkPack(scope, expr); + if (auto f = first(packResult)) + return *f; + else if (get(packResult)) { - args.push_back(check(scope, arg)); + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}, expr->location); + + return typeResult; } - - TypeId fnType = check(scope, a->func); - TypeId instantiatedType = freshType(scope); - addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); - - TypeId firstRet = freshType(scope); - TypePackId rets = arena->addTypePack(TypePack{{firstRet}, arena->addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); - FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); - TypeId inferredFnType = arena->addType(ftv); - - addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); - return firstRet; + } + else if (auto a = expr->as()) + { + auto [fnType, functionScope] = checkFunctionSignature(scope, a); + checkFunctionBody(functionScope, a); + return fnType; + } + else if (auto indexName = expr->as()) + { + result = check(scope, indexName); + } + else if (auto table = expr->as()) + { + result = checkExprTable(scope, table); } else { LUAU_ASSERT(0); - return freshType(scope); + result = freshType(scope); + } + + LUAU_ASSERT(result); + astTypes[expr] = result; + return result; +} + +TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +{ + TypeId obj = check(scope, indexName->expr); + TypeId result = freshType(scope); + + TableTypeVar::Props props{{indexName->index.value, Property{result}}}; + const std::optional indexer; + TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, TableState::Free}; + + TypeId expectedTableType = arena->addType(std::move(ttv)); + + addConstraint(scope, SubtypeConstraint{obj, expectedTableType}, indexName->location); + + return result; +} + +TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +{ + TypeId ty = arena->addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + auto createIndexer = [this, scope, ttv]( + TypeId currentIndexType, TypeId currentResultType, Location itemLocation, std::optional keyLocation) { + if (!ttv->indexer) + { + TypeId indexType = this->freshType(scope); + TypeId resultType = this->freshType(scope); + ttv->indexer = TableIndexer{indexType, resultType}; + } + + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}, keyLocation ? *keyLocation : itemLocation); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}, itemLocation); + }; + + for (const AstExprTable::Item& item : expr->items) + { + TypeId itemTy = check(scope, item.value); + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key); + + if (AstExprConstantString* key = item.key->as()) + { + ttv->props[key->value.begin()] = {itemTy}; + } + else + { + createIndexer(keyTy, itemTy, item.value->location, item.key->location); + } + } + else + { + TypeId numberType = singletonTypes.numberType; + createIndexer(numberType, itemTy, item.value->location, std::nullopt); + } + } + + return ty; +} + +std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +{ + Scope2* innerScope = childScope(fn->body->location, parent); + TypePackId returnType = freshTypePack(innerScope); + innerScope->returnType = returnType; + + std::vector argTypes; + + for (AstLocal* local : fn->args) + { + TypeId t = freshType(innerScope); + argTypes.push_back(t); + innerScope->bindings[local] = t; // TODO annotations + } + + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + astTypes[fn] = actualFunctionType; + + return {actualFunctionType, innerScope}; +} + +void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +{ + for (AstStat* stat : fn->body->body) + visit(scope, stat); + + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + + if (nullptr != getFallthrough(fn->body)) + { + TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever + addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}, fn->body->location); } } -static void collectConstraints(std::vector& result, Scope2* scope) +void collectConstraints(std::vector>& result, Scope2* scope) { for (const auto& c : scope->constraints) - result.push_back(c.get()); + result.push_back(NotNull{c.get()}); for (Scope2* child : scope->children) collectConstraints(result, child); } -std::vector collectConstraints(Scope2* rootScope) +std::vector> collectConstraints(Scope2* rootScope) { - std::vector result; + std::vector> result; collectConstraints(result, rootScope); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f40cd4b3..41dfd892 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -7,6 +7,7 @@ #include "Luau/Unifier.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -58,11 +59,11 @@ ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) { - for (const Constraint* c : constraints) + for (NotNull c : constraints) { - unsolvedConstraints.insert(c); + unsolvedConstraints.push_back(c); - for (const Constraint* dep : c->dependencies) + for (NotNull dep : c->dependencies) { block(dep, c); } @@ -74,8 +75,6 @@ void ConstraintSolver::run() if (done()) return; - bool progress = false; - ToStringOptions opts; if (FFlag::DebugLuauLogSolver) @@ -84,44 +83,80 @@ void ConstraintSolver::run() dump(this, opts); } - do + if (FFlag::DebugLuauLogSolverToJson) { - progress = false; + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + } - auto it = begin(unsolvedConstraints); - auto endIt = end(unsolvedConstraints); + auto runSolverPass = [&](bool force) { + bool progress = false; - while (it != endIt) + size_t i = 0; + while (i < unsolvedConstraints.size()) { - if (isBlocked(*it)) + NotNull c = unsolvedConstraints[i]; + if (!force && isBlocked(c)) { - ++it; + ++i; continue; } - std::string saveMe = FFlag::DebugLuauLogSolver ? toString(**it, opts) : std::string{}; + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; - bool success = tryDispatch(*it); - progress = progress || success; + if (FFlag::DebugLuauLogSolverToJson) + { + logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints); + } + + bool success = tryDispatch(c, force); + + progress |= success; - auto saveIt = it; - ++it; if (success) { - unsolvedConstraints.erase(saveIt); + unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.commitPreparedStepSnapshot(); + } + if (FFlag::DebugLuauLogSolver) { + if (force) + printf("Force "); printf("Dispatched\n\t%s\n", saveMe.c_str()); dump(this, opts); } } + else + ++i; + + if (force && success) + return true; } + + return progress; + }; + + bool progress = false; + do + { + progress = runSolverPass(false); + if (!progress) + progress |= runSolverPass(true); } while (progress); if (FFlag::DebugLuauLogSolver) + { dumpBindings(rootScope, opts); + } - LUAU_ASSERT(done()); + if (FFlag::DebugLuauLogSolverToJson) + { + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + printf("Logger output:\n%s\n", logger.compileOutput().c_str()); + } } bool ConstraintSolver::done() @@ -129,21 +164,21 @@ bool ConstraintSolver::done() return unsolvedConstraints.empty(); } -bool ConstraintSolver::tryDispatch(const Constraint* constraint) +bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { - if (isBlocked(constraint)) + if (!force && isBlocked(constraint)) return false; bool success = false; if (auto sc = get(*constraint)) - success = tryDispatch(*sc); + success = tryDispatch(*sc, constraint, force); else if (auto psc = get(*constraint)) - success = tryDispatch(*psc); + success = tryDispatch(*psc, constraint, force); else if (auto gc = get(*constraint)) - success = tryDispatch(*gc); + success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) - success = tryDispatch(*ic, constraint); + success = tryDispatch(*ic, constraint, force); else LUAU_ASSERT(0); @@ -155,65 +190,66 @@ bool ConstraintSolver::tryDispatch(const Constraint* constraint) return success; } -bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c) +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subType, c.superType); + if (isBlocked(c.subType)) + return block(c.subType, constraint); + else if (isBlocked(c.superType)) + return block(c.superType, constraint); + + unify(c.subType, c.superType, constraint->location); + unblock(c.subType); unblock(c.superType); return true; } -bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c) +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subPack, c.superPack); + unify(c.subPack, c.superPack, constraint->location); unblock(c.subPack); unblock(c.superPack); return true; } -bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& constraint) +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) { - unify(constraint.subType, constraint.superType); + if (isBlocked(c.sourceType)) + return block(c.sourceType, constraint); - quantify(constraint.superType, constraint.scope); - unblock(constraint.subType); - unblock(constraint.superType); + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(c.sourceType); + else + unify(c.generalizedType, c.sourceType, constraint->location); + + TypeId generalized = quantify(arena, c.sourceType, c.scope); + *asMutable(c.sourceType) = *generalized; + + unblock(c.generalizedType); + unblock(c.sourceType); return true; } -bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, const Constraint* constraint) +bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) { - TypeId superType = follow(c.superType); - if (const FunctionTypeVar* ftv = get(superType)) - { - if (!ftv->generalized) - { - block(superType, constraint); - return false; - } - } - else if (get(superType)) - { - block(superType, constraint); - return false; - } - // TODO: Error if it's a primitive or something + if (isBlocked(c.superType)) + return block(c.superType, constraint); Instantiation inst(TxnLog::empty(), arena, TypeLevel{}); std::optional instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated); + unify(c.subType, *instantiated, constraint->location); unblock(c.subType); return true; } -void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* constraint) +void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -221,19 +257,21 @@ void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* cons count += 1; } -void ConstraintSolver::block(const Constraint* target, const Constraint* constraint) +void ConstraintSolver::block(NotNull target, NotNull constraint) { block_(target, constraint); } -void ConstraintSolver::block(TypeId target, const Constraint* constraint) +bool ConstraintSolver::block(TypeId target, NotNull constraint) { block_(target, constraint); + return false; } -void ConstraintSolver::block(TypePackId target, const Constraint* constraint) +bool ConstraintSolver::block(TypePackId target, NotNull constraint) { block_(target, constraint); + return false; } void ConstraintSolver::unblock_(BlockedConstraintId progressed) @@ -243,7 +281,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (const Constraint* unblockedConstraint : it->second) + for (NotNull unblockedConstraint : it->second) { auto& count = blockedConstraints[unblockedConstraint]; // This assertion being hit indicates that `blocked` and @@ -257,7 +295,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) blocked.erase(it); } -void ConstraintSolver::unblock(const Constraint* progressed) +void ConstraintSolver::unblock(NotNull progressed) { return unblock_(progressed); } @@ -272,35 +310,33 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } -bool ConstraintSolver::isBlocked(const Constraint* constraint) +bool ConstraintSolver::isBlocked(TypeId ty) +{ + return nullptr != get(follow(ty)); +} + +bool ConstraintSolver::isBlocked(NotNull constraint) { auto blockedIt = blockedConstraints.find(constraint); return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::reportErrors(const std::vector& errors) -{ - this->errors.insert(end(this->errors), begin(errors), end(errors)); -} - -void ConstraintSolver::unify(TypeId subType, TypeId superType) +void ConstraintSolver::unify(TypeId subType, TypeId superType, Location location) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; u.tryUnify(subType, superType); u.log.commit(); - reportErrors(u.errors); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, Location location) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; u.tryUnify(subPack, superPack); u.log.commit(); - reportErrors(u.errors); } } // namespace Luau diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp new file mode 100644 index 00000000..2f93c280 --- /dev/null +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -0,0 +1,139 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintSolverLogger.h" + +namespace Luau +{ + +static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts) +{ + std::string output = "{\"bindings\":{"; + + bool comma = false; + for (const auto& [name, type] : scope->bindings) + { + if (comma) + output += ","; + + output += "\""; + output += name.c_str(); + output += "\": \""; + + ToStringResult result = toStringDetailed(type, opts); + opts.nameMap = std::move(result.nameMap); + output += result.name; + output += "\""; + + comma = true; + } + + output += "},\"children\":["; + comma = false; + + for (const Scope2* child : scope->children) + { + if (comma) + output += ","; + + output += dumpScopeAndChildren(child, opts); + comma = true; + } + + output += "]}"; + return output; +} + +static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) +{ + std::string result = "digraph Constraints {\\n"; + + std::unordered_set> contained; + for (NotNull c : constraints) + { + contained.insert(c); + } + + for (NotNull c : constraints) + { + std::string id = std::to_string(reinterpret_cast(c.get())); + result += id; + result += " [label=\\\""; + result += toString(*c, opts).c_str(); + result += "\\\"];\\n"; + + for (NotNull dep : c->dependencies) + { + if (contained.count(dep) == 0) + continue; + + result += std::to_string(reinterpret_cast(dep.get())); + result += " -> "; + result += id; + result += ";\\n"; + } + } + + result += "}"; + + return result; +} + +std::string ConstraintSolverLogger::compileOutput() +{ + std::string output = "["; + bool comma = false; + + for (const std::string& snapshot : snapshots) + { + if (comma) + output += ","; + output += snapshot; + + comma = true; + } + + output += "]"; + return output; +} + +void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints) +{ + std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\"}"; + + snapshots.push_back(std::move(snapshot)); +} + +void ConstraintSolverLogger::prepareStepSnapshot( + const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints) +{ + // LUAU_ASSERT(!preparedSnapshot); + + std::string snapshot = "{\"type\":\"step\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\",\"currentId\":\""; + snapshot += std::to_string(reinterpret_cast(current.get())); + snapshot += "\",\"current\":\""; + snapshot += toString(*current, opts); + snapshot += "\"}"; + + preparedSnapshot = std::move(snapshot); +} + +void ConstraintSolverLogger::commitPreparedStepSnapshot() +{ + if (preparedSnapshot) + { + snapshots.push_back(std::move(*preparedSnapshot)); + preparedSnapshot = std::nullopt; + } +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 741a35cf..9e025062 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,16 +1,17 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" -#include "Luau/Common.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Config.h" -#include "Luau/FileResolver.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/FileResolver.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -216,7 +217,7 @@ ErrorVec accumulateErrors( continue; const SourceNode& sourceNode = it->second; - queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); + queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode @@ -586,7 +587,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec path.push_back(top); // push children - for (const ModuleName& dep : top->requires) + for (const ModuleName& dep : top->requireSet) { auto it = sourceNodes.find(dep); if (it != sourceNodes.end()) @@ -738,7 +739,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requires) + for (const auto& dep : module.second.requireSet) reverseDeps[dep].push_back(module.first); } @@ -797,9 +798,14 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco cs.run(); result->scope2s = std::move(cgb.scopes); + result->astTypes = std::move(cgb.astTypes); + result->astTypePacks = std::move(cgb.astTypePacks); + result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); result->clonePublicInterface(iceHandler); + Luau::check(sourceModule, result.get()); + return result; } @@ -841,8 +847,8 @@ std::pair Frontend::getSourceNode(CheckResult& check SourceModule result = parse(name, source->source, opts); result.type = source->type; - RequireTraceResult& requireTrace = requires[name]; - requireTrace = traceRequires(fileResolver, result.root, name); + RequireTraceResult& require = requireTrace[name]; + require = traceRequires(fileResolver, result.root, name); SourceNode& sourceNode = sourceNodes[name]; SourceModule& sourceModule = sourceModules[name]; @@ -851,7 +857,7 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceModule.environmentName = environmentName; sourceNode.name = name; - sourceNode.requires.clear(); + sourceNode.requireSet.clear(); sourceNode.requireLocations.clear(); sourceNode.dirtySourceModule = false; @@ -861,10 +867,10 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.dirtyModuleForAutocomplete = true; } - for (const auto& [moduleName, location] : requireTrace.requires) - sourceNode.requires.insert(moduleName); + for (const auto& [moduleName, location] : require.requireList) + sourceNode.requireSet.insert(moduleName); - sourceNode.requireLocations = requireTrace.requires; + sourceNode.requireLocations = require.requireList; return {&sourceNode, &sourceModule}; } @@ -925,8 +931,8 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. - auto it = frontend->requires.find(currentModuleName); - if (it == frontend->requires.end()) + auto it = frontend->requireTrace.find(currentModuleName); + if (it == frontend->requireTrace.end()) { // CLI-43699 // If we can't find the current module name, that's because we bypassed the frontend's initializer @@ -1025,7 +1031,7 @@ void Frontend::clear() sourceModules.clear(); moduleResolver.modules.clear(); moduleResolverForAutocomplete.modules.clear(); - requires.clear(); + requireTrace.clear(); } } // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index f145a511..77c62422 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -40,7 +40,7 @@ TypeId Instantiation::clean(TypeId ty) const FunctionTypeVar* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.tags = ftv->tags; clone.argNames = ftv->argNames; diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 200b7d1b..50868e56 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2282,7 +2282,7 @@ private: size_t getReturnCount(TypeId ty) { if (auto ftv = get(ty)) - return size(ftv->retType); + return size(ftv->retTypes); if (auto itv = get(ty)) { @@ -2291,7 +2291,7 @@ private: for (TypeId part : itv->parts) if (auto ftv = get(follow(part))) - result = std::max(result, size(ftv->retType)); + result = std::max(result, size(ftv->retTypes)); return result; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 11403be5..d36665e2 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,6 +17,7 @@ LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -273,6 +274,18 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) return ok; } +bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subPack, superPack); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + template static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) { @@ -390,6 +403,7 @@ struct Normalize final : TypeVarVisitor bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override { CHECK_ITERATION_LIMIT(false); + LUAU_ASSERT(!ty->normal); ConstrainedTypeVar* ctv = const_cast(&ctvRef); @@ -401,14 +415,21 @@ struct Normalize final : TypeVarVisitor std::vector newParts = normalizeUnion(parts); - const bool normal = areNormal(newParts, seen, ice); - - if (newParts.size() == 1) - *asMutable(ty) = BoundTypeVar{newParts[0]}; + if (FFlag::LuauQuantifyConstrained) + { + ctv->parts = std::move(newParts); + } else - *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + { + const bool normal = areNormal(newParts, seen, ice); - asMutable(ty)->normal = normal; + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + } return false; } @@ -421,9 +442,9 @@ struct Normalize final : TypeVarVisitor return false; traverse(ftv.argTypes); - traverse(ftv.retType); + traverse(ftv.retTypes); - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); return false; } @@ -465,7 +486,14 @@ struct Normalize final : TypeVarVisitor checkNormal(ttv.indexer->indexResultType); } - asMutable(ty)->normal = normal; + // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. + if (FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) + asMutable(ty)->normal = normal; + } + else + asMutable(ty)->normal = normal; return false; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 21775373..2004d153 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -2,15 +2,32 @@ #include "Luau/Quantify.h" +#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header +#include "Luau/TxnLog.h" +#include "Luau/Substitution.h" #include "Luau/VisitTypeVar.h" #include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header LUAU_FASTFLAG(LuauAlwaysQuantify); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) namespace Luau { +/// @return true if outer encloses inner +static bool subsumes(Scope2* outer, Scope2* inner) +{ + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent; + } + + return false; +} + struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; @@ -62,6 +79,34 @@ struct Quantifier final : TypeVarOnceVisitor return false; } + bool visit(TypeId ty, const ConstrainedTypeVar&) override + { + if (FFlag::LuauQuantifyConstrained) + { + ConstrainedTypeVar* ctv = getMutable(ty); + + seenMutableType = true; + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) + return false; + + std::vector opts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic + for (TypeId opt : opts) + traverse(opt); + + if (opts.size() == 1) + *asMutable(ty) = BoundTypeVar{opts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(opts)}; + + return false; + } + else + return true; + } + bool visit(TypeId ty, const TableTypeVar&) override { LUAU_ASSERT(getMutable(ty)); @@ -73,8 +118,12 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) seenMutableType = true; - if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) - return false; + if (!FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + } + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) { if (ttv.state == TableState::Unsealed) @@ -156,4 +205,104 @@ void quantify(TypeId ty, Scope2* scope) ftv->generalized = true; } +struct PureQuantifier : Substitution +{ + Scope2* scope; + std::vector insertedGenerics; + std::vector insertedGenericPacks; + + PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) + : Substitution(log, arena) + , scope(scope) + { + } + + bool isDirty(TypeId ty) override + { + LUAU_ASSERT(ty == follow(ty)); + + if (auto ftv = get(ty)) + { + return subsumes(scope, ftv->scope); + } + else if (auto ttv = get(ty)) + { + return ttv->state == TableState::Free && subsumes(scope, ttv->scope); + } + + return false; + } + + bool isDirty(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + return subsumes(scope, ftp->scope); + } + + return false; + } + + TypeId clean(TypeId ty) override + { + if (auto ftv = get(ty)) + { + TypeId result = arena->addType(GenericTypeVar{}); + insertedGenerics.push_back(result); + return result; + } + else if (auto ttv = get(ty)) + { + TypeId result = arena->addType(TableTypeVar{}); + TableTypeVar* resultTable = getMutable(result); + LUAU_ASSERT(resultTable); + + *resultTable = *ttv; + resultTable->scope = nullptr; + resultTable->state = TableState::Generic; + + return result; + } + + return ty; + } + + TypePackId clean(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); + insertedGenericPacks.push_back(result); + return result; + } + + return tp; + } + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) +{ + PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + std::optional result = quantifier.substitute(ty); + LUAU_ASSERT(result); + + FunctionTypeVar* ftv = getMutable(*result); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); + + // TODO: Set hasNoGenerics. + + return *result; +} + } // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 8ed245fb..c036a7a5 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -28,7 +28,7 @@ struct RequireTracer : AstVisitor AstExprGlobal* global = expr->func->as(); if (global && global->name == "require" && expr->args.size >= 1) - requires.push_back(expr); + requireCalls.push_back(expr); return true; } @@ -84,9 +84,9 @@ struct RequireTracer : AstVisitor ModuleInfo moduleContext{currentModuleName}; // seed worklist with require arguments - work.reserve(requires.size()); + work.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal @@ -125,15 +125,15 @@ struct RequireTracer : AstVisitor } // resolve all requires according to their argument - result.requires.reserve(requires.size()); + result.requireList.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) { AstExpr* arg = require->args.data[0]; if (const ModuleInfo* info = result.exprs.find(arg)) { - result.requires.push_back({info->name, require->location}); + result.requireList.push_back({info->name, require->location}); ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! result.exprs[require] = std::move(infoCopy); @@ -151,7 +151,7 @@ struct RequireTracer : AstVisitor DenseHashMap locals; std::vector work; - std::vector requires; + std::vector requireCalls; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 5a22deeb..9c4ce829 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -27,7 +27,7 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionTypeVar* ftv = get(ty)) { visitChild(ftv->argTypes); - visitChild(ftv->retType); + visitChild(ftv->retTypes); } else if (const TableTypeVar* ttv = get(ty)) { @@ -442,7 +442,7 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionTypeVar* ftv = getMutable(ty)) { ftv->argTypes = replace(ftv->argTypes); - ftv->retType = replace(ftv->retType); + ftv->retTypes = replace(ftv->retTypes); } else if (TableTypeVar* ttv = getMutable(ty)) { diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9b396c80..6b677bb8 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -154,7 +154,7 @@ void StateDot::visitChildren(TypeId ty, int index) finishNode(); visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); + visitChild(ftv->retTypes, index, "ret"); } else if (const TableTypeVar* ttv = get(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 04d15cf7..81dc0467 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -226,6 +226,11 @@ struct StringifierState result.name += s; } + void emit(int i) + { + emit(std::to_string(i).c_str()); + } + void indent() { indentation += 4; @@ -394,6 +399,13 @@ struct TypeVarStringifier state.emit("]]"); } + void operator()(TypeId, const BlockedTypeVar& btv) + { + state.emit("*blocked-"); + state.emit(btv.index); + state.emit("*"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -480,8 +492,8 @@ struct TypeVarStringifier if (FFlag::LuauLowerBoundsCalculation) { - auto retBegin = begin(ftv.retType); - auto retEnd = end(ftv.retType); + auto retBegin = begin(ftv.retTypes); + auto retEnd = end(ftv.retTypes); if (retBegin != retEnd) { ++retBegin; @@ -491,7 +503,7 @@ struct TypeVarStringifier } else { - if (auto retPack = get(follow(ftv.retType))) + if (auto retPack = get(follow(ftv.retTypes))) { if (retPack->head.size() == 1 && !retPack->tail) plural = false; @@ -501,7 +513,7 @@ struct TypeVarStringifier if (plural) state.emit("("); - stringify(ftv.retType); + stringify(ftv.retTypes); if (plural) state.emit(")"); @@ -1303,14 +1315,14 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp state.emit("): "); - size_t retSize = size(ftv.retType); - bool hasTail = !finite(ftv.retType); - bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + size_t retSize = size(ftv.retTypes); + bool hasTail = !finite(ftv.retTypes); + bool wrap = get(follow(ftv.retTypes)) && (hasTail ? retSize != 0 : retSize != 1); if (wrap) state.emit("("); - tvs.stringify(ftv.retType); + tvs.stringify(ftv.retTypes); if (wrap) state.emit(")"); @@ -1385,9 +1397,9 @@ std::string toString(const Constraint& c, ToStringOptions& opts) } else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) { - ToStringResult subStr = toStringDetailed(gc->subType, opts); + ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(gc->superType, opts); + ToStringResult superStr = toStringDetailed(gc->sourceType, opts); opts.nameMap = std::move(superStr.nameMap); return subStr.name + " ~ gen " + superStr.name; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 0f4534b7..6cca7127 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,11 @@ public: } } + AstType* operator()(const BlockedTypeVar& btv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + } + AstType* operator()(const ConstrainedTypeVar& ctv) { AstArray types; @@ -271,7 +276,7 @@ public: } AstArray returnTypes; - const auto& [retVector, retTail] = flatten(ftv.retType); + const auto& [retVector, retTail] = flatten(ftv.retTypes); returnTypes.size = retVector.size(); returnTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * returnTypes.size)); for (size_t i = 0; i < returnTypes.size; ++i) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp new file mode 100644 index 00000000..7f5ba683 --- /dev/null +++ b/Analysis/src/TypeChecker2.cpp @@ -0,0 +1,160 @@ + +#include "Luau/TypeChecker2.h" + +#include + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Clone.h" +#include "Luau/Normalize.h" + +namespace Luau +{ + +struct TypeChecker2 : public AstVisitor +{ + const SourceModule* sourceModule; + Module* module; + InternalErrorReporter ice; // FIXME accept a pointer from Frontend + + TypeChecker2(const SourceModule* sourceModule, Module* module) + : sourceModule(sourceModule) + , module(module) + { + } + + using AstVisitor::visit; + + TypePackId lookupPack(AstExpr* expr) + { + TypePackId* tp = module->astTypePacks.find(expr); + LUAU_ASSERT(tp); + return follow(*tp); + } + + TypeId lookupType(AstExpr* expr) + { + TypeId* ty = module->astTypes.find(expr); + LUAU_ASSERT(ty); + return follow(*ty); + } + + bool visit(AstStatAssign* assign) override + { + size_t count = std::min(assign->vars.size, assign->values.size); + + for (size_t i = 0; i < count; ++i) + { + AstExpr* lhs = assign->vars.data[i]; + TypeId* lhsType = module->astTypes.find(lhs); + LUAU_ASSERT(lhsType); + + AstExpr* rhs = assign->values.data[i]; + TypeId* rhsType = module->astTypes.find(rhs); + LUAU_ASSERT(rhsType); + + if (!isSubtype(*rhsType, *lhsType, ice)) + { + reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + } + } + + return true; + } + + bool visit(AstExprCall* call) override + { + TypePackId expectedRetType = lookupPack(call); + TypeId functionType = lookupType(call->func); + + TypeArena arena; + TypePack args; + for (const auto& arg : call->args) + { + TypeId argTy = module->astTypes[arg]; + LUAU_ASSERT(argTy); + args.head.push_back(argTy); + } + + TypePackId argsTp = arena.addTypePack(args); + FunctionTypeVar ftv{argsTp, expectedRetType}; + TypeId expectedType = arena.addType(ftv); + if (!isSubtype(expectedType, functionType, ice)) + { + unfreeze(module->interfaceTypes); + CloneState cloneState; + expectedType = clone(expectedType, module->interfaceTypes, cloneState); + freeze(module->interfaceTypes); + reportError(TypeMismatch{expectedType, functionType}, call->location); + } + + return true; + } + + bool visit(AstExprIndexName* indexName) override + { + TypeId leftType = lookupType(indexName->expr); + TypeId resultType = lookupType(indexName); + + // leftType must have a property called indexName->index + + if (auto ttv = get(leftType)) + { + auto it = ttv->props.find(indexName->index.value); + if (it == ttv->props.end()) + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + else if (!isSubtype(resultType, it->second.type, ice)) + { + reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + } + } + else + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + + return true; + } + + bool visit(AstExprConstantNumber* number) override + { + TypeId actualType = lookupType(number); + TypeId numberType = getSingletonTypes().numberType; + + if (!isSubtype(actualType, numberType, ice)) + { + reportError(TypeMismatch{actualType, numberType}, number->location); + } + + return true; + } + + bool visit(AstExprConstantString* string) override + { + TypeId actualType = lookupType(string); + TypeId stringType = getSingletonTypes().stringType; + + if (!isSubtype(actualType, stringType, ice)) + { + reportError(TypeMismatch{actualType, stringType}, string->location); + } + + return true; + } + + void reportError(TypeErrorData&& data, const Location& location) + { + module->errors.emplace_back(location, sourceModule->name, std::move(data)); + } +}; + +void check(const SourceModule& sourceModule, Module* module) +{ + TypeChecker2 typeChecker{&sourceModule, module}; + + sourceModule.root->visit(&typeChecker); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 447cd029..fd1b3b85 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -18,6 +18,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include #include @@ -30,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) @@ -42,9 +42,9 @@ LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); -LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) @@ -260,7 +260,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) - , optionalNumberType(getSingletonTypes().optionalNumberType) , anyTypePack(getSingletonTypes().anyTypePack) , duplicateTypeAliases{{false, {}}} { @@ -679,7 +678,7 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); resolve(result.predicates, ifScope, true); @@ -712,7 +711,7 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Locat void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); resolve(result.predicates, whileScope, true); @@ -728,16 +727,64 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location) +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location) { Unifier state = mkUnifier(location); - state.unifyLowerBound(subTy, superTy); + state.unifyLowerBound(subTy, superTy, demotedLevel); state.log.commit(); reportErrors(state.errors); } +struct Demoter : Substitution +{ + Demoter(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override + { + return get(ty); + } + + bool isDirty(TypePackId tp) override + { + return get(tp); + } + + TypeId clean(TypeId ty) override + { + auto ftv = get(ty); + LUAU_ASSERT(ftv); + return addType(FreeTypeVar{demotedLevel(ftv->level)}); + } + + TypePackId clean(TypePackId tp) override + { + auto ftp = get(tp); + LUAU_ASSERT(ftp); + return addTypePack(TypePackVar{FreeTypePack{demotedLevel(ftp->level)}}); + } + + TypeLevel demotedLevel(TypeLevel level) + { + return TypeLevel{level.level + 5000, level.subLevel}; + } + + void demote(std::vector>& expectedTypes) + { + if (!FFlag::LuauQuantifyConstrained) + return; + for (std::optional& ty : expectedTypes) + { + if (ty) + ty = substitute(*ty); + } + } +}; + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -760,11 +807,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) { - unifyLowerBound(retPack, scope->returnType, return_.location); + unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); return; } @@ -1230,7 +1280,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, forin.location); } else - unify(iterFunc->retType, varPack, forin.location); + unify(iterFunc->retTypes, varPack, forin.location); check(loopScope, *forin.body); } @@ -1611,7 +1661,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1620,7 +1670,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return {errorRecoveryType(scope)}; } - ExprResult result; + WithPredicate result; if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); @@ -1682,7 +1732,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return result; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprLocal is an LValue. @@ -1696,7 +1746,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprGlobal is an LValue. @@ -1708,7 +1758,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) { TypePackId varargPack = checkExprPack(scope, expr).type; @@ -1738,9 +1788,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) { - ExprResult result = checkExprPack(scope, expr); + WithPredicate result = checkExprPack(scope, expr); TypePackId retPack = follow(result.type); if (auto pack = get(retPack)) @@ -1770,7 +1820,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa ice("Unknown TypePack type!", expr.location); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) { Name name = expr.index.value; @@ -2031,7 +2081,7 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) return ty; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId ty = checkLValue(scope, expr); @@ -2042,7 +2092,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn return {ty}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); @@ -2108,8 +2158,7 @@ TypeId TypeChecker::checkExprTable( if (errors.empty()) exprType = expectedProp.type; } - else if (expectedTable->indexer && (FFlag::LuauExpectedPropTypeFromIndexer ? maybeString(expectedTable->indexer->indexType) - : isString(expectedTable->indexer->indexType))) + else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) @@ -2147,7 +2196,7 @@ TypeId TypeChecker::checkExprTable( return addType(table); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -2201,7 +2250,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; - else if (FFlag::LuauExpectedPropTypeFromIndexer && expectedIndexType && maybeString(*expectedIndexType)) + else if (expectedIndexType && maybeString(*expectedIndexType)) expectedResultType = expectedIndexResultType; } else if (expectedUnion) @@ -2236,9 +2285,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa return {checkExprTable(scope, expr, fieldTypes, expectedType)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) { - ExprResult result = checkExpr(scope, *expr.expr); + WithPredicate result = checkExpr(scope, *expr.expr); TypeId operandType = follow(result.type); switch (expr.op) @@ -2466,62 +2515,50 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); - if (FFlag::LuauSuccessTypingForEqualityOperations) + if (leftMetatable != rightMetatable) { - if (leftMetatable != rightMetatable) + bool matches = false; + if (isEquality) { - bool matches = false; - if (isEquality) + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) { - if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + for (TypeId leftOption : utv) { - for (TypeId leftOption : utv) + if (getMetatable(follow(leftOption)) == rightMetatable) { - if (getMetatable(follow(leftOption)) == rightMetatable) + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption)) == leftMetatable) { matches = true; break; } } } - - if (!matches) - { - if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) - { - for (TypeId rightOption : utv) - { - if (getMetatable(follow(rightOption)) == leftMetatable) - { - matches = true; - break; - } - } - } - } - } - - - if (!matches) - { - reportError( - expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorRecoveryType(booleanType); } } - } - else - { - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + + + if (!matches) { reportError( expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } + if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -2532,7 +2569,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(addTypePack({booleanType}), ftv->retType); + state.tryUnify(addTypePack({booleanType}), ftv->retTypes); if (!state.errors.empty()) { @@ -2721,7 +2758,7 @@ TypeId TypeChecker::checkBinaryOperation( } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) { if (expr.op == AstExprBinary::And) { @@ -2752,8 +2789,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); + WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2770,18 +2807,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + WithPredicate lhs = checkExpr(scope, *expr.left); + WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { TypeId annotationType = resolveType(scope, *expr.annotation); - ExprResult result = checkExpr(scope, *expr.expr, annotationType); + WithPredicate result = checkExpr(scope, *expr.expr, annotationType); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. if (canUnify(annotationType, result.type, expr.location).empty()) @@ -2794,7 +2831,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy return {errorRecoveryType(annotationType), std::move(result.predicates)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) { const size_t oldSize = currentModule->errors.size(); @@ -2808,17 +2845,17 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { - ExprResult result = checkExpr(scope, *expr.condition); + WithPredicate result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); resolve(result.predicates, trueScope, true); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); + WithPredicate trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); + WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) return {trueType.type}; @@ -3170,7 +3207,7 @@ std::pair TypeChecker::checkFunctionSignature( retPack = anyTypePack; else if (expectedFunctionType) { - auto [head, tail] = flatten(expectedFunctionType->retType); + auto [head, tail] = flatten(expectedFunctionType->retTypes); // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) @@ -3354,7 +3391,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE if (useConstrainedIntersections()) { - TypePackId retPack = follow(funTy->retType); + TypePackId retPack = follow(funTy->retTypes); // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) @@ -3367,20 +3404,20 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE else { // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (get_if(&funTy->retTypes->ty)) + *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; - if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retType))) + if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retTypes))) { // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. if (!isNonstrictMode() || function.returnAnnotation) { - reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); + reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retTypes}); } } } @@ -3388,7 +3425,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE ice("Checking non functional type"); } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkExprPack(scope, *a); @@ -3654,7 +3691,7 @@ void TypeChecker::checkArgumentList( } } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3722,7 +3759,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); TypePackId argPack = argListResult.type; if (get(argPack)) @@ -3766,7 +3803,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!overload && !overloadsThatDont.empty()) overload = get(overloadsThatDont[0]); if (overload) - return {errorRecoveryTypePack(overload->retType)}; + return {errorRecoveryTypePack(overload->retTypes)}; return {errorRecoveryTypePack(retPack)}; } @@ -3775,7 +3812,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, std::optional ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -3790,7 +3827,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } else { - std::vector result = reduceUnion({*el, *ty}); + std::vector result = reduceUnion({*el, ty}); el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } @@ -3810,7 +3847,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st if (argsTail) { - if (const VariadicTypePack* vtp = get(follow(*argsTail))) + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) { while (index < argumentCount) assignOption(index++, vtp->ty); @@ -3819,11 +3857,14 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, +std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -3918,14 +3959,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (ftv->magicFunction) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) return *ret; } Unifier state = mkUnifier(expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { return {}; @@ -3996,7 +4037,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4027,7 +4068,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4085,7 +4126,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retType, {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, {}); checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } @@ -4110,7 +4151,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast return; } -ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, +WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { TypePackId pack = addTypePack(TypePack{}); @@ -4401,10 +4442,24 @@ TypeId Anyification::clean(TypeId ty) } else if (auto ctv = get(ty)) { - auto [t, ok] = normalize(ty, *arena, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; + if (FFlag::LuauQuantifyConstrained) + { + std::vector copy = ctv->parts; + for (TypeId& ty : copy) + ty = replace(ty); + TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); + auto [t, ok] = normalize(res, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } + else + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } } else return anyType; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index ba09df5f..3d97e6eb 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -66,7 +66,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t } else if (const auto& itf = get(index)) { - std::optional r = first(follow(itf->retType)); + std::optional r = first(follow(itf->retTypes)); if (!r) return getSingletonTypes().nilType; else diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 33bfe254..57762937 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -29,8 +29,8 @@ LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId follow(TypeId t) { @@ -408,41 +408,48 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +BlockedTypeVar::BlockedTypeVar() + : index(++nextIndex) +{ +} + +int BlockedTypeVar::nextIndex = 0; + +FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, +FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retType, std::optional defn, bool hasSelf) + TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { @@ -488,7 +495,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) return false; - if (!areEqual(seen, *lhs.retType, *rhs.retType)) + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) return false; return true; @@ -678,7 +685,6 @@ static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent* static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -692,7 +698,6 @@ SingletonTypes::SingletonTypes() , trueType(&trueType_) , falseType(&falseType_) , anyType(&anyType_) - , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) , arena(new TypeArena) { @@ -825,7 +830,7 @@ void persist(TypeId ty) else if (auto ftv = get(t)) { persist(ftv->argTypes); - persist(ftv->retType); + persist(ftv->retTypes); } else if (auto ttv = get(t)) { @@ -1100,10 +1105,10 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha return result; } -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -1142,7 +1147,7 @@ std::optional> magicFunctionFormat( if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); - return ExprResult{arena.addTypePack({typechecker.stringType})}; + return WithPredicate{arena.addTypePack({typechecker.stringType})}; } std::vector filterMap(TypeId type, TypeIdPredicate predicate) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 414b05f4..877663de 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -1288,13 +1289,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); + innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) reportError( @@ -1312,7 +1313,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(subFunction->retType, superFunction->retType); + tryUnify_(subFunction->retTypes, superFunction->retTypes); } if (FFlag::LuauTxnLogRefreshFunctionPointers) @@ -2177,7 +2178,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto fun = state.log.getMutable(ty)) { queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retTypes, anyTypePack); } else if (auto table = state.log.getMutable(ty)) { @@ -2322,7 +2323,7 @@ void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) superC->parts.push_back(subTy); } -void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) { // The duplication between this and regular typepack unification is tragic. @@ -2357,7 +2358,7 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) if (!freeTailPack) return; - TypeLevel level = freeTailPack->level; + TypeLevel level = FFlag::LuauQuantifyConstrained ? demotedLevel : freeTailPack->level; TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 597b2f0a..a34f7603 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1075,6 +1075,8 @@ void BytecodeBuilder::validate() const LUAU_ASSERT(i <= insns.size()); } + std::vector openCaptures; + // second pass: validate the rest of the bytecode for (size_t i = 0; i < insns.size();) { @@ -1121,6 +1123,8 @@ void BytecodeBuilder::validate() const case LOP_CLOSEUPVALS: VREG(LUAU_INSN_A(insn)); + while (openCaptures.size() && openCaptures.back() >= LUAU_INSN_A(insn)) + openCaptures.pop_back(); break; case LOP_GETIMPORT: @@ -1388,8 +1392,12 @@ void BytecodeBuilder::validate() const switch (LUAU_INSN_A(insn)) { case LCT_VAL: + VREG(LUAU_INSN_B(insn)); + break; + case LCT_REF: VREG(LUAU_INSN_B(insn)); + openCaptures.push_back(LUAU_INSN_B(insn)); break; case LCT_UPVAL: @@ -1409,6 +1417,12 @@ void BytecodeBuilder::validate() const LUAU_ASSERT(i <= insns.size()); } + // all CAPTURE REF instructions must have a CLOSEUPVALS instruction after them in the bytecode stream + // this doesn't guarantee safety as it doesn't perform basic block based analysis, but if this fails + // then the bytecode is definitely unsafe to run since the compiler won't generate backwards branches + // except for loop edges + LUAU_ASSERT(openCaptures.empty()); + #undef VREG #undef VREGEND #undef VUPVAL diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7431cde4..52dc9242 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -246,6 +246,14 @@ struct Compiler f.canInline = true; f.stackSize = stackSize; f.costModel = modelCost(func->body, func->args.data, func->args.size); + + // track functions that only ever return a single value so that we can convert multret calls to fixedret calls + if (allPathsEndWithReturn(func->body)) + { + ReturnVisitor returnVisitor(this); + stat->visit(&returnVisitor); + f.returnsOne = returnVisitor.returnsOne; + } } upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes @@ -260,6 +268,19 @@ struct Compiler { if (AstExprCall* expr = node->as()) { + // Optimization: convert multret calls to functions that always return one value to fixedret calls; this facilitates inlining + if (options.optimizationLevel >= 2) + { + AstExprFunction* func = getFunctionExpr(expr->func); + Function* fi = func ? functions.find(func) : nullptr; + + if (fi && fi->returnsOne) + { + compileExprTemp(node, target); + return false; + } + } + // We temporarily swap out regTop to have targetTop work correctly... // This is a crude hack but it's necessary for correctness :( RegScope rs(this, target); @@ -447,7 +468,9 @@ struct Compiler return false; } - // TODO: we can compile multret functions if all returns of the function are multret as well + // we can't inline multret functions because the caller expects L->top to be adjusted: + // - inlined return compiles to a JUMP, and we don't have an instruction that adjusts L->top arbitrarily + // - even if we did, right now all L->top adjustments are immediately consumed by the next instruction, and for now we want to preserve that if (multRet) { bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); @@ -492,7 +515,7 @@ struct Compiler size_t oldLocals = localStack.size(); // note that we push the frame early; this is needed to block recursive inline attempts - inlineFrames.push_back({func, target, targetCount}); + inlineFrames.push_back({func, oldLocals, target, targetCount}); // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) for (size_t i = 0; i < func->args.size; ++i) @@ -593,6 +616,8 @@ struct Compiler { for (size_t i = 0; i < targetCount; ++i) bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + + closeLocals(oldLocals); } popLocals(oldLocals); @@ -2355,6 +2380,8 @@ struct Compiler compileExprListTemp(stat->list, frame.target, frame.targetCount, /* targetTop= */ false); + closeLocals(frame.localOffset); + if (!fallthrough) { size_t jumpLabel = bytecode.emitLabel(); @@ -3316,6 +3343,48 @@ struct Compiler std::vector upvals; }; + struct ReturnVisitor: AstVisitor + { + Compiler* self; + bool returnsOne = true; + + ReturnVisitor(Compiler* self) + : self(self) + { + } + + bool visit(AstExpr* expr) override + { + return false; + } + + bool visit(AstStatReturn* stat) override + { + if (stat->list.size == 1) + { + AstExpr* value = stat->list.data[0]; + + if (AstExprCall* expr = value->as()) + { + AstExprFunction* func = self->getFunctionExpr(expr->func); + Function* fi = func ? self->functions.find(func) : nullptr; + + returnsOne &= fi && fi->returnsOne; + } + else if (value->is()) + { + returnsOne = false; + } + } + else + { + returnsOne = false; + } + + return false; + } + }; + struct RegScope { RegScope(Compiler* self) @@ -3351,6 +3420,7 @@ struct Compiler uint64_t costModel = 0; unsigned int stackSize = 0; bool canInline = false; + bool returnsOne = false; }; struct Local @@ -3384,6 +3454,8 @@ struct Compiler { AstExprFunction* func; + size_t localOffset; + uint8_t target; uint8_t targetCount; diff --git a/Sources.cmake b/Sources.cmake index 99007e89..f261cba6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -65,12 +65,13 @@ target_sources(Luau.CodeGen PRIVATE target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h - Analysis/include/Luau/NotNull.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h + Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h + Analysis/include/Luau/ConstraintSolverLogger.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -97,6 +98,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TxnLog.h Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h + Analysis/include/Luau/TypeChecker2.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h @@ -113,8 +115,10 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Config.cpp + Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp + Analysis/src/ConstraintSolverLogger.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/Instantiation.cpp @@ -136,6 +140,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TxnLog.cpp Analysis/src/TypeArena.cpp Analysis/src/TypeAttach.cpp + Analysis/src/TypeChecker2.cpp Analysis/src/TypedAllocator.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp @@ -245,7 +250,6 @@ if(TARGET Luau.UnitTest) tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp - tests/NotNull.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 5e02c2ea..bdcb85cb 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -418,7 +418,7 @@ typedef struct Table CommonHeader; - uint8_t flags; /* 1<

flags = 0 +#define invalidateTMcache(t) t->tmcache = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { @@ -479,7 +479,7 @@ Table* luaH_new(lua_State* L, int narray, int nhash) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; - t->flags = cast_byte(~0); + t->tmcache = cast_byte(~0); t->array = NULL; t->sizearray = 0; t->lastfree = 0; @@ -778,7 +778,7 @@ Table* luaH_clone(lua_State* L, Table* tt) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; - t->flags = tt->flags; + t->tmcache = tt->tmcache; t->array = NULL; t->sizearray = 0; t->lsizenode = 0; @@ -835,5 +835,5 @@ void luaH_clear(Table* tt) } /* back to empty -> no tag methods present */ - tt->flags = cast_byte(~0); + tt->tmcache = cast_byte(~0); } diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 9b99506b..e7df4e53 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -88,8 +88,8 @@ const TValue* luaT_gettm(Table* events, TMS event, TString* ename) const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); if (ttisnil(tm)) - { /* no tag method? */ - events->flags |= cast_byte(1u << event); /* cache this fact */ + { /* no tag method? */ + events->tmcache |= cast_byte(1u << event); /* cache this fact */ return NULL; } else diff --git a/VM/src/ltm.h b/VM/src/ltm.h index e1b95c21..a5223941 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -41,10 +41,10 @@ typedef enum } TMS; // clang-format on -#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->flags & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) +#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->tmcache & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) #define fasttm(l, et, e) gfasttm(l->global, et, e) -#define fastnotm(et, e) ((et) == NULL || ((et)->flags & (1u << (e)))) +#define fastnotm(et, e) ((et) == NULL || ((et)->tmcache & (1u << (e)))) LUAI_DATA const char* const luaT_typenames[]; LUAI_DATA const char* const luaT_eventname[]; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index dea1ab19..f3b0bcad 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1992,6 +1992,7 @@ local fp: @1= f auto ac = autocomplete('1'); + REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } @@ -2620,7 +2621,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") { - ScopedFastFlag FFlagLuauIfElseExprFixCompletionIssue("LuauIfElseExprFixCompletionIssue", true); check(R"( local abcdef = 0; local temp = false diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6eee254e..036bf124 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4992,6 +4992,147 @@ RETURN R1 1 )"); } +TEST_CASE("InlineCapture") +{ + // if the argument is captured by a nested closure, normally we can rely on capture by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE VAL R1 +RETURN R2 1 +)"); + + // if the argument is a constant, we move it to a register so that capture by value can happen + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local y = foo(42) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +RETURN R1 1 +)"); + + // if the argument is an externally mutated variable, we copy it to an argument and capture it by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x x = 42 +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADN R1 42 +MOVE R3 R1 +NEWCLOSURE R2 P1 +CAPTURE VAL R3 +RETURN R2 1 +)"); + + // finally, if the argument is mutated internally, we must capture it by reference and close the upvalue + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + return function() return a end +end + +local y = foo() +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +ORK R2 R2 K1 +NEWCLOSURE R1 P1 +CAPTURE REF R2 +CLOSEUPVALS R2 +RETURN R1 1 +)"); + + // note that capture might need to be performed during the fallthrough block + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + print(function() return a end) +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +ORK R3 R3 K1 +GETIMPORT R4 3 +NEWCLOSURE R5 P1 +CAPTURE REF R3 +CALL R4 1 0 +LOADNIL R2 +CLOSEUPVALS R3 +RETURN R2 1 +)"); + + // note that mutation and capture might be inside internal control flow + // TODO: this has an oddly redundant CLOSEUPVALS after JUMP; it's not due to inlining, and is an artifact of how StatBlock/StatReturn interact + // fixing this would reduce the number of redundant CLOSEUPVALS a bit but it only affects bytecode size as these instructions aren't executed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if not a then + local b b = 42 + return function() return b end + end +end + +local x = ... +local y = foo(x) +return y, x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +JUMPIF R1 L0 +LOADNIL R3 +LOADN R3 42 +NEWCLOSURE R2 P1 +CAPTURE REF R3 +CLOSEUPVALS R3 +JUMP L1 +CLOSEUPVALS R3 +L0: LOADNIL R2 +L1: MOVE R3 R2 +MOVE R4 R1 +RETURN R3 2 +)"); +} + TEST_CASE("InlineFallthrough") { // if the function doesn't return, we still fill the results with nil @@ -5044,27 +5185,6 @@ RETURN R1 -1 )"); } -TEST_CASE("InlineCapture") -{ - // can't inline function with nested functions that capture locals because they might be constants - CHECK_EQ("\n" + compileFunction(R"( -local function foo(a) - local function bar() - return a - end - return bar() -end -)", - 1, 2), - R"( -NEWCLOSURE R1 P0 -CAPTURE VAL R0 -MOVE R2 R1 -CALL R2 0 -1 -RETURN R2 -1 -)"); -} - TEST_CASE("InlineArgMismatch") { // when inlining a function, we must respect all the usual rules @@ -5491,6 +5611,96 @@ RETURN R2 1 )"); } +TEST_CASE("InlineMultret") +{ + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a() +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // however, if we can deduce statically that a function always returns a single value, the inlining will work + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // this analysis will also propagate through other functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local function bar(a) + return foo(a) +end + +return bar(42) +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +LOADN R2 42 +RETURN R2 1 +)"); + + // we currently don't do this analysis fully for recursive functions since they can't be inlined anyway + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return foo(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return math.abs(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + TEST_CASE("ReturnConsecutive") { // we can return a single local directly diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index ab5af4f6..96b21613 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -17,13 +17,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(2 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); - CHECK("b <: a" == toString(*constraints[1], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("a <: b" == toString(*constraints[1], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") @@ -36,15 +36,34 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); - REQUIRE(4 == constraints.size()); + REQUIRE(3 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); - CHECK("b <: number" == toString(*constraints[1], opts)); - CHECK("c <: boolean" == toString(*constraints[2], opts)); - CHECK("d <: nil" == toString(*constraints[3], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("number <: b" == toString(*constraints[1], opts)); + CHECK("boolean <: c" == toString(*constraints[2], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") +{ + AstStatBlock* block = parse(R"( + local function a() return nil end + local b = a() + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + ToStringOptions opts; + REQUIRE(5 <= constraints.size()); + + CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); + CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("nil <: a..." == toString(*constraints[4], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") @@ -55,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(4 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); CHECK("b ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (c, d...) <: b" == toString(*constraints[2], opts)); - CHECK("e <: c" == toString(*constraints[3], opts)); + CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -75,13 +94,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(2 == constraints.size()); ToStringOptions opts; - CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); - CHECK("b <: c..." == toString(*constraints[1], opts)); + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("a <: b..." == toString(*constraints[1], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") @@ -93,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(4 == constraints.size()); ToStringOptions opts; - CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); - CHECK("d ~ inst a" == toString(*constraints[1], opts)); - CHECK("(b) -> (e, f...) <: d" == toString(*constraints[2], opts)); - CHECK("e <: c..." == toString(*constraints[3], opts)); + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); + CHECK("d... <: b..." == toString(*constraints[3], opts)); } TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 232ec2de..ac22f65b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -345,7 +345,7 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) if (error.location.begin.line >= lines.size()) { os << "\tSource not available?" << std::endl; - return; + continue; } std::string_view theLine = lines[error.location.begin.line]; @@ -430,6 +430,7 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { + BlockedTypeVar::nextIndex = 0; } ModuleName fromString(std::string_view name) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index c0554669..b9c24704 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -97,8 +97,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); - CHECK_EQ(res.requires[0].first, "Modules/Foo/Bar"); + CHECK_EQ(1, res.requireList.size()); + CHECK_EQ(res.requireList[0].first, "Modules/Foo/Bar"); } // It could be argued that this should not work. @@ -113,7 +113,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require_inside_a_function") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); + CHECK_EQ(1, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "real_source") @@ -138,7 +138,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(8, res.requires.size()); + CHECK_EQ(8, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 89b13ab1..d585b731 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -102,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") const FunctionTypeVar* ftv = get(methodType); REQUIRE(ftv != nullptr); - std::optional methodReturnType = first(ftv->retType); + std::optional methodReturnType = first(ftv->retTypes); REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index c0556103..50dcbad0 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,6 +13,57 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals2") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals_everywhere") +{ + CheckResult result = check(R"( + --!nonstrict + foo = 1 + + if true then + bar = 2 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") { ScopedFastFlag sff[]{{"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}}; @@ -51,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ("any", toString(args[0])); REQUIRE_EQ("any", toString(args[1])); - auto rets = flatten(ftv->retType).first; + auto rets = flatten(ftv->retTypes).first; REQUIRE_EQ(0, rets.size()); } diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 2876175d..284230c9 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -837,6 +837,7 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -867,16 +868,17 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect CHECK("{+ y: number +}" == toString(args[2])); CHECK("{+ z: string +}" == toString(args[3])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0])); + CHECK("{| x: a & {+ w: boolean, y: number, z: string +} |}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -906,16 +908,17 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); CHECK("{+ z: string +}" == toString(args[2])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0])); + CHECK("{| x: {+ x: boolean, y: t1, z: string +} |} where t1 = {+ y: t1 +}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -944,13 +947,13 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect REQUIRE(3 == args.size()); CHECK("{+ x: boolean +}" == toString(args[0])); - CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1])); + CHECK("{+ y: t1 +} where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(args[1])); CHECK("{+ z: string +}" == toString(args[2])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0])); + CHECK("t1 where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") @@ -1062,4 +1065,29 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function f() + if math.random() > 0.5 then + return true + end + type Ret = typeof(f()) + if math.random() > 0.5 then + return "something" + end + return "something" :: Ret + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("() -> boolean | string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index 1a323c85..ed1c25ec 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -75,9 +75,9 @@ TEST_CASE("basic_stuff") t->y = 3.14f; const NotNull u = t; - // u->x = 44; // nope + u->x = 44; int v = u->x; - CHECK(v == 5); + CHECK(v == 44); bar(a); @@ -96,8 +96,11 @@ TEST_CASE("basic_stuff") TEST_CASE("hashable") { std::unordered_map, const char*> map; - NotNull a{new int(8)}; - NotNull b{new int(10)}; + int a_ = 8; + int b_ = 10; + + NotNull a{&a_}; + NotNull b{&b_}; std::string hello = "hello"; std::string world = "world"; @@ -108,9 +111,47 @@ TEST_CASE("hashable") CHECK_EQ(2, map.size()); CHECK_EQ(hello.c_str(), map[a]); CHECK_EQ(world.c_str(), map[b]); +} - delete a; - delete b; +TEST_CASE("const") +{ + int p = 0; + int q = 0; + + NotNull n{&p}; + + *n = 123; + + NotNull m = n; // Conversion from NotNull to NotNull is allowed + + CHECK(123 == *m); // readonly access of m is ok + + // *m = 321; // nope. m points at const data. + + // NotNull o = m; // nope. Conversion from NotNull to NotNull is forbidden + + NotNull n2{&q}; + m = n2; // ok. m points to const data, but is not itself const + + const NotNull m2 = n; + // m2 = n2; // nope. m2 is const. + *m2 = 321; // ok. m2 is const, but points to mutable data + + CHECK(321 == *n); +} + +TEST_CASE("const_compatibility") +{ + int* raw = new int(8); + + NotNull a(raw); + NotNull b(raw); + NotNull c = a; + // NotNull d = c; // nope - no conversion from const to non-const + + CHECK_EQ(*c, 8); + + delete raw; } TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index b9e1ae96..ccdd2b37 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -70,7 +70,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retType; + TypePackId retPack = ftv->retTypes; const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a28ba49e..036a667a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -45,7 +45,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") const FunctionTypeVar* takeFiveType = get(requireType("take_five")); REQUIRE(takeFiveType != nullptr); - std::vector retVec = flatten(takeFiveType->retType).first; + std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(Fixture, "local_function") const FunctionTypeVar* ftv = get(h); REQUIRE(ftv != nullptr); - std::optional rt = first(ftv->retType); + std::optional rt = first(ftv->retTypes); REQUIRE(bool(rt)); TypeId retType = follow(*rt); @@ -361,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") LUAU_REQUIRE_NO_ERRORS(result); const Luau::FunctionTypeVar* fn = get(requireType("p")); REQUIRE(fn); - auto ret = first(fn->retType); + auto ret = first(fn->retTypes); REQUIRE(ret); REQUIRE(get(follow(*ret))); } @@ -460,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType); CHECK(get(*retType)); } @@ -1619,4 +1619,56 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function foo(f) + f(5) + f("hi") + local function g() + return f + end + local h = g() + h(true) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((boolean | number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantified") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o.m1(5) + t[o] = nil + end + + o.m1("hi") + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: check the normalized type of f +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index fbda8bec..edb5adcf 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); @@ -247,7 +247,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); @@ -882,7 +882,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") const FunctionTypeVar* foo = get(follow(fooProp->type)); REQUIRE(bool(foo)); - std::optional ret_ = first(foo->retType); + std::optional ret_ = first(foo->retTypes); REQUIRE(bool(ret_)); TypeId ret = follow(*ret_); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 03614938..fd9b1dd4 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -90,7 +90,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); @@ -777,8 +777,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") { - ScopedFastFlag sff("LuauSuccessTypingForEqualityOperations", true); - CheckResult result = check(R"( local mm = {} type Foo = typeof(setmetatable({}, mm)) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 22fb3b69..487e5979 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -472,6 +472,7 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", true}, + {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( @@ -494,8 +495,8 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") )"); LUAU_REQUIRE_NO_ERRORS(result); - // TODO: We're missing generics a... and b... - CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); + // TODO: We're missing generics b... + CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 207b3cff..cefba4b2 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -13,8 +13,8 @@ using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { if (expr.args.size != 1) return std::nullopt; @@ -32,7 +32,7 @@ std::optional> magicFunctionInstanceIsA( unfreeze(typeChecker.globalTypes); TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } struct RefinementClassFixture : Fixture diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index d622d4af..87d49651 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -642,7 +642,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") const TableTypeVar* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); - std::vector retVec = flatten(ftv->retType).first; + std::vector retVec = flatten(ftv->retTypes).first; const TableTypeVar* retType = get(follow(retVec[0])); REQUIRE(retType != nullptr); @@ -691,7 +691,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") const FunctionTypeVar* fType = get(requireType("f")); REQUIRE(fType != nullptr); - auto retType_ = first(fType->retType); + auto retType_ = first(fType->retTypes); REQUIRE(bool(retType_)); auto retType = get(follow(*retType_)); @@ -1881,7 +1881,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") REQUIRE(prop.type); const FunctionTypeVar* ftv = get(follow(prop.type)); REQUIRE(ftv); - const TypePack* res = get(follow(ftv->retType)); + const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); REQUIRE(res->head.size() == 1); const MetatableTypeVar* mtv = get(follow(res->head[0])); @@ -2584,7 +2584,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); REQUIRE(newType); - std::optional newRetType = *first(newType->retType); + std::optional newRetType = *first(newType->retTypes); REQUIRE(newRetType); const MetatableTypeVar* newRet = get(follow(*newRetType)); @@ -2977,7 +2977,6 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { - ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; CheckResult result = check(R"( @@ -2992,8 +2991,6 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") { - ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; - CheckResult result = check(R"( type X = {[any]: string | boolean} diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index cf0c9881..6257cda6 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,8 +13,9 @@ #include -LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); using namespace Luau; @@ -43,10 +44,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{ - requireType("a"), - requireType("b"), - }})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -86,6 +84,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { ScopedFastFlag sff[]{ + {"DebugLuauDeferredConstraintResolution", false}, {"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}, }; @@ -236,10 +235,14 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + // TODO: Should we assert anything about these tests when DCR is being used? + if (!FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); + } } TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") @@ -352,40 +355,6 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "globals") -{ - CheckResult result = check(R"( - --!nonstrict - foo = true - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals2") -{ - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!nonstrict - foo = function() return 1 end - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> number", toString(requireType("foo"))); -} - TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") { CheckResult result = check(R"( @@ -400,23 +369,6 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") { CheckResult result = check(R"( @@ -447,21 +399,6 @@ TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") CHECK_EQ("any", toString(requireType("value"))); } -// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") -// { -// CheckResult result = check(R"( -// function f(a) -// if a.cond then -// return a.method() -// end -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ("A", toString(requireType("f"))); -// } - TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { check(R"( diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 118863fe..bcd30498 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -26,7 +26,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const FunctionTypeVar* takeTwoType = get(requireType("take_two")); REQUIRE(takeTwoType != nullptr); - const auto& [returns, tail] = flatten(takeTwoType->retType); + const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); CHECK_EQ(typeChecker.numberType, follow(returns[0])); @@ -73,7 +73,7 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); REQUIRE(takeOneMoreType != nullptr); - const auto& [rets, tail] = flatten(takeOneMoreType->retType); + const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); CHECK_EQ(typeChecker.numberType, follow(rets[0])); @@ -105,10 +105,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") LUAU_REQUIRE_NO_ERRORS(result); const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(0, size(fTy->retType)); + CHECK_EQ(0, size(fTy->retTypes)); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") @@ -125,15 +125,15 @@ TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(1, size(follow(fTy->retType))); + CHECK_EQ(1, size(follow(fTy->retTypes))); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); const FunctionTypeVar* hTy = get(requireType("h")); REQUIRE(hTy != nullptr); - CHECK_EQ(0, size(hTy->retType)); + CHECK_EQ(0, size(hTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "varargs_inference_through_multiple_scopes") diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index 5de0140e..b9ea3141 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -6,40 +6,40 @@ - {{ index=0, value={*($T1*)storage} }} - {{ index=1, value={*($T2*)storage} }} - {{ index=2, value={*($T3*)storage} }} - {{ index=3, value={*($T4*)storage} }} - {{ index=4, value={*($T5*)storage} }} - {{ index=5, value={*($T6*)storage} }} - {{ index=6, value={*($T7*)storage} }} - {{ index=7, value={*($T8*)storage} }} - {{ index=8, value={*($T9*)storage} }} - {{ index=9, value={*($T10*)storage} }} - {{ index=10, value={*($T11*)storage} }} - {{ index=11, value={*($T12*)storage} }} - {{ index=12, value={*($T13*)storage} }} - {{ index=13, value={*($T14*)storage} }} - {{ index=14, value={*($T15*)storage} }} - {{ index=15, value={*($T16*)storage} }} - {{ index=16, value={*($T17*)storage} }} - {{ index=17, value={*($T18*)storage} }} - {{ index=18, value={*($T19*)storage} }} - {{ index=19, value={*($T20*)storage} }} - {{ index=20, value={*($T21*)storage} }} - {{ index=21, value={*($T22*)storage} }} - {{ index=22, value={*($T23*)storage} }} - {{ index=23, value={*($T24*)storage} }} - {{ index=24, value={*($T25*)storage} }} - {{ index=25, value={*($T26*)storage} }} - {{ index=26, value={*($T27*)storage} }} - {{ index=27, value={*($T28*)storage} }} - {{ index=28, value={*($T29*)storage} }} - {{ index=29, value={*($T30*)storage} }} - {{ index=30, value={*($T31*)storage} }} - {{ index=31, value={*($T32*)storage} }} + {{ typeId=0, value={*($T1*)storage} }} + {{ typeId=1, value={*($T2*)storage} }} + {{ typeId=2, value={*($T3*)storage} }} + {{ typeId=3, value={*($T4*)storage} }} + {{ typeId=4, value={*($T5*)storage} }} + {{ typeId=5, value={*($T6*)storage} }} + {{ typeId=6, value={*($T7*)storage} }} + {{ typeId=7, value={*($T8*)storage} }} + {{ typeId=8, value={*($T9*)storage} }} + {{ typeId=9, value={*($T10*)storage} }} + {{ typeId=10, value={*($T11*)storage} }} + {{ typeId=11, value={*($T12*)storage} }} + {{ typeId=12, value={*($T13*)storage} }} + {{ typeId=13, value={*($T14*)storage} }} + {{ typeId=14, value={*($T15*)storage} }} + {{ typeId=15, value={*($T16*)storage} }} + {{ typeId=16, value={*($T17*)storage} }} + {{ typeId=17, value={*($T18*)storage} }} + {{ typeId=18, value={*($T19*)storage} }} + {{ typeId=19, value={*($T20*)storage} }} + {{ typeId=20, value={*($T21*)storage} }} + {{ typeId=21, value={*($T22*)storage} }} + {{ typeId=22, value={*($T23*)storage} }} + {{ typeId=23, value={*($T24*)storage} }} + {{ typeId=24, value={*($T25*)storage} }} + {{ typeId=25, value={*($T26*)storage} }} + {{ typeId=26, value={*($T27*)storage} }} + {{ typeId=27, value={*($T28*)storage} }} + {{ typeId=28, value={*($T29*)storage} }} + {{ typeId=29, value={*($T30*)storage} }} + {{ typeId=30, value={*($T31*)storage} }} + {{ typeId=31, value={*($T32*)storage} }} - typeId + typeId *($T1*)storage *($T2*)storage *($T3*)storage From ce9f4e23ae1bb20cf6a1c017f3d15de1c66e97bf Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 21 Jun 2022 13:14:30 -0700 Subject: [PATCH 087/102] Update performance.md (#553) Document function inlining and loop unrolling. --- docs/_pages/performance.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/_pages/performance.md b/docs/_pages/performance.md index 10e2341c..34b24b03 100644 --- a/docs/_pages/performance.md +++ b/docs/_pages/performance.md @@ -185,3 +185,13 @@ While large tables can be a problem for incremental GC in general since currentl The incremental garbage collector in Luau runs three phases for each cycle: mark, atomic and sweep. Mark incrementally traverses all live objects, atomic finishes various operations that need to happen without mutator intervention (see previous section), and sweep traverses all objects in the heap, reclaiming memory used by dead objects and performing minor fixup for live objects. While objects allocated during the mark phase are traversed in the same cycle and thus may get reclaimed, objects allocated during the sweep phase are considered live. Because of this, the faster the sweep phase completes, the less garbage will accumulate; and, of course, the less time sweeping takes the less overhead there is from this phase of garbage collection on the process. Since sweeping traverses the whole heap, we maximize the efficiency of this traversal by allocating garbage-collected objects of the same size in 16 KB pages, and traversing each page at a time, which is otherwise known as a paged sweeper. This ensures good locality of reference as consecutively swept objects are contiugous in memory, and allows us to spend no memory for each object on sweep-related data or allocation metadata, since paged sweeper doesn't need to be able to free objects without knowing which page they are in. Compared to linked list based sweeping that Lua/LuaJIT implement, paged sweeper is 2-3x faster, and saves 16 bytes per object on 64-bit platforms. + +## Function inlining and loop unrolling + +By default, the bytecode compiler performs a series of optimizations that result in faster execution of the code, but they preserve both execution semantics and debuggability. For example, a function call is compiled as a function call, which may be observable via `debug.traceback`; a loop is compiled as a loop, which may be observable via `lua_getlocal`. To help improve performance in cases where these restrictions can be relaxed, the bytecode compiler implements additional optimizations when optimization level 2 is enabled (which requires using `-O2` switch when using Luau CLI), namely function inlining and loop unrolling. + +Only loops with loop bounds known at compile time, such as `for i=1,4 do`, can be unrolled. The loop body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if unrolling should be performed. + +Only local functions (defined either as `local function foo` or `local foo = function`) can be inlined. The function body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if each call to the function should be inlined instead. Additionally recursive invocations of a function can’t be inlined at this time, and inlining is completely disabled for modules that use `getfenv`/`setfenv` functions. + +In both cases, in addition to removing the overhead associated with function calls or loop iteration, these optimizations can additionally benefit by enabling additional optimizations, such as constant folding of expressions dependent on loop iteration variable or constant function arguments, or using more efficient instructions for certain expressions when the inputs to these instructions are constants. From e0ac24d1ed4bb755f9d3b85718f6259fd76245e9 Mon Sep 17 00:00:00 2001 From: Qualadore <93345551+Qualadore@users.noreply.github.com> Date: Wed, 22 Jun 2022 08:01:34 -0800 Subject: [PATCH 088/102] Correct string.find and string.match return types (#554) --- Analysis/src/TypeVar.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 57762937..0f53f990 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -750,13 +750,15 @@ TypeId SingletonTypes::makeStringMetatable() TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}}, + {"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, {"gsub", {gsubFunc}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"lower", {stringToStringType}}, - {"match", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber}, {}, {optionalString})}}, + {"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), + arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, From ca32d1bf9de594736cf52f88cf32a621462db9e3 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Wed, 22 Jun 2022 09:27:05 -0700 Subject: [PATCH 089/102] Update library.md (#555) Fix string.match and string.find type definitions --- docs/_pages/library.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/_pages/library.md b/docs/_pages/library.md index eeada336..ff3075ac 100644 --- a/docs/_pages/library.md +++ b/docs/_pages/library.md @@ -488,7 +488,7 @@ function string.char(args: ...number): string Returns the string that contains a byte for every input number; all inputs must be integers in `[0..255]` range. ``` -function string.find(s: string, p: string, init: number?, plain: boolean?): (number?, number?) +function string.find(s: string, p: string, init: number?, plain: boolean?): (number?, number?, ...string) ``` Tries to find an instance of pattern `p` in the string `s`, starting from position `init` (defaults to 1). When `plain` is true, the search is using raw case-insensitive string equality, otherwise `p` should be a [string pattern](https://www.lua.org/manual/5.3/manual.html#6.4.1). If a match is found, returns the position of the match and the length of the match, followed by the pattern captures; otherwise returns `nil`. @@ -536,7 +536,7 @@ function string.lower(s: string): string Returns a string where each byte corresponds to the lower-case ASCII version of the input byte in the source string. ``` -function string.match(s: string, p: string, init: number?): (number?, number?) +function string.match(s: string, p: string, init: number?): ...string? ``` Tries to find an instance of pattern `p` in the string `s`, starting from position `init` (defaults to 1). `p` should be a [string pattern](https://www.lua.org/manual/5.3/manual.html#6.4.1). If a match is found, returns all pattern captures, or entire matching substring if no captures are present, otherwise returns `nil`. From 778e62c8f74f7cb2f48f119e70195a57b05e56bc Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Wed, 22 Jun 2022 13:15:41 -0500 Subject: [PATCH 090/102] RFC: never and unknown types (#434) Co-authored-by: Alexander McCord <11488393+alexmccord@users.noreply.github.com> --- rfcs/none-and-unknown-types.md | 144 +++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 rfcs/none-and-unknown-types.md diff --git a/rfcs/none-and-unknown-types.md b/rfcs/none-and-unknown-types.md new file mode 100644 index 00000000..0085ef07 --- /dev/null +++ b/rfcs/none-and-unknown-types.md @@ -0,0 +1,144 @@ +# Add never and unknown types + +## Summary + +Add `unknown` and `never` types that are inhabited by everything and nothing respectively. + +## Motivation + +There are lots of cases in local type inference, semantic subtyping, +and type normalization, where it would be useful to have top and +bottom types. Currently, `any` is filling that role, but it has +special "switch off the type system" superpowers. + +Any use of `unknown` must be narrowed by type refinements unless another `unknown` or `any` is expected. For +example a function which can return any value is: + +```lua + function anything() : unknown ... end +``` + +and can be used as: + +```lua + local x = anything() + if type(x) == "number" then + print(x + 1) + end +``` + +The type of this function cannot be given concisely in current +Luau. The nearest equivalent is `any`, but this switches off the type system, for example +if the type of `anything` is `() -> any` then the following code typechecks: + +```lua + local x = anything() + print(x + 1) +``` + +This is fine in nonstrict mode, but strict mode should flag this as an error. + +The `never` type comes up whenever type inference infers incompatible types for a variable, for example + +```lua + function oops(x) + print("hi " .. x) -- constrains x must be a string + print(math.abs(x)) -- constrains x must be a number + end +``` + +The most general type of `x` is `string & number`, so this code gives +a type error, but we still need to provide a type for `oops`. With a +`never` type, we can infer the type `oops : (never) -> ()`. + +or when exhaustive type casing is achieved: + +```lua + function f(x: string | number) + if type(x) == "string" then + -- x : string + elseif type(x) == "number" then + -- x : number + else + -- x : never + end + end +``` + +or even when the type casing is simply nonsensical: + +```lua + function f(x: string | number) + if type(x) == "string" and type(x) == "number" then + -- x : string & number which is never + end + end +``` + +The `never` type is also useful in cases such as tagged unions where +some of the cases are impossible. For example: + +```lua + type Result = { err: false, val: T } | { err: true, err: E } +``` + +For code which we know is successful, we would like to be able to +indicate that the error case is impossible. With a `never` type, we +can do this with `Result`. Similarly, code which cannot succeed +has type `Result`. + +These types can _almost_ be defined in current Luau, but only quite verbosely: + +```lua + type never = number & string + type unknown = nil | number | boolean | string | {} | (...never) -> (...unknown) +``` + +But even for `unknown` it is impossible to include every single data types, e.g. every root class. + +Providing `never` and `unknown` as built-in types makes the code for +type inference simpler, for example we have a way to present a union +type with no options (as `never`). Otherwise we have to contend with ad hoc +corner cases. + +## Design + +Add: + +* a type `never`, inhabited by nothing, and +* a type `unknown`, inhabited by everything. + +And under success types (nonstrict mode), `unknown` is exactly equivalent to `any` because `unknown` +encompasses everything as does `any`. + +The interesting thing is that `() -> (never, string)` is equivalent to `() -> never` because all +values in a pack must be inhabitable in order for the pack itself to also be inhabitable. In fact, +the type `() -> never` is not completely accurate, it should be `() -> (never, ...never)` to avoid +cascading type errors. Ditto for when an expression list `f(), g()` where the resulting type pack is +`(never, string, number)` is still the same as `(never, ...never)`. + +```lua + function f(): never error() end + function g(): string return "" end + + -- no cascading type error where count mismatches, because the expression list f(), g() + -- was made to return (never, ...never) due to the presence of a never type in the pack + local x, y, z = f(), g() + -- x : never + -- y : never + -- z : never +``` + +## Drawbacks + +Another bit of complexity budget spent. + +These types will be visible to creators, so yay bikeshedding! + +Replacing `any` with `unknown` is a breaking change: code in strict mode may now produce errors. + +## Alternatives + +Stick with the current use of `any` for these cases. + +Make `never` and `unknown` type aliases rather than built-ins. From 1757234f0154a893882ad3e927f7dc218273f155 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Wed, 22 Jun 2022 11:16:38 -0700 Subject: [PATCH 091/102] Rename none-and-unknown-types.md to never-and-unknown-types.md This makes the type names match. --- rfcs/{none-and-unknown-types.md => never-and-unknown-types.md} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename rfcs/{none-and-unknown-types.md => never-and-unknown-types.md} (99%) diff --git a/rfcs/none-and-unknown-types.md b/rfcs/never-and-unknown-types.md similarity index 99% rename from rfcs/none-and-unknown-types.md rename to rfcs/never-and-unknown-types.md index 0085ef07..d996afc6 100644 --- a/rfcs/none-and-unknown-types.md +++ b/rfcs/never-and-unknown-types.md @@ -1,4 +1,4 @@ -# Add never and unknown types +# never and unknown types ## Summary From 348ad4d4176f1b8f5154cbbe725171154eafa9cb Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Wed, 22 Jun 2022 12:54:48 -0700 Subject: [PATCH 092/102] Update STATUS.md Add never and unknown types --- rfcs/STATUS.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index d2fe86f0..6bfa865d 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -26,3 +26,9 @@ This document tracks unimplemented RFCs. [RFC: Lower bounds calculation](https://github.com/Roblox/luau/blob/master/rfcs/lower-bounds-calculation.md) **Status**: Implemented but not fully rolled out yet. + +## never and unknown types + +[RFC: never and unknown types](https://github.com/Roblox/luau/blob/master/rfcs/never-and-unknown-types.md) + +**Status**: Needs implementation From 08ab7da4db08cb7457743718971712a1aef4f0a5 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 23 Jun 2022 18:56:00 -0700 Subject: [PATCH 093/102] Sync to upstream/release/533 (#560) --- Analysis/include/Luau/Constraint.h | 14 +- .../include/Luau/ConstraintGraphBuilder.h | 41 ++- Analysis/include/Luau/ConstraintSolver.h | 5 +- .../include/Luau/ConstraintSolverLogger.h | 4 +- Analysis/include/Luau/Error.h | 45 ++- Analysis/include/Luau/IostreamHelpers.h | 1 + Analysis/include/Luau/Module.h | 9 +- Analysis/include/Luau/Normalize.h | 4 +- Analysis/include/Luau/RecursionCounter.h | 15 +- Analysis/include/Luau/Scope.h | 18 + Analysis/include/Luau/TypeVar.h | 1 - Analysis/include/Luau/Unifiable.h | 1 + Analysis/include/Luau/Unifier.h | 4 - Analysis/include/Luau/VisitTypeVar.h | 2 +- Analysis/src/Clone.cpp | 4 +- Analysis/src/Constraint.cpp | 3 +- Analysis/src/ConstraintGraphBuilder.cpp | 285 +++++++++++++-- Analysis/src/ConstraintSolver.cpp | 35 +- Analysis/src/Error.cpp | 93 ++++- Analysis/src/Frontend.cpp | 2 + Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/Module.cpp | 1 - Analysis/src/Quantify.cpp | 9 +- Analysis/src/Scope.cpp | 32 ++ Analysis/src/ToString.cpp | 6 + Analysis/src/TypeChecker2.cpp | 173 +++++++++ Analysis/src/TypeInfer.cpp | 80 ++--- Analysis/src/TypeVar.cpp | 22 +- Analysis/src/Unifiable.cpp | 8 + Analysis/src/Unifier.cpp | 331 +----------------- CLI/Analyze.cpp | 4 + CMakeLists.txt | 9 + Common/include/Luau/Bytecode.h | 19 +- Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 16 +- Compiler/src/Compiler.cpp | 6 +- VM/src/ludata.cpp | 2 + VM/src/lvmload.cpp | 4 +- tests/Compiler.test.cpp | 34 +- tests/Fixture.h | 2 +- tests/Module.test.cpp | 1 - tests/Normalize.test.cpp | 1 - tests/RuntimeLimits.test.cpp | 7 +- tests/ToString.test.cpp | 2 - tests/TypeInfer.aliases.test.cpp | 87 ++++- tests/TypeInfer.annotations.test.cpp | 85 ++++- tests/TypeInfer.generics.test.cpp | 13 - tests/TypeInfer.modules.test.cpp | 20 +- tests/TypeInfer.refinements.test.cpp | 2 - tests/TypeInfer.singletons.test.cpp | 4 - tests/TypeInfer.tables.test.cpp | 73 ---- tests/TypeInfer.test.cpp | 3 - tests/TypeInfer.unionTypes.test.cpp | 6 - tests/VisitTypeVar.test.cpp | 9 +- 54 files changed, 968 insertions(+), 693 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index c62166e2..8a41c9e8 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -1,10 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/Variant.h" +#include #include #include @@ -47,18 +47,24 @@ struct InstantiationConstraint TypeId superType; }; -using ConstraintV = Variant; +// name(namedType) = name +struct NameConstraint +{ + TypeId namedType; + std::string name; +}; + +using ConstraintV = Variant; using ConstraintPtr = std::unique_ptr; struct Constraint { - Constraint(ConstraintV&& c, Location location); + explicit Constraint(ConstraintV&& c); Constraint(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete; ConstraintV c; - Location location; std::vector> dependencies; }; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index da774a2a..9b118691 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -17,20 +17,7 @@ namespace Luau { -struct Scope2 -{ - // The parent scope of this scope. Null if there is no parent (i.e. this - // is the module-level scope). - Scope2* parent = nullptr; - // All the children of this scope. - std::vector children; - std::unordered_map bindings; // TODO: I think this can be a DenseHashMap - TypePackId returnType; - // All constraints belonging to this scope. - std::vector constraints; - - std::optional lookup(Symbol sym); -}; +struct Scope2; struct ConstraintGraphBuilder { @@ -47,6 +34,10 @@ struct ConstraintGraphBuilder // A mapping of AST node to TypePackId. DenseHashMap astTypePacks{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; + // Types resolved from type annotations. Analogous to astTypes. + DenseHashMap astResolvedTypes{nullptr}; + // Type packs resolved from type annotations. Analogous to astTypePacks. + DenseHashMap astResolvedTypePacks{nullptr}; explicit ConstraintGraphBuilder(TypeArena* arena); @@ -73,9 +64,8 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param cv the constraint variant to add. - * @param location the location to attribute to the constraint. */ - void addConstraint(Scope2* scope, ConstraintV cv, Location location); + void addConstraint(Scope2* scope, ConstraintV cv); /** * Adds a constraint to a given scope. @@ -99,6 +89,7 @@ struct ConstraintGraphBuilder void visit(Scope2* scope, AstStatReturn* ret); void visit(Scope2* scope, AstStatAssign* assign); void visit(Scope2* scope, AstStatIf* ifStatement); + void visit(Scope2* scope, AstStatTypeAlias* alias); TypePackId checkExprList(Scope2* scope, const AstArray& exprs); @@ -124,6 +115,24 @@ struct ConstraintGraphBuilder * @param fn the function expression to check. */ void checkFunctionBody(Scope2* scope, AstExprFunction* fn); + + /** + * Resolves a type from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param ty the AST annotation to resolve. + * @return the type of the AST annotation. + **/ + TypeId resolveType(Scope2* scope, AstType* ty); + + /** + * Resolves a type pack from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param tp the AST annotation to resolve. + * @return the type pack of the AST annotation. + **/ + TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp); + + TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list); }; /** diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 7e6d4461..4870157f 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -55,6 +55,7 @@ struct ConstraintSolver bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const NameConstraint& c, NotNull constraint); void block(NotNull target, NotNull constraint); /** @@ -85,7 +86,7 @@ struct ConstraintSolver * @param subType the sub-type to unify. * @param superType the super-type to unify. */ - void unify(TypeId subType, TypeId superType, Location location); + void unify(TypeId subType, TypeId superType); /** * Creates a new Unifier and performs a single unification operation. Commits @@ -93,7 +94,7 @@ struct ConstraintSolver * @param subPack the sub-type pack to unify. * @param superPack the super-type pack to unify. */ - void unify(TypePackId subPack, TypePackId superPack, Location location); + void unify(TypePackId subPack, TypePackId superPack); private: /** diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h index 2b195d71..55336a23 100644 --- a/Analysis/include/Luau/ConstraintSolverLogger.h +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Constraint.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index b4530674..a1323960 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -169,6 +169,13 @@ struct GenericError bool operator==(const GenericError& rhs) const; }; +struct InternalError +{ + std::string message; + + bool operator==(const InternalError& rhs) const; +}; + struct CannotCallNonFunction { TypeId ty; @@ -293,12 +300,12 @@ struct NormalizationTooComplex } }; -using TypeErrorData = - Variant; +using TypeErrorData = Variant; struct TypeError { @@ -339,7 +346,13 @@ T* get(TypeError& e) using ErrorVec = std::vector; +struct TypeErrorToStringOptions +{ + FileResolver* fileResolver = nullptr; +}; + std::string toString(const TypeError& error); +std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); @@ -356,4 +369,24 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; +class InternalCompilerError : public std::exception { +public: + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); + + const std::string message; + const std::string moduleName; + const std::optional location; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index ee994296..05b94516 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -30,6 +30,7 @@ std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error); std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error); std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e); std::ostream& operator<<(std::ostream& lhs, const GenericError& error); +std::ostream& operator<<(std::ostream& lhs, const InternalError& error); std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error); std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error); std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index e979b3f0..39f8dfb7 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -1,10 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Error.h" #include "Luau/FileResolver.h" #include "Luau/ParseOptions.h" -#include "Luau/Error.h" #include "Luau/ParseResult.h" +#include "Luau/Scope.h" #include "Luau/TypeArena.h" #include @@ -19,7 +20,9 @@ struct Module; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; -struct Scope2; + +class AstType; +class AstTypePack; /// Root of the AST of a parsed source file struct SourceModule @@ -73,6 +76,8 @@ struct Module DenseHashMap astExpectedTypes{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; DenseHashMap astOverloadResolvedTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; + DenseHashMap astResolvedTypePacks{nullptr}; std::unordered_map declaredGlobals; ErrorVec errors; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index d4c7698b..f5fd9886 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -9,8 +9,8 @@ namespace Luau struct InternalErrorReporter; -bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); -bool isSubtype(TypePackId superTy, TypePackId subTy, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, InternalErrorReporter& ice); std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 03ae2c83..f964dbfe 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauRecursionLimitException); - namespace Luau { @@ -39,21 +37,12 @@ private: struct RecursionLimiter : RecursionCounter { - // TODO: remove ctx after LuauRecursionLimitException is removed - RecursionLimiter(int* count, int limit, const char* ctx) + RecursionLimiter(int* count, int limit) : RecursionCounter(count) { - LUAU_ASSERT(ctx); if (limit > 0 && *count > limit) { - if (FFlag::LuauRecursionLimitException) - throw RecursionLimitException(); - else - { - std::string m = "Internal recursion counter limit exceeded: "; - m += ctx; - throw std::runtime_error(m); - } + throw RecursionLimitException(); } } }; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 45338409..cef4b94f 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Constraint.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" @@ -64,4 +65,21 @@ struct Scope std::unordered_map typeAliasTypePackParameters; }; +struct Scope2 +{ + // The parent scope of this scope. Null if there is no parent (i.e. this + // is the module-level scope). + Scope2* parent = nullptr; + // All the children of this scope. + std::vector children; + std::unordered_map bindings; // TODO: I think this can be a DenseHashMap + std::unordered_map typeBindings; + TypePackId returnType; + // All constraints belonging to this scope. + std::vector constraints; + + std::optional lookup(Symbol sym); + std::optional lookupTypeBinding(const Name& name); +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index ff7708d4..20f4107c 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -287,7 +287,6 @@ struct FunctionTypeVar bool hasSelf; Tags tags; bool hasNoGenerics = false; - bool generalized = false; }; enum class TableState diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index fdc39481..4ff91714 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -117,6 +117,7 @@ struct Generic explicit Generic(const Name& name); explicit Generic(Scope2* scope); Generic(TypeLevel level, const Name& name); + Generic(Scope2* scope, const Name& name); int index; TypeLevel level; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index b51a485e..4af324cb 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -79,12 +79,8 @@ private: void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); - void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); - void tryUnifyFreeTable(TypeId subTy, TypeId superTy); - void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); - void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId widen(TypeId ty); TypePackId widen(TypePackId tp); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 642522c9..5fd43f0b 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -169,7 +169,7 @@ struct GenericTypeVarVisitor void traverse(TypeId ty) { - RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit, "TypeVarVisitor"}; + RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit}; if (visit_detail::hasSeen(seen, ty)) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 248262ce..df4e0a6b 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -317,7 +317,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypePackId"); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); TypePackId& res = cloneState.seenTypePacks[tp]; @@ -335,7 +335,7 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypeId"); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); TypeId& res = cloneState.seenTypes[typeId]; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 6cb0e4ee..64e3a666 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -5,9 +5,8 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c, Location location) +Constraint::Constraint(ConstraintV&& c) : c(std::move(c)) - , location(location) { } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index fa627e7a..d9e8d238 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -2,28 +2,13 @@ #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Scope.h" + namespace Luau { const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp -std::optional Scope2::lookup(Symbol sym) -{ - Scope2* s = this; - - while (true) - { - auto it = s->bindings.find(sym); - if (it != s->bindings.end()) - return it->second; - - if (s->parent) - s = s->parent; - else - return std::nullopt; - } -} - ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) : singletonTypes(getSingletonTypes()) , arena(arena) @@ -59,10 +44,10 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv, Location location) +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) { LUAU_ASSERT(scope); - scope->constraints.emplace_back(new Constraint{std::move(cv), location}); + scope->constraints.emplace_back(new Constraint{std::move(cv)}); } void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) @@ -79,6 +64,13 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) rootScope = scopes.back().second.get(); rootScope->returnType = freshTypePack(rootScope); + // TODO: We should share the global scope. + rootScope->typeBindings["nil"] = singletonTypes.nilType; + rootScope->typeBindings["number"] = singletonTypes.numberType; + rootScope->typeBindings["string"] = singletonTypes.stringType; + rootScope->typeBindings["boolean"] = singletonTypes.booleanType; + rootScope->typeBindings["thread"] = singletonTypes.threadType; + visit(rootScope, block); } @@ -102,6 +94,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) checkPack(scope, e->expr); else if (auto i = stat->as()) visit(scope, i); + else if (auto a = stat->as()) + visit(scope, a); else LUAU_ASSERT(0); } @@ -114,8 +108,14 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) for (AstLocal* local : local->vars) { - // TODO annotations TypeId ty = freshType(scope); + + if (local->annotation) + { + TypeId annotation = resolveType(scope, local->annotation); + addConstraint(scope, SubtypeConstraint{ty, annotation}); + } + varTypes.push_back(ty); scope->bindings[local] = ty; } @@ -136,14 +136,14 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) { std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}, local->location); + addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}); } } else { TypeId exprType = check(scope, local->values.data[i]); if (i < varTypes.size()) - addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}, local->vars.data[i]->location); + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); } } } @@ -188,7 +188,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -240,7 +240,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -251,13 +251,26 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) LUAU_ASSERT(scope); TypePackId exprTypes = checkPack(scope, ret->list); - addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}, ret->location); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) { LUAU_ASSERT(scope); + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. Since we're not ready to actually resolve + // any of the annotations, we just use a fresh type for now. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + TypeId initialType = freshType(scope); + scope->typeBindings[alias->name.value] = initialType; + } + } + for (AstStat* stat : block->body) visit(scope, stat); } @@ -267,7 +280,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) TypePackId varPackId = checkExprList(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); - addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}, assign->location); + addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) @@ -284,6 +297,28 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) } } +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) +{ + // TODO: Exported type aliases + // TODO: Generic type aliases + + auto it = scope->typeBindings.find(alias->name.value); + // This should always be here since we do a separate pass over the + // AST to set up typeBindings. If it's not, we've somehow skipped + // this alias in that first pass. + LUAU_ASSERT(it != scope->typeBindings.end()); + + TypeId ty = resolveType(scope, alias->type); + + // Rather than using a subtype constraint, we instead directly bind + // the free type we generated in the first pass to the resolved type. + // This prevents a case where you could cause another constraint to + // bind the free alias type to an unrelated type, causing havoc. + asMutable(it->second)->ty.emplace(ty); + + addConstraint(scope, NameConstraint{ty, alias->name.value}); +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) { LUAU_ASSERT(scope); @@ -350,13 +385,13 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) astOriginalCallTypes[call->func] = fnType; TypeId instantiatedType = freshType(scope); - addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}, expr->location); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); TypePackId rets = freshTypePack(scope); FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); TypeId inferredFnType = arena->addType(ftv); - addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}, expr->location); + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); result = rets; } else @@ -413,7 +448,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) TypePack onePack{{typeResult}, freshTypePack(scope)}; TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}, expr->location); + addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}); return typeResult; } @@ -454,7 +489,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) TypeId expectedTableType = arena->addType(std::move(ttv)); - addConstraint(scope, SubtypeConstraint{obj, expectedTableType}, indexName->location); + addConstraint(scope, SubtypeConstraint{obj, expectedTableType}); return result; } @@ -465,8 +500,7 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) TableTypeVar* ttv = getMutable(ty); LUAU_ASSERT(ttv); - auto createIndexer = [this, scope, ttv]( - TypeId currentIndexType, TypeId currentResultType, Location itemLocation, std::optional keyLocation) { + auto createIndexer = [this, scope, ttv](TypeId currentIndexType, TypeId currentResultType) { if (!ttv->indexer) { TypeId indexType = this->freshType(scope); @@ -474,8 +508,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) ttv->indexer = TableIndexer{indexType, resultType}; } - addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}, keyLocation ? *keyLocation : itemLocation); - addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}, itemLocation); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); }; for (const AstExprTable::Item& item : expr->items) @@ -495,13 +529,13 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) } else { - createIndexer(keyTy, itemTy, item.value->location, item.key->location); + createIndexer(keyTy, itemTy); } } else { TypeId numberType = singletonTypes.numberType; - createIndexer(numberType, itemTy, item.value->location, std::nullopt); + createIndexer(numberType, itemTy); } } @@ -514,15 +548,29 @@ std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2 TypePackId returnType = freshTypePack(innerScope); innerScope->returnType = returnType; + if (fn->returnAnnotation) + { + TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation); + addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType}); + } + std::vector argTypes; for (AstLocal* local : fn->args) { TypeId t = freshType(innerScope); argTypes.push_back(t); - innerScope->bindings[local] = t; // TODO annotations + innerScope->bindings[local] = t; + + if (local->annotation) + { + TypeId argAnnotation = resolveType(innerScope, local->annotation); + addConstraint(innerScope, SubtypeConstraint{t, argAnnotation}); + } } + // TODO: Vararg annotation. + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); @@ -541,10 +589,171 @@ void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* f if (nullptr != getFallthrough(fn->body)) { TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever - addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}, fn->body->location); + addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}); } } +TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) +{ + TypeId result = nullptr; + + if (auto ref = ty->as()) + { + // TODO: Support imported types w/ require tracing. + // TODO: Support generic type references. + LUAU_ASSERT(!ref->prefix); + LUAU_ASSERT(!ref->hasParameterList); + + // TODO: If it doesn't exist, should we introduce a free binding? + // This is probably important for handling type aliases. + result = scope->lookupTypeBinding(ref->name.value).value_or(singletonTypes.errorRecoveryType()); + } + else if (auto tab = ty->as()) + { + TableTypeVar::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + std::string name = prop.name.value; + // TODO: Recursion limit. + TypeId propTy = resolveType(scope, prop.type); + // TODO: Fill in location. + props[name] = {propTy}; + } + + if (tab->indexer) + { + // TODO: Recursion limit. + indexer = TableIndexer{ + resolveType(scope, tab->indexer->indexType), + resolveType(scope, tab->indexer->resultType), + }; + } + + // TODO: Remove TypeLevel{} here, we don't need it. + result = arena->addType(TableTypeVar{props, indexer, TypeLevel{}, TableState::Sealed}); + } + else if (auto fn = ty->as()) + { + // TODO: Generic functions. + // TODO: Scope (though it may not be needed). + // TODO: Recursion limit. + TypePackId argTypes = resolveTypePack(scope, fn->argTypes); + TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes); + + // TODO: Is this the right constructor to use? + result = arena->addType(FunctionTypeVar{argTypes, returnTypes}); + + FunctionTypeVar* ftv = getMutable(result); + ftv->argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv->argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv->argNames.push_back(std::nullopt); + } + } + } + else if (auto tof = ty->as()) + { + // TODO: Recursion limit. + TypeId exprType = check(scope, tof->expr); + result = exprType; + } + else if (auto unionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : unionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part)); + } + + result = arena->addType(UnionTypeVar{parts}); + } + else if (auto intersectionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : intersectionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part)); + } + + result = arena->addType(IntersectionTypeVar{parts}); + } + else if (auto boolAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(BooleanSingleton{boolAnnotation->value})); + } + else if (auto stringAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + } + else if (ty->is()) + { + result = singletonTypes.errorRecoveryType(); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes.errorRecoveryType(); + } + + astResolvedTypes[ty] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp) +{ + TypePackId result; + if (auto expl = tp->as()) + { + result = resolveTypePack(scope, expl->typeList); + } + else if (auto var = tp->as()) + { + TypeId ty = resolveType(scope, var->variadicType); + result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); + } + else if (auto gen = tp->as()) + { + result = arena->addTypePack(TypePackVar{GenericTypePack{scope, gen->genericName.value}}); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes.errorRecoveryTypePack(); + } + + astResolvedTypePacks[tp] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list) +{ + std::vector head; + + for (AstType* headTy : list.types) + { + head.push_back(resolveType(scope, headTy)); + } + + std::optional tail = std::nullopt; + if (list.tailType) + { + tail = resolveTypePack(scope, list.tailType); + } + + return arena->addTypePack(TypePack{head, tail}); +} + void collectConstraints(std::vector>& result, Scope2* scope) { for (const auto& c : scope->constraints) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 41dfd892..9e355236 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -2,6 +2,7 @@ #include "Luau/ConstraintSolver.h" #include "Luau/Instantiation.h" +#include "Luau/Location.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" #include "Luau/Unifier.h" @@ -179,6 +180,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); + else if (auto nc = get(*constraint)) + success = tryDispatch(*nc, constraint); else LUAU_ASSERT(0); @@ -197,7 +200,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNulllocation); + unify(c.subType, c.superType); unblock(c.subType); unblock(c.superType); @@ -207,7 +210,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subPack, c.superPack, constraint->location); + unify(c.subPack, c.superPack); unblock(c.subPack); unblock(c.superPack); @@ -222,7 +225,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullty.emplace(c.sourceType); else - unify(c.generalizedType, c.sourceType, constraint->location); + unify(c.generalizedType, c.sourceType); TypeId generalized = quantify(arena, c.sourceType, c.scope); *asMutable(c.sourceType) = *generalized; @@ -243,12 +246,28 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated, constraint->location); + unify(c.subType, *instantiated); unblock(c.subType); return true; } +bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) +{ + if (isBlocked(c.namedType)) + return block(c.namedType, constraint); + + TypeId target = follow(c.namedType); + if (TableTypeVar* ttv = getMutable(target)) + ttv->name = c.name; + else if (MetatableTypeVar* mtv = getMutable(target)) + mtv->syntheticName = c.name; + else + return block(c.namedType, constraint); + + return true; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -321,19 +340,19 @@ bool ConstraintSolver::isBlocked(NotNull constraint) return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::unify(TypeId subType, TypeId superType, Location location) +void ConstraintSolver::unify(TypeId subType, TypeId superType) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; u.tryUnify(subType, superType); u.log.commit(); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, Location location) +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; u.tryUnify(subPack, superPack); u.log.commit(); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index f443a3cc..93cb65b9 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) +LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false) + static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -49,6 +52,8 @@ namespace Luau struct ErrorConverter { + FileResolver* fileResolver = nullptr; + std::string operator()(const Luau::TypeMismatch& tm) const { std::string givenTypeName = Luau::toString(tm.givenType); @@ -62,8 +67,18 @@ struct ErrorConverter { if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; + if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) + { + std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); + std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); + result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + + "' from '" + wantedModuleName + "'"; + } + else + { + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; + } } } } @@ -78,7 +93,14 @@ struct ErrorConverter if (!tm.reason.empty()) result += tm.reason + " "; - result += Luau::toString(*tm.error); + if (FFlag::LuauTypeMismatchModuleNameResolution) + { + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); + } + else + { + result += Luau::toString(*tm.error); + } } else if (!tm.reason.empty()) { @@ -280,6 +302,11 @@ struct ErrorConverter return e.message; } + std::string operator()(const Luau::InternalError& e) const + { + return e.message; + } + std::string operator()(const Luau::CannotCallNonFunction& e) const { return "Cannot call non-function " + toString(e.ty); @@ -598,6 +625,11 @@ bool GenericError::operator==(const GenericError& rhs) const return message == rhs.message; } +bool InternalError::operator==(const InternalError& rhs) const +{ + return message == rhs.message; +} + bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const { return ty == rhs.ty; @@ -685,7 +717,12 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const std::string toString(const TypeError& error) { - ErrorConverter converter; + return toString(error, TypeErrorToStringOptions{}); +} + +std::string toString(const TypeError& error, TypeErrorToStringOptions options) +{ + ErrorConverter converter{options.fileResolver}; return Luau::visit(converter, error.data); } @@ -773,6 +810,9 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.ty = clone(e.ty); @@ -847,22 +887,51 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) void InternalErrorReporter::ice(const std::string& message, const Location& location) { - std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + if (FFlag::LuauUseInternalCompilerErrorException) + { + InternalCompilerError error(message, moduleName, location); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; + throw error; + } + else + { + std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; + } } void InternalErrorReporter::ice(const std::string& message) { - std::runtime_error error("Internal error in " + moduleName + ": " + message); + if (FFlag::LuauUseInternalCompilerErrorException) + { + InternalCompilerError error(message, moduleName); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; + throw error; + } + else + { + std::runtime_error error("Internal error in " + moduleName + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; + } +} + +const char* InternalCompilerError::what() const throw() +{ + return this->message.data(); } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 9e025062..85c5dbc8 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -801,6 +801,8 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco result->astTypes = std::move(cgb.astTypes); result->astTypePacks = std::move(cgb.astTypePacks); result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); + result->astResolvedTypes = std::move(cgb.astResolvedTypes); + result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); result->clonePublicInterface(iceHandler); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 048167ae..e4fac455 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -111,6 +111,8 @@ static void errorToString(std::ostream& stream, const T& err) } else if constexpr (std::is_same_v) stream << "GenericError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "InternalError { " << err.message << " }"; else if constexpr (std::is_same_v) stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; else if constexpr (std::is_same_v) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 4d157e6f..95eb125e 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -11,7 +11,6 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/ConstraintGraphBuilder.h" // FIXME: For Scope2 TODO pull out into its own header #include diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 2004d153..40e14c68 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -2,11 +2,10 @@ #include "Luau/Quantify.h" -#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header -#include "Luau/TxnLog.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" +#include "Luau/TxnLog.h" #include "Luau/VisitTypeVar.h" -#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header LUAU_FASTFLAG(LuauAlwaysQuantify); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); @@ -177,8 +176,6 @@ void quantify(TypeId ty, TypeLevel level) if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; - - ftv->generalized = true; } void quantify(TypeId ty, Scope2* scope) @@ -201,8 +198,6 @@ void quantify(TypeId ty, Scope2* scope) if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; - - ftv->generalized = true; } struct PureQuantifier : Substitution diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 011e28d4..66aaee1f 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -121,4 +121,36 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +std::optional Scope2::lookup(Symbol sym) +{ + Scope2* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return it->second; + + if (s->parent) + s = s->parent; + else + return std::nullopt; + } +} + +std::optional Scope2::lookupTypeBinding(const Name& name) +{ + Scope2* s = this; + while (s) + { + auto it = s->typeBindings.find(name); + if (it != s->typeBindings.end()) + return it->second; + + s = s->parent; + } + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 81dc0467..7a458964 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1411,6 +1411,12 @@ std::string toString(const Constraint& c, ToStringOptions& opts) opts.nameMap = std::move(superStr.nameMap); return subStr.name + " ~ inst " + superStr.name; } + else if (const NameConstraint* nc = Luau::get(c)) + { + ToStringResult namedStr = toStringDetailed(nc->namedType, opts); + opts.nameMap = std::move(namedStr.nameMap); + return "@name(" + namedStr.name + ") = " + nc->name; + } else { LUAU_ASSERT(false); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 7f5ba683..63e5800f 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -7,6 +7,9 @@ #include "Luau/AstQuery.h" #include "Luau/Clone.h" #include "Luau/Normalize.h" +#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header +#include "Luau/Unifier.h" +#include "Luau/ToString.h" namespace Luau { @@ -39,6 +42,104 @@ struct TypeChecker2 : public AstVisitor return follow(*ty); } + TypeId lookupAnnotation(AstType* annotation) + { + TypeId* ty = module->astResolvedTypes.find(annotation); + LUAU_ASSERT(ty); + return follow(*ty); + } + + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) + { + std::vector head; + + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i])); + } + + TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); + return arena.addTypePack(TypePack{head, tail}); + } + + Scope2* findInnermostScope(Location location) + { + Scope2* bestScope = module->getModuleScope2(); + Location bestLocation = module->scope2s[0].first; + + for (size_t i = 0; i < module->scope2s.size(); ++i) + { + auto& [scopeBounds, scope] = module->scope2s[i]; + if (scopeBounds.encloses(location)) + { + if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) + { + bestScope = scope.get(); + bestLocation = scopeBounds; + } + } + else + { + // TODO: Is this sound? This relies on the fact that scopes are inserted + // into the scope list in the order that they appear in the AST. + break; + } + } + + return bestScope; + } + + bool visit(AstStatLocal* local) override + { + for (size_t i = 0; i < local->values.size; ++i) + { + AstExpr* value = local->values.data[i]; + if (i == local->values.size - 1) + { + if (i < local->values.size) + { + TypePackId valueTypes = lookupPack(value); + auto it = begin(valueTypes); + for (size_t j = i; j < local->vars.size; ++j) + { + if (it == end(valueTypes)) + { + break; + } + + AstLocal* var = local->vars.data[i]; + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + if (!isSubtype(*it, varType, ice)) + { + reportError(TypeMismatch{varType, *it}, value->location); + } + } + + ++it; + } + } + } + else + { + TypeId valueType = lookupType(value); + AstLocal* var = local->vars.data[i]; + + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + if (!isSubtype(varType, valueType, ice)) + { + reportError(TypeMismatch{varType, valueType}, value->location); + } + } + } + } + + return true; + } + bool visit(AstStatAssign* assign) override { size_t count = std::min(assign->vars.size, assign->values.size); @@ -62,6 +163,30 @@ struct TypeChecker2 : public AstVisitor return true; } + bool visit(AstStatReturn* ret) override + { + Scope2* scope = findInnermostScope(ret->location); + TypePackId expectedRetType = scope->returnType; + + TypeArena arena; + TypePackId actualRetType = reconstructPack(ret->list, arena); + + UnifierSharedState sharedState{&ice}; + Unifier u{&arena, Mode::Strict, ret->location, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(actualRetType, expectedRetType); + const bool ok = u.errors.empty() && u.log.empty(); + + if (!ok) + { + for (const TypeError& e : u.errors) + module->errors.push_back(e); + } + + return true; + } + bool visit(AstExprCall* call) override { TypePackId expectedRetType = lookupPack(call); @@ -91,6 +216,35 @@ struct TypeChecker2 : public AstVisitor return true; } + bool visit(AstExprFunction* fn) override + { + TypeId inferredFnTy = lookupType(fn); + const FunctionTypeVar* inferredFtv = get(inferredFnTy); + LUAU_ASSERT(inferredFtv); + + auto argIt = begin(inferredFtv->argTypes); + for (const auto& arg : fn->args) + { + if (argIt == end(inferredFtv->argTypes)) + break; + + if (arg->annotation) + { + TypeId inferredArgTy = *argIt; + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + if (!isSubtype(annotatedArgTy, inferredArgTy, ice)) + { + reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); + } + } + + ++argIt; + } + + return true; + } + bool visit(AstExprIndexName* indexName) override { TypeId leftType = lookupType(indexName->expr); @@ -144,6 +298,25 @@ struct TypeChecker2 : public AstVisitor return true; } + bool visit(AstType* ty) override + { + return true; + } + + bool visit(AstTypeReference* ty) override + { + Scope2* scope = findInnermostScope(ty->location); + + // TODO: Imported types + // TODO: Generic types + if (!scope->lookupTypeBinding(ty->name.value)) + { + reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + } + + return true; + } + void reportError(TypeErrorData&& data, const Location& location) { module->errors.emplace_back(location, sourceModule->name, std::move(data)); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index fd1b3b85..44635e88 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,13 +35,9 @@ LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) -LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) -LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); -LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) LUAU_FASTFLAG(LuauQuantifyConstrained) @@ -275,22 +271,15 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { - if (FFlag::LuauRecursionLimitException) - { - try - { - return checkWithoutRecursionCheck(module, mode, environmentScope); - } - catch (const RecursionLimitException&) - { - reportErrorCodeTooComplex(module.root->location); - return std::move(currentModule); - } - } - else + try { return checkWithoutRecursionCheck(module, mode, environmentScope); } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } } ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -445,22 +434,15 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } - if (FFlag::LuauRecursionLimitException) - { - try - { - checkBlockWithoutRecursionCheck(scope, block); - } - catch (const RecursionLimitException&) - { - reportErrorCodeTooComplex(block.location); - return; - } - } - else + try { checkBlockWithoutRecursionCheck(scope, block); } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } } void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) @@ -1917,7 +1899,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : utv) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeForType unions"); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); // Not needed when we normalize types. if (get(follow(t))) @@ -1967,7 +1949,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : itv->parts) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeFromType intersections"); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) parts.push_back(*ty); @@ -2190,7 +2172,7 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; + TableState state = TableState::Unsealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); @@ -5175,9 +5157,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack bool ApplyTypeFunction::isDirty(TypeId ty) { - if (FFlag::LuauApplyTypeFunctionFix && typeArguments.count(ty)) - return true; - else if (!FFlag::LuauApplyTypeFunctionFix && get(ty)) + if (typeArguments.count(ty)) return true; else if (const FreeTypeVar* ftv = get(ty)) { @@ -5191,9 +5171,7 @@ bool ApplyTypeFunction::isDirty(TypeId ty) bool ApplyTypeFunction::isDirty(TypePackId tp) { - if (FFlag::LuauApplyTypeFunctionFix && typePackArguments.count(tp)) - return true; - else if (!FFlag::LuauApplyTypeFunctionFix && get(tp)) + if (typePackArguments.count(tp)) return true; else return false; @@ -5218,29 +5196,15 @@ bool ApplyTypeFunction::ignoreChildren(TypePackId tp) TypeId ApplyTypeFunction::clean(TypeId ty) { TypeId& arg = typeArguments[ty]; - if (FFlag::LuauApplyTypeFunctionFix) - { - LUAU_ASSERT(arg); - return arg; - } - else if (arg) - return arg; - else - return addType(FreeTypeVar{level}); + LUAU_ASSERT(arg); + return arg; } TypePackId ApplyTypeFunction::clean(TypePackId tp) { TypePackId& arg = typePackArguments[tp]; - if (FFlag::LuauApplyTypeFunctionFix) - { - LUAU_ASSERT(arg); - return arg; - } - else if (arg) - return arg; - else - return addTypePack(FreeTypePack{level}); + LUAU_ASSERT(arg); + return arg; } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -5273,7 +5237,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; - bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); + bool shouldMutate = getTableType(tf.type); TableTypeVar* ttv = getMutableTableType(target); if (shouldMutate && ttv && needsClone) diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 0f53f990..ade70d72 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau @@ -172,22 +171,15 @@ bool isString(TypeId ty) // Returns true when ty is a supertype of string bool maybeString(TypeId ty) { - if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) - { - ty = follow(ty); + ty = follow(ty); - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) - return true; + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; - if (auto utv = get(ty)) - return std::any_of(begin(utv), end(utv), maybeString); + if (auto utv = get(ty)) + return std::any_of(begin(utv), end(utv), maybeString); - return false; - } - else - { - return isString(ty); - } + return false; } bool isThread(TypeId ty) @@ -369,7 +361,7 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit, "hasLength"); + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); ty = follow(ty); diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index fe878358..8d23aa49 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -53,6 +53,14 @@ Generic::Generic(TypeLevel level, const Name& name) { } +Generic::Generic(Scope2* scope, const Name& name) + : index(++nextIndex) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + int Generic::nextIndex = 0; Error::Error() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 877663de..6147e118 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,11 +17,8 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); -LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -354,7 +351,7 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypeId tryUnify_"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); ++sharedState.counters.iterationCount; @@ -983,7 +980,7 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypePackId tryUnify_"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); ++sharedState.counters.iterationCount; @@ -1316,12 +1313,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(subFunction->retTypes, superFunction->retTypes); } - if (FFlag::LuauTxnLogRefreshFunctionPointers) - { - // Updating the log may have invalidated the function pointers - superFunction = log.getMutable(superTy); - subFunction = log.getMutable(subTy); - } + // Updating the log may have invalidated the function pointers + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); ctx = context; @@ -1360,9 +1354,6 @@ struct Resetter void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance2) - return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -1379,8 +1370,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && - !isOptional(superProp.type)) + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) missingProperties.push_back(propName); } @@ -1398,7 +1388,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) + if (superIter == superTable->props.end()) extraProperties.push_back(propName); } @@ -1443,7 +1433,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) @@ -1512,9 +1502,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) - { - } else if (superTable->state == TableState::Free) { PendingType* pendingSuper = log.queue(superTy); @@ -1639,296 +1626,6 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } -void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - Resetter resetter{&variance}; - variance = Invariant; - - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!superTable || !subTable) - ice("passed non-table types to unifyTables"); - - if (superTable->state == TableState::Sealed && subTable->state == TableState::Sealed) - return tryUnifySealedTables(subTy, superTy, isIntersection); - else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Unsealed) || - (superTable->state == TableState::Unsealed && subTable->state == TableState::Sealed)) - return tryUnifySealedTables(subTy, superTy, isIntersection); - else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || - (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not - { - TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; - TypeId otherTypeId = subTable->state == TableState::Free ? superTy : subTy; - - return tryUnifyFreeTable(otherTypeId, freeTypeId); - } - else if (superTable->state == TableState::Free && subTable->state == TableState::Free) - { - tryUnifyFreeTable(subTy, superTy); - - // avoid creating a cycle when the types are already pointing at each other - if (follow(superTy) != follow(subTy)) - { - log.bindTable(superTy, subTy); - } - return; - } - else if (superTable->state != TableState::Sealed && subTable->state != TableState::Sealed) - { - // All free tables are checked in one of the branches above - LUAU_ASSERT(superTable->state != TableState::Free); - LUAU_ASSERT(subTable->state != TableState::Free); - - // Tables must have exactly the same props and their types must all unify - // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : superTable->props) - { - const auto& r = subTable->props.find(name); - if (r == subTable->props.end()) - reportError(TypeError{location, UnknownProperty{subTy, name}}); - else - tryUnify_(r->second.type, prop.type); - } - - if (superTable->indexer && subTable->indexer) - tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (superTable->indexer) - { - // passing/assigning a table without an indexer to something that has one - // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (subTable->state == TableState::Unsealed) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); - } - } - else if (superTable->state == TableState::Sealed) - { - // lt is sealed and so it must be possible for rt to have precisely the same shape - // Verify that this is the case, then bind rt to lt. - ice("unsealed tables are not working yet", location); - } - else if (subTable->state == TableState::Sealed) - return tryUnifyTables(superTy, subTy, isIntersection); - else - ice("tryUnifyTables"); -} - -void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - TableTypeVar* freeTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!freeTable || !subTable) - ice("passed non-table types to tryUnifyFreeTable"); - - // Any properties in freeTable must unify with those in otherTable. - // Then bind freeTable to otherTable. - for (const auto& [freeName, freeProp] : freeTable->props) - { - if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) - { - tryUnify_(*subProp, freeProp.type); - - /* - * TypeVars are commonly cyclic, so it is entirely possible - * for unifying a property of a table to change the table itself! - * We need to check for this and start over if we notice this occurring. - * - * I believe this is guaranteed to terminate eventually because this will - * only happen when a free table is bound to another table. - */ - if (!log.getMutable(superTy) || !log.getMutable(subTy)) - return tryUnify_(subTy, superTy); - - if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) - return tryUnify_(subTy, superTy); - } - else - { - // If the other table is also free, then we are learning that it has more - // properties than we previously thought. Else, it is an error. - if (subTable->state == TableState::Free) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* pendingSubTtv = getMutable(pendingSub); - LUAU_ASSERT(pendingSubTtv); - pendingSubTtv->props.insert({freeName, freeProp}); - } - else - reportError(TypeError{location, UnknownProperty{subTy, freeName}}); - } - } - - if (freeTable->indexer && subTable->indexer) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnifyIndexer(*subTable->indexer, *freeTable->indexer); - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - - log.concat(std::move(innerState.log)); - } - else if (subTable->state == TableState::Free && freeTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - - if (!freeTable->boundTo && subTable->state != TableState::Free) - { - log.bindTable(superTy, subTy); - } -} - -void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!superTable || !subTable) - ice("passed non-table types to unifySealedTables"); - - std::vector missingPropertiesInSuper; - bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; - bool errorReported = false; - - // Optimization: First test that the property sets are compatible without doing any recursive unification - if (!subTable->indexer) - { - for (const auto& [propName, superProp] : superTable->props) - { - auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && !isOptional(superProp.type)) - missingPropertiesInSuper.push_back(propName); - } - - if (!missingPropertiesInSuper.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); - return; - } - } - - Unifier innerState = makeChildUnifier(); - - // Tables must have exactly the same props and their types must all unify - for (const auto& it : superTable->props) - { - const auto& r = subTable->props.find(it.first); - if (r == subTable->props.end()) - { - if (isOptional(it.second.type)) - continue; - - missingPropertiesInSuper.push_back(it.first); - - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } - else - { - if (isUnnamedTable && r->second.location) - { - size_t oldErrorSize = innerState.errors.size(); - Location old = innerState.location; - innerState.location = *r->second.location; - innerState.tryUnify_(r->second.type, it.second.type); - innerState.location = old; - - if (oldErrorSize != innerState.errors.size() && !errorReported) - { - errorReported = true; - reportError(innerState.errors.back()); - } - } - else - { - innerState.tryUnify_(r->second.type, it.second.type); - } - } - } - - if (superTable->indexer || subTable->indexer) - { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) - { - if (superTable->indexer && !subTable->indexer) - { - log.changeIndexer(subTy, superTable->indexer); - } - } - else if (superTable->state == TableState::Unsealed) - { - if (subTable->indexer && !superTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - } - else if (superTable->indexer) - { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } - } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } - - if (!errorReported) - log.concat(std::move(innerState.log)); - else - return; - - if (!missingPropertiesInSuper.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); - return; - } - - // If the superTy is an immediate part of an intersection type, do not do extra-property check. - // Otherwise, we would falsely generate an extra-property-error for 's' in this code: - // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking against the table '{n: number}'. - if (!isIntersection && superTable->state != TableState::Unsealed && !superTable->indexer) - { - // Check for extra properties in the subTy - std::vector extraPropertiesInSub; - - for (const auto& [subKey, subProp] : subTable->props) - { - const auto& superIt = superTable->props.find(subKey); - if (superIt == superTable->props.end()) - { - if (isOptional(subProp.type)) - continue; - - extraPropertiesInSub.push_back(subKey); - } - } - - if (!extraPropertiesInSub.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); - return; - } - } - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); -} - void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { const MetatableTypeVar* superMetatable = get(superTy); @@ -2068,14 +1765,6 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) return fail(); } -void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - - tryUnify_(subIndexer.indexType, superIndexer.indexType); - tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); -} - static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) @@ -2435,7 +2124,7 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypeId"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); @@ -2506,7 +2195,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypePackId"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); while (!log.getMutable(haystack)) { diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 8b03ea1a..81db7c35 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -9,6 +9,7 @@ #include "FileUtils.h" LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) enum class ReportFormat { @@ -49,6 +50,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); + else if (FFlag::LuauTypeMismatchModuleNameResolution) + report(format, humanReadableName.c_str(), error.location, "TypeError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); else report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } diff --git a/CMakeLists.txt b/CMakeLists.txt index c624a132..e256e234 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ option(LUAU_BUILD_TESTS "Build tests" ON) option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) +option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) if(LUAU_STATIC_CRT) cmake_minimum_required(VERSION 3.15) @@ -115,6 +116,14 @@ target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) +if(LUAU_EXTERN_C) + # enable extern "C" for VM (lua.h, lualib.h) and Compiler (luacode.h) to make Luau friendlier to use from non-C++ languages + # note that we enable LUA_USE_LONGJMP=1 as well; otherwise functions like luaL_error will throw C++ exceptions, which can't be done from extern "C" functions + target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) + target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") + target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") +endif() + if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index f71d893c..218bb5d5 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -7,7 +7,7 @@ // Creating the bytecode is outside the scope of this file and is handled by bytecode builder (BytecodeBuilder.h) and bytecode compiler (Compiler.h) // Note that ALL enums declared in this file are order-sensitive since the values are baked into bytecode that needs to be processed by legacy clients. -// Bytecode definitions +// # Bytecode definitions // Bytecode instructions are using "word code" - each instruction is one or many 32-bit words. // The first word in the instruction is always the instruction header, and *must* contain the opcode (enum below) in the least significant byte. // @@ -19,7 +19,7 @@ // Instruction word is sometimes followed by one extra word, indicated as AUX - this is just a 32-bit word and is decoded according to the specification for each opcode. // For each opcode the encoding is *static* - that is, based on the opcode you know a-priory how large the instruction is, with the exception of NEWCLOSURE -// Bytecode indices +// # Bytecode indices // Bytecode instructions commonly refer to integer values that define offsets or indices for various entities. For each type, there's a maximum encodable value. // Note that in some cases, the compiler will set a lower limit than the maximum encodable value is to prevent fragile code into bumping against the limits whenever we change the compilation details. // Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected. @@ -29,6 +29,15 @@ // Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. // Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. // Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. + +// # Bytecode versions +// Bytecode serialized format embeds a version number, that dictates both the serialized form as well as the allowed instructions. As long as the bytecode version falls into supported +// range (indicated by LBC_BYTECODE_MIN / LBC_BYTECODE_MAX) and was produced by Luau compiler, it should load and execute correctly. +// +// Note that Luau runtime doesn't provide indefinite bytecode compatibility: support for older versions gets removed over time. As such, bytecode isn't a durable storage format and it's expected +// that Luau users can recompile bytecode from source on Luau version upgrades if necessary. + +// Bytecode opcode, part of the instruction header enum LuauOpcode { // NOP: noop @@ -380,8 +389,10 @@ enum LuauOpcode // Bytecode tags, used internally for bytecode encoded as a string enum LuauBytecodeTag { - // Bytecode version - LBC_VERSION = 2, + // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled + LBC_VERSION_MIN = 2, + LBC_VERSION_MAX = 2, + LBC_VERSION_TARGET = 2, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index dbe54299..6ec10b53 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -119,6 +119,8 @@ public: static std::string getError(const std::string& message); + static uint8_t getVersion(); + private: struct Constant { diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index a34f7603..301cf255 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -9,6 +9,9 @@ namespace Luau { +static_assert(LBC_VERSION_TARGET >= LBC_VERSION_MIN && LBC_VERSION_TARGET <= LBC_VERSION_MAX, "Invalid bytecode version setup"); +static_assert(LBC_VERSION_MAX <= 127, "Bytecode version should be 7-bit so that we can extend the serialization to use varint transparently"); + static const uint32_t kMaxConstantCount = 1 << 23; static const uint32_t kMaxClosureCount = 1 << 15; @@ -572,7 +575,10 @@ void BytecodeBuilder::finalize() bytecode.reserve(capacity); // assemble final bytecode blob - bytecode = char(LBC_VERSION); + uint8_t version = getVersion(); + LUAU_ASSERT(version >= LBC_VERSION_MIN && version <= LBC_VERSION_MAX); + + bytecode = char(version); writeStringTable(bytecode); @@ -1040,7 +1046,7 @@ void BytecodeBuilder::expandJumps() std::string BytecodeBuilder::getError(const std::string& message) { - // 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION for valid bytecode blobs) + // 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION_TARGET for valid bytecode blobs) std::string result; result += char(0); result += message; @@ -1048,6 +1054,12 @@ std::string BytecodeBuilder::getError(const std::string& message) return result; } +uint8_t BytecodeBuilder::getVersion() +{ + // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags + return LBC_VERSION_TARGET; +} + #ifdef LUAU_ASSERTENABLED void BytecodeBuilder::validate() const { diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 52dc9242..e732256b 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,8 +16,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) - LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) @@ -2672,7 +2670,7 @@ struct Compiler else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { skipOp = LOP_FORGPREP_NEXT; - loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; + loopOp = LOP_FORGLOOP; } } else if (stat->values.size == 2) @@ -2682,7 +2680,7 @@ struct Compiler if (builtin.isGlobal("next")) // for .. in next,t { skipOp = LOP_FORGPREP_NEXT; - loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; + loopOp = LOP_FORGLOOP; } } } diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 28152689..c2110cb3 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -26,6 +26,8 @@ void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { void (*dtor)(lua_State*, void*) = nullptr; dtor = L->global->udatagc[u->tag]; + // TODO: access to L here is highly unsafe since this is called during internal GC traversal + // certain operations such as lua_getthreaddata are okay, but by and large this risks crashes on improper use if (dtor) dtor(L, u->data); } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 8b742f1c..86afddd2 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -154,11 +154,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (version != LBC_VERSION) + if (version < LBC_VERSION_MIN || version > LBC_VERSION_MAX) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected [%d..%d], got %d)", chunkid, LBC_VERSION_MIN, LBC_VERSION_MAX, version); return 1; } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 036bf124..655e48cb 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,8 +261,6 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { - ScopedFastFlag sff2("LuauCompileIterNoPairs", false); - // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -329,7 +327,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_NEXT R0 L0 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); @@ -342,7 +340,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_NEXT R0 L0 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); } @@ -2262,8 +2260,6 @@ TEST_CASE("TypeAliasing") TEST_CASE("DebugLineInfo") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2313,7 +2309,7 @@ return result 15: L0: MOVE R7 R1 15: MOVE R8 R5 15: CONCAT R1 R7 R8 -14: L1: FORGLOOP_NEXT R2 L0 +14: L1: FORGLOOP R2 L0 1 17: RETURN R1 1 )"); } @@ -2545,8 +2541,6 @@ a TEST_CASE("DebugSource") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - const char* source = R"( local kSelectedBiomes = { ['Mountains'] = true, @@ -2614,7 +2608,7 @@ L0: MOVE R7 R1 MOVE R8 R5 CONCAT R1 R7 R8 14: for k in pairs(kSelectedBiomes) do -L1: FORGLOOP_NEXT R2 L0 +L1: FORGLOOP R2 L0 1 17: return result RETURN R1 1 )"); @@ -2622,8 +2616,6 @@ RETURN R1 1 TEST_CASE("DebugLocals") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - const char* source = R"( function foo(e, f) local a = 1 @@ -2661,12 +2653,12 @@ end local 0: reg 5, start pc 5 line 5, end pc 8 line 5 local 1: reg 6, start pc 14 line 8, end pc 18 line 8 local 2: reg 7, start pc 14 line 8, end pc 18 line 8 -local 3: reg 3, start pc 21 line 12, end pc 24 line 12 -local 4: reg 3, start pc 26 line 16, end pc 30 line 16 -local 5: reg 0, start pc 0 line 3, end pc 34 line 21 -local 6: reg 1, start pc 0 line 3, end pc 34 line 21 -local 7: reg 2, start pc 1 line 4, end pc 34 line 21 -local 8: reg 3, start pc 34 line 21, end pc 34 line 21 +local 3: reg 3, start pc 22 line 12, end pc 25 line 12 +local 4: reg 3, start pc 27 line 16, end pc 31 line 16 +local 5: reg 0, start pc 0 line 3, end pc 35 line 21 +local 6: reg 1, start pc 0 line 3, end pc 35 line 21 +local 7: reg 2, start pc 1 line 4, end pc 35 line 21 +local 8: reg 3, start pc 35 line 21, end pc 35 line 21 3: LOADN R2 1 4: LOADN R5 1 4: LOADN R3 3 @@ -2683,7 +2675,7 @@ local 8: reg 3, start pc 34 line 21, end pc 34 line 21 8: MOVE R9 R6 8: MOVE R10 R7 8: CALL R8 2 0 -7: L3: FORGLOOP_NEXT R3 L2 +7: L3: FORGLOOP R3 L2 2 11: LOADN R3 2 12: GETIMPORT R4 1 12: LOADN R5 2 @@ -3795,8 +3787,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -3939,7 +3929,7 @@ L2: GETIMPORT R5 1 NEWCLOSURE R6 P1 CAPTURE VAL R3 CALL R5 1 0 -L3: FORGLOOP_NEXT R0 L2 +L3: FORGLOOP R0 L2 2 LOADN R2 1 LOADN R0 10 LOADN R1 1 diff --git a/tests/Fixture.h b/tests/Fixture.h index ffcd4b9e..0e3735f6 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,13 +2,13 @@ #pragma once #include "Luau/Config.h" -#include "Luau/ConstraintGraphBuilder.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index d585b731..7c2f4d1c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -279,7 +279,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; - ScopedFastFlag sff{"LuauRecursionLimitException", true}; TypeArena src; diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 284230c9..a474b6e7 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,7 +12,6 @@ using namespace Luau; struct NormalizeFixture : Fixture { ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; - ScopedFastFlag sff2{"LuauTableSubtypingVariance2", true}; }; void createSomeClasses(TypeChecker& typeChecker) diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index bef38fc3..6619147b 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -264,10 +264,13 @@ TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") } )LUA"; + CheckResult result = check(src); + CodeTooComplex ctc; + if (FFlag::LuauLowerBoundsCalculation) - (void)check(src); + LUAU_REQUIRE_ERRORS(result); else - CHECK_THROWS_AS(check(src), std::exception); + CHECK(hasError(result, &ctc)); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4d2e94ee..e03069a9 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -409,8 +409,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local base = {} function base:one() return 1 end diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 86cc9701..d6f0a0c8 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,8 +7,21 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("TypeAliases"); +TEST_CASE_FIXTURE(Fixture, "basic_alias") +{ + CheckResult result = check(R"( + type T = number + local x: T = 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") { CheckResult result = check(R"( @@ -24,6 +37,63 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); } +TEST_CASE_FIXTURE(Fixture, "names_are_ascribed") +{ + CheckResult result = check(R"( + type T = { x: number } + local x: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("T", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") +{ + // This is a tricky case. In order to support recursive type aliases, + // we first walk the block and generate free types as placeholders. + // We then walk the AST as normal. If we declare a type alias as below, + // we generate a free type. We then begin our normal walk, examining + // local x: T = "foo", which establishes two constraints: + // a <: b + // string <: a + // We then visit the type alias, and establish that + // b <: number + // Then, when solving these constraints, we dispatch them in the order + // they appear above. This means that a ~ b, and a ~ string, thus + // b ~ string. This means the b <: number constraint has no effect. + // Essentially we've "stolen" the alias's type out from under it. + // This test ensures that we don't actually do this. + CheckResult result = check(R"( + local x: T = "foo" + type T = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); + } + else + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); + } +} + TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") { CheckResult result = check(R"( @@ -41,7 +111,22 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_whe CHECK_EQ(typeChecker.numberType, tm->givenType); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") +{ + CheckResult result = check(R"( + --!strict + type T = { f: number, g: U } + type U = { h: number, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = 3, g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index ccdd2b37..3e2ad6dc 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -30,11 +30,21 @@ TEST_CASE_FIXTURE(Fixture, "successful_check") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") +{ + CheckResult result = check(R"( + local x: number = 1 + local y: number? = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double(2) @@ -47,7 +57,7 @@ TEST_CASE_FIXTURE(Fixture, "function_parameter_annotations_are_checked") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double("two") @@ -70,13 +80,13 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retTypes; + TypePackId retPack = follow(ftv->retTypes); const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, tp->head[0]); + REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -116,6 +126,23 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotation_should_continuously_parse LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") +{ + CheckResult result = check(R"( + local x: IDoNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); +} + TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") { CheckResult result = check(R"( @@ -632,7 +659,10 @@ int AssertionCatcher::tripped; TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", false}, + }; AssertionCatcher ac; @@ -646,9 +676,10 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; - - AssertionCatcher ac; + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", false}, + }; bool caught = false; @@ -662,8 +693,44 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") std::runtime_error); CHECK_EQ(true, caught); +} - frontend.iceHandler.onInternalError = {}; +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") +{ + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", true}, + }; + + AssertionCatcher ac; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + InternalCompilerError); + + LUAU_ASSERT(1 == AssertionCatcher::tripped); +} + +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler") +{ + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", true}, + }; + + bool caught = false; + + frontend.iceHandler.onInternalError = [&](const char*) { + caught = true; + }; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + InternalCompilerError); + + CHECK_EQ(true, caught); } TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index edb5adcf..97ba0808 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -700,11 +700,6 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict -- At one point this produced a UAF @@ -979,8 +974,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - // Mutability in type function application right now can create strange recursive types CheckResult result = check(R"( type Table = { a: number } @@ -1015,8 +1008,6 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) @@ -1123,8 +1114,6 @@ TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") { - ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; - // https://github.com/Roblox/luau/issues/484 CheckResult result = check(R"( --!strict @@ -1153,8 +1142,6 @@ local complex: ComplexObject = { TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") { - ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; - // https://github.com/Roblox/luau/issues/484 CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index afec20bf..a0f670f1 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauTableSubtypingVariance2) - TEST_SUITE_BEGIN("TypeInferModules"); TEST_CASE_FIXTURE(BuiltinsFixture, "require") @@ -326,16 +324,9 @@ local b: B.T = a CheckResult result = frontend.check("game/C"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTableSubtypingVariance2) - { - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); - } - else - { - CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'"); - } } TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") @@ -367,16 +358,9 @@ local b: B.T = a CheckResult result = frontend.check("game/D"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTableSubtypingVariance2) - { - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); - } - else - { - CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'"); - } } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index cefba4b2..3f5dad3d 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -353,8 +353,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resol TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local t: {x: number?} = {x = nil} diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index a90f434f..4a88abee 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -260,10 +260,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { - ScopedFastFlag sffs[]{ - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict local x: { ["<>"] : number } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 87d49651..77a2928c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -276,8 +276,6 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local a = {} a.x = 99 @@ -347,8 +345,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict function foo(o) @@ -370,8 +366,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -477,8 +471,6 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict local t = { u = {} } @@ -512,8 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict function f(x : { q : number }) @@ -772,8 +762,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); @@ -783,8 +771,6 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( function empty() return {} end local array: {string} = empty() @@ -1175,8 +1161,6 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local t = {x = 1} function t.m() end @@ -1187,8 +1171,6 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local t = {x = 1} function t:m() end @@ -1468,11 +1450,6 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1518,11 +1495,6 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1609,8 +1581,6 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} local vec1 = {x = 1} @@ -1998,8 +1968,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_prope TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict local buttons = {} @@ -2013,8 +1981,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_prope TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: string } @@ -2031,8 +1997,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type AS = { x: number, y: number } type BS = { x: number, y: string } @@ -2054,11 +2018,6 @@ caused by: TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); @@ -2085,8 +2044,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") { - ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { [number]: string } type B = { [string]: string } @@ -2103,8 +2060,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") { - ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { [number]: number } type B = { [number]: string } @@ -2121,10 +2076,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { - ScopedFastFlag sffs[]{ - {"LuauTableSubtypingVariance2", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2140,11 +2091,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { - ScopedFastFlag sffs[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2166,10 +2112,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { - ScopedFastFlag sffs[]{ - {"LuauTableSubtypingVariance2", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2185,10 +2127,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { - ScopedFastFlag sff[]{ - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local b b = setmetatable({}, {__call = b}) @@ -2201,11 +2139,6 @@ b() TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauSubtypingAddOptPropsToUnsealedTables", true}, - }; - CheckResult result = check(R"( --!strict local function setNumber(t: { p: number? }, x:number) t.p = x end @@ -2706,8 +2639,6 @@ type t0 = any TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning_2") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( type X = T type K = X @@ -2725,8 +2656,6 @@ type K = X TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( type X = T local a = {} @@ -2977,8 +2906,6 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { - ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; - CheckResult result = check(R"( type X = { { x: boolean?, y: boolean? } } diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6257cda6..6a048b26 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -887,8 +887,6 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { - ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict --!nolint @@ -928,7 +926,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") { ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); - ScopedFastFlag sff{"LuauRecursionLimitException", true}; CheckResult result = check(R"( function complex() diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index d19d80cb..2b48133d 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -428,12 +428,6 @@ y = x TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - {"LuauSubtypingAddOptPropsToUnsealedTables", true}, - }; - CheckResult result = check(R"( -- the difference between this and unify_unsealed_table_union_check is the type annotation on x local t = { x = 3, y = true } diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp index 01960fbe..4fba694a 100644 --- a/tests/VisitTypeVar.test.cpp +++ b/tests/VisitTypeVar.test.cpp @@ -10,14 +10,9 @@ using namespace Luau; LUAU_FASTINT(LuauVisitRecursionLimit) -struct VisitTypeVarFixture : Fixture -{ - ScopedFastFlag flag2 = {"LuauRecursionLimitException", true}; -}; - TEST_SUITE_BEGIN("VisitTypeVar"); -TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded") +TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") { ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; @@ -30,7 +25,7 @@ TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded") CHECK_THROWS_AS(toString(tType), RecursionLimitException); } -TEST_CASE_FIXTURE(VisitTypeVarFixture, "dont_throw_when_limit_is_high_enough") +TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") { ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; From e91d80ee25f17af665b5df244e7e8bd77ff9d4f1 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 23 Jun 2022 18:56:19 -0700 Subject: [PATCH 094/102] Update compatibility.md (#559) --- docs/_pages/compatibility.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/_pages/compatibility.md b/docs/_pages/compatibility.md index cdeb4fad..d1686c2f 100644 --- a/docs/_pages/compatibility.md +++ b/docs/_pages/compatibility.md @@ -49,20 +49,20 @@ Sandboxing challenges are [covered in the dedicated section](sandbox). |---------|--------|------| | yieldable pcall/xpcall | ✔️ | | | yieldable metamethods | ❌ | significant performance implications | -| ephemeron tables | ❌ | this complicates the garbage collector esp. for large weak tables | -| emergency garbage collector | ❌ | Luau runs in environments where handling memory exhaustion in emergency situations is not tenable | +| ephemeron tables | ❌ | this complicates and slows down the garbage collector esp. for large weak tables | +| emergency garbage collector | 🤷‍ | Luau runs in environments where handling memory exhaustion in emergency situations is not tenable | | goto statement | ❌ | this complicates the compiler, makes control flow unstructured and doesn't address a significant need | | finalizers for tables | ❌ | no `__gc` support due to sandboxing and performance/complexity | | no more fenv for threads or functions | 😞 | we love this, but it breaks compatibility | | tables honor the `__len` metamethod | 🤷‍♀️ | performance implications, no strong use cases | hex and `\z` escapes in strings | ✔️ | | | support for hexadecimal floats | 🤷‍♀️ | no strong use cases | -| order metamethods work for different types | ❌ | no strong use cases and more complicated semantics + compat | +| order metamethods work for different types | ❌ | no strong use cases and more complicated semantics, compatibility and performance implications | | empty statement | 🤷‍♀️ | less useful in Lua than in JS/C#/C/C++ | -| `break` statement may appear in the middle of a block | 🤷‍♀️ | we'd like to do it for return/continue as well but there be dragons | +| `break` statement may appear in the middle of a block | 🤷‍♀️ | we'd like to do it consistently for `break`/`return`/`continue` but there be dragons | | arguments for function called through `xpcall` | ✔️ | | | optional base in `math.log` | ✔️ | | -| optional separator in `string.rep` | 🤷‍♀️ | no real use cases | +| optional separator in `string.rep` | 🤷‍♀️ | no strong use cases | | new metamethods `__pairs` and `__ipairs` | ❌ | superseded by `__iter` | | frontier patterns | ✔️ | | | `%g` in patterns | ✔️ | | @@ -83,7 +83,7 @@ Ephemeron tables may be implemented at some point since they do have valid uses |---------|--------|------| | `\u` escapes in strings | ✔️ | | | integers (64-bit by default) | ❌ | backwards compatibility and performance implications | -| bitwise operators | ❌ | `bit32` library covers this | +| bitwise operators | ❌ | `bit32` library covers this in absence of 64-bit integers | | basic utf-8 support | ✔️ | we include `utf8` library and other UTF8 features | | functions for packing and unpacking values (string.pack/unpack/packsize) | ✔️ | | | floor division | ❌ | no strong use cases, syntax overlaps with C comments | @@ -95,16 +95,16 @@ Ephemeron tables may be implemented at some point since they do have valid uses It's important to highlight integer support and bitwise operators. For Luau, it's rare that a full 64-bit integer type is necessary - double-precision types support integers up to 2^53 (in Lua which is used in embedded space, integers may be more appealing in environments without a native 64-bit FPU). However, there's a *lot* of value in having a single number type, both from performance perspective and for consistency. Notably, Lua doesn't handle integer overflow properly, so using integers also carries compatibility implications. -If integers are taken out of the equation, bitwise operators make much less sense; additionally, `bit32` library is more fully featured (includes commonly used operations such as rotates and arithmetic shift; bit extraction/replacement is also more readable). Adding operators along with metamethods for all of them increases complexity, which means this feature isn't worth it on the balance. +If integers are taken out of the equation, bitwise operators make less sense, as integers aren't a first class feature; additionally, `bit32` library is more fully featured (includes commonly used operations such as rotates and arithmetic shift; bit extraction/replacement is also more readable). Adding operators along with metamethods for all of them increases complexity, which means this feature isn't worth it on the balance. Common arguments for this include a more familiar syntax, which, while true, gets more nuanced as `^` isn't available as a xor operator, and arithmetic right shift isn't expressible without yet another operator, and performance, which in Luau is substantially better than in Lua because `bit32` library uses VM builtins instead of expensive function calls. -Floor division is less harmful, but it's used rarely enough that `math.floor(a/b)` seems like an adequate replacement; additionally, `//` is a comment in C-derived languages and we may decide to adopt it in addition to `--` at some point. +Floor division is much less complex, but it's used rarely enough that `math.floor(a/b)` seems like an adequate replacement; additionally, `//` is a comment in C-derived languages and we may decide to adopt it in addition to `--` at some point. ## Lua 5.4 | feature | status | notes | |--|--|--| | new generational mode for garbage collection | 🔜 | we're working on gc optimizations and generational mode is on our radar -| to-be-closed variables | ❌ | the syntax is ugly and inconsistent with how we'd like to do attributes long-term; no strong use cases in our domain | +| to-be-closed variables | ❌ | the syntax is inconsistent with how we'd like to do attributes long-term; no strong use cases in our domain | | const variables | ❌ | while there's some demand for const variables, we'd never adopt this syntax | | new implementation for math.random | ✔️ | our RNG is based on PCG, unlike Lua 5.4 which uses Xoroshiro | | optional `init` argument to `string.gmatch` | 🤷‍♀️ | no strong use cases | @@ -112,14 +112,14 @@ Floor division is less harmful, but it's used rarely enough that `math.floor(a/b | coercions string-to-number moved to the string library | 😞 | we love this, but it breaks compatibility | | new format `%p` in `string.format` | 🤷‍♀️ | no strong use cases | | `utf8` library accepts codepoints up to 2^31 | 🤷‍♀️ | no strong use cases | -| The use of the `__lt` metamethod to emulate `__le` has been removed | 😞 | breaks compatibility and doesn't seem very interesting otherwise | +| The use of the `__lt` metamethod to emulate `__le` has been removed | ❌ | breaks compatibility and complicates comparison overloading story | | When finalizing objects, Lua will call `__gc` metamethods that are not functions | ❌ | no `__gc` support due to sandboxing and performance/complexity | | The function print calls `__tostring` instead of tostring to format its arguments. | ✔️ | | | By default, the decoding functions in the utf8 library do not accept surrogates. | 😞 | breaks compatibility and doesn't seem very interesting otherwise | -Lua has a beautiful syntax and frankly we're disappointed in the ``/`` which takes away from that beauty. Taking syntax aside, `` isn't very useful in Luau - its dominant use case is for code that works with external resources like files or sockets, but we don't provide such APIs - and has a very large complexity cost, evidences by a lot of bug fixes since the initial implementation in 5.4 work versions. `` in Luau doesn't matter for performance - our multi-pass compiler is already able to analyze the usage of the variable to know if it's modified or not and extract all performance gains from it - so the only use here is for code readability, where the `` syntax is... suboptimal. +Taking syntax aside (which doesn't feel idiomatic or beautiful), `` isn't very useful in Luau - its dominant use case is for code that works with external resources like files or sockets, but we don't provide such APIs - and has a very large complexity cost, evidences by a lot of bug fixes since the initial implementation in 5.4 work versions. `` in Luau doesn't matter for performance - our multi-pass compiler is already able to analyze the usage of the variable to know if it's modified or not and extract all performance gains from it - so the only use here is for code readability, where the `` syntax is... suboptimal. -If we do end up introducing const variables, it would be through a `const var = value` syntax, which is backwards compatible through a context-sensitive keyword similar to `type`. +If we do end up introducing const variables, it would be through a `const var = value` syntax, which is backwards compatible through a context-sensitive keyword similar to `type`. That said, there's ambiguity wrt whether `const` should simply behave like a read-only variable, ala JavaScript, or if it should represent a stronger contract, for example by limiting the expressions on the right hand side to ones compiler can evaluate ahead of time, or by freezing table values and thus guaranteeing immutability. ## Differences from Lua From 5e405b58b3c7889de9ab9feaee34a254cef775d1 Mon Sep 17 00:00:00 2001 From: Allan N Jeremy Date: Fri, 24 Jun 2022 19:46:29 +0300 Subject: [PATCH 095/102] Added multi-os runners for benchmark & implemented luau analyze (#542) --- .github/workflows/benchmark.yml | 217 ++++- bench/measure_time.py | 43 + bench/static_analysis/LuauPolyfillMap.lua | 962 ++++++++++++++++++++++ scripts/run-with-cachegrind.sh | 9 +- 4 files changed, 1217 insertions(+), 14 deletions(-) create mode 100644 bench/measure_time.py create mode 100644 bench/static_analysis/LuauPolyfillMap.lua diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 68f63006..d4ac82ad 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -4,7 +4,6 @@ on: push: branches: - master - paths-ignore: - "docs/**" - "papers/**" @@ -13,12 +12,13 @@ on: - "prototyping/**" jobs: - benchmarks-run: - name: Run ${{ matrix.bench.title }} + windows: + name: Run ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) strategy: fail-fast: false matrix: - os: [ubuntu-latest] + os: [windows-latest] + arch: [Win32, x64] bench: - { script: "run-benchmarks", @@ -32,7 +32,93 @@ jobs: runs-on: ${{ matrix.os }} steps: - - name: Checkout Luau + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + shell: bash # necessary for fail-fast + run: | + mkdir build && cd build + cmake .. -DCMAKE_BUILD_TYPE=Release + cmake --build . --target Luau.Repl.CLI --config Release + cmake --build . --target Luau.Analyze.CLI --config Release + + - name: Move build files to root + run: | + move build/RelWithDebInfo/* . + + - name: Check dir structure + run: | + ls build/RelWithDebInfo + ls + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) + tool: "benchmarkluau" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + alert-threshold: 150% + fail-threshold: 200% + fail-on-alert: true + comment-on-alert: true + comment-always: true + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. + + unix: + name: Run ${{ matrix.bench.title }} (${{ matrix.os}}) + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository uses: actions/checkout@v3 - name: Build Luau @@ -48,18 +134,21 @@ jobs: python -m pip install requests python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - name: Install valgrind - run: | - sudo apt-get install valgrind - - name: Run benchmark run: | python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + - name: Install valgrind + if: matrix.os == 'ubuntu-latest' + run: | + sudo apt-get install valgrind + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + if: matrix.os == 'ubuntu-latest' run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 | tee -a ${{ matrix.bench.script }}-output.txt - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + if: matrix.os == 'ubuntu-latest' run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle }}" ${{ matrix.bench.cachegrindIterCount }} | tee -a ${{ matrix.bench.script }}-output.txt - name: Checkout Benchmark Results repository @@ -78,12 +167,14 @@ jobs: output-file-path: ./${{ matrix.bench.script }}-output.txt external-data-json-path: ./gh-pages/dev/bench/data.json alert-threshold: 150% - fail-threshold: 1000% - fail-on-alert: false + fail-threshold: 200% + fail-on-alert: true comment-on-alert: true + comment-always: true github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Store ${{ matrix.bench.title }} result + - name: Store ${{ matrix.bench.title }} result (CacheGrind) + if: matrix.os == 'ubuntu-latest' uses: Roblox/rhysd-github-action-benchmark@v-luau with: name: ${{ matrix.bench.title }} (CacheGrind) @@ -97,7 +188,107 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} - name: Push benchmark results - + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. + + static-analysis: + name: Run ${{ matrix.bench.title }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + engine: + - { channel: stable, version: latest } + bench: + - { + script: "run-analyze", + timeout: 12, + title: "Luau Analyze", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + with: + token: "${{ secrets.BENCH_GITHUB_TOKEN }}" + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + sudo pip install requests numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Install valgrind + run: | + sudo apt-get install valgrind + + - name: Run Luau Analyze on static file + run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "benchmarkluau" + + gh-pages-branch: "main" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + alert-threshold: 150% + fail-threshold: 200% + fail-on-alert: true + comment-on-alert: true + comment-always: true + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Store ${{ matrix.bench.title }} result (CacheGrind) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "roblox" + gh-pages-branch: "main" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + alert-threshold: 150% + fail-threshold: 200% + fail-on-alert: true + comment-on-alert: true + comment-always: true + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' run: | echo "Pushing benchmark results..." cd gh-pages diff --git a/bench/measure_time.py b/bench/measure_time.py new file mode 100644 index 00000000..c41c7d2c --- /dev/null +++ b/bench/measure_time.py @@ -0,0 +1,43 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +import os, sys, time, numpy + +try: + import scipy + from scipy import mean, stats +except ModuleNotFoundError: + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None + +duration_list = [] + +DEFAULT_CYCLES_TO_RUN = 100 +cycles_to_run = DEFAULT_CYCLES_TO_RUN + +try: + cycles_to_run = sys.argv[3] if sys.argv[3] else DEFAULT_CYCLES_TO_RUN + cycles_to_run = int(cycles_to_run) +except IndexError: + pass +except (ValueError, TypeError): + cycles_to_run = DEFAULT_CYCLES_TO_RUN + print("Error: Cycles to run argument must be an integer. Using default value of {}".format(DEFAULT_CYCLES_TO_RUN)) + +# Numpy complains if we provide a cycle count of less than 3 ~ default to 3 whenever a lower value is provided +cycles_to_run = cycles_to_run if cycles_to_run > 2 else 3 + +for i in range(1,cycles_to_run): + start = time.perf_counter() + + # Run the code you want to measure here + os.system(sys.argv[1]) + + end = time.perf_counter() + + duration_ms = (end - start) * 1000 + duration_list.append(duration_ms) + +# Stats +mean = numpy.mean(duration_list) +std_err = stats.sem(duration_list) + +print("SUCCESS: {} : {:.2f}ms +/- {:.2f}% on luau ".format('duration', mean,std_err)) diff --git a/bench/static_analysis/LuauPolyfillMap.lua b/bench/static_analysis/LuauPolyfillMap.lua new file mode 100644 index 00000000..1cfd0181 --- /dev/null +++ b/bench/static_analysis/LuauPolyfillMap.lua @@ -0,0 +1,962 @@ +-- This file is part of the Roblox luau-polyfill repository and is licensed under MIT License; see LICENSE.txt for details +--!nonstrict +-- #region Array +-- Array related +local Array = {} +local Object = {} +local Map = {} + +type Array = { [number]: T } +type callbackFn = (element: V, key: K, map: Map) -> () +type callbackFnWithThisArg = (thisArg: Object, value: V, key: K, map: Map) -> () +type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + forEach: (self: Map, callback: callbackFn | callbackFnWithThisArg, thisArg: Object?) -> (), + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U +type Object = { [string]: any } +type Table = { [T]: V } +type Tuple = Array + +local Set = {} + +-- #region Array +function Array.isArray(value: any): boolean + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count += 1 + sum += key + end + + return sum == (count * (count + 1) / 2) +end + +function Array.from( + value: string | Array | Object, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? +): Array + if value == nil then + error("cannot create array from a nil value") + end + local valueType = typeof(value) + + local array = {} + + if valueType == "table" and Array.isArray(value) then + if mapFn then + for i = 1, #(value :: Array) do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: Array)[i], i) + else + array[i] = (mapFn :: mapFn)((value :: Array)[i], i) + end + end + else + for i = 1, #(value :: Array) do + array[i] = (value :: Array)[i] + end + end + elseif instanceOf(value, Set) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif instanceOf(value, Map) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif valueType == "string" then + if mapFn then + for i = 1, (value :: string):len() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: any):sub(i, i), i) + else + array[i] = (mapFn :: mapFn)((value :: any):sub(i, i), i) + end + end + else + for i = 1, (value :: string):len() do + array[i] = (value :: any):sub(i, i) + end + end + end + + return array +end + +type callbackFnArrayMap = (element: T, index: number, array: Array) -> U +type callbackFnWithThisArgArrayMap = (thisArg: V, element: T, index: number, array: Array) -> U + +-- Implements Javascript's `Array.prototype.map` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/map +function Array.map( + t: Array, + callback: callbackFnArrayMap | callbackFnWithThisArgArrayMap, + thisArg: V? +): Array + if typeof(t) ~= "table" then + error(string.format("Array.map called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local A = {} + local k = 1 + + while k <= len do + local kValue = t[k] + + if kValue ~= nil then + local mappedValue + + if thisArg ~= nil then + mappedValue = (callback :: callbackFnWithThisArgArrayMap)(thisArg, kValue, k, t) + else + mappedValue = (callback :: callbackFnArrayMap)(kValue, k, t) + end + + A[k] = mappedValue + end + k += 1 + end + + return A +end + +type Function = (any, any, number, any) -> any + +-- Implements Javascript's `Array.prototype.reduce` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/reduce +function Array.reduce(array: Array, callback: Function, initialValue: any?): any + if typeof(array) ~= "table" then + error(string.format("Array.reduce called on %s", typeof(array))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local length = #array + + local value + local initial = 1 + + if initialValue ~= nil then + value = initialValue + else + initial = 2 + if length == 0 then + error("reduce of empty array with no initial value") + end + value = array[1] + end + + for i = initial, length do + value = callback(value, array[i], i, array) + end + + return value +end + +type callbackFnArrayForEach = (element: T, index: number, array: Array) -> () +type callbackFnWithThisArgArrayForEach = (thisArg: U, element: T, index: number, array: Array) -> () + +-- Implements Javascript's `Array.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/forEach +function Array.forEach( + t: Array, + callback: callbackFnArrayForEach | callbackFnWithThisArgArrayForEach, + thisArg: U? +): () + if typeof(t) ~= "table" then + error(string.format("Array.forEach called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local k = 1 + + while k <= len do + local kValue = t[k] + + if thisArg ~= nil then + (callback :: callbackFnWithThisArgArrayForEach)(thisArg, kValue, k, t) + else + (callback :: callbackFnArrayForEach)(kValue, k, t) + end + + if #t < len then + -- don't iterate on removed items, don't iterate more than original length + len = #t + end + k += 1 + end +end +-- #endregion + +-- #region Set +Set.__index = Set + +type callbackFnSet = (value: T, key: T, set: Set) -> () +type callbackFnWithThisArgSet = (thisArg: Object, value: T, key: T, set: Set) -> () + +export type Set = { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + forEach: (self: Set, callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?) -> (), + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, +} + +type Iterable = { ipairs: (any) -> any } + +function Set.new(iterable: Array | Set | Iterable | string | nil): Set + local array = {} + local map = {} + if iterable ~= nil then + local arrayIterable: Array + -- ROBLOX TODO: remove type casting from (iterable :: any).ipairs in next release + if typeof(iterable) == "table" then + if Array.isArray(iterable) then + arrayIterable = Array.from(iterable :: Array) + elseif typeof((iterable :: Iterable).ipairs) == "function" then + -- handle in loop below + elseif _G.__DEV__ then + error("cannot create array from an object-like table") + end + elseif typeof(iterable) == "string" then + arrayIterable = Array.from(iterable :: string) + else + error(("cannot create array from value of type `%s`"):format(typeof(iterable))) + end + + if arrayIterable then + for _, element in ipairs(arrayIterable) do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + elseif typeof(iterable) == "table" and typeof((iterable :: Iterable).ipairs) == "function" then + for _, element in (iterable :: Iterable):ipairs() do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Set) :: any) :: Set +end + +function Set:add(value) + if not self._map[value] then + -- Luau FIXME: analyze should know self is Set which includes size as a number + self.size = self.size :: number + 1 + self._map[value] = true + table.insert(self._array, value) + end + return self +end + +function Set:clear() + self.size = 0 + table.clear(self._map) + table.clear(self._array) +end + +function Set:delete(value): boolean + if not self._map[value] then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[value] = nil + local index = table.find(self._array, value) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/forEach +function Set:forEach(callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(value: T) + if thisArg ~= nil then + (callback :: callbackFnWithThisArgSet)(thisArg, value, value, self) + else + (callback :: callbackFnSet)(value, value, self) + end + end) +end + +function Set:has(value): boolean + return self._map[value] ~= nil +end + +function Set:ipairs() + return ipairs(self._array) +end + +-- #endregion Set + +-- #region Object +function Object.entries(value: string | Object | Array): Array + assert(value :: any ~= nil, "cannot get entries from a nil value") + local valueType = typeof(value) + + local entries: Array> = {} + if valueType == "table" then + for key, keyValue in pairs(value :: Object) do + -- Luau FIXME: Luau should see entries as Array, given object is [string]: any, but it sees it as Array> despite all the manual annotation + table.insert(entries, { key :: string, keyValue :: any }) + end + elseif valueType == "string" then + for i = 1, string.len(value :: string) do + entries[i] = { tostring(i), string.sub(value :: string, i, i) } + end + end + + return entries +end + +-- #endregion + +-- #region instanceOf + +-- ROBLOX note: Typed tbl as any to work with strict type analyze +-- polyfill for https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/instanceof +function instanceOf(tbl: any, class) + assert(typeof(class) == "table", "Received a non-table as the second argument for instanceof") + + if typeof(tbl) ~= "table" then + return false + end + + local ok, hasNew = pcall(function() + return class.new ~= nil and tbl.new == class.new + end) + if ok and hasNew then + return true + end + + local seen = { tbl = true } + + while tbl and typeof(tbl) == "table" do + tbl = getmetatable(tbl) + if typeof(tbl) == "table" then + tbl = tbl.__index + + if tbl == class then + return true + end + end + + -- if we still have a valid table then check against seen + if typeof(tbl) == "table" then + if seen[tbl] then + return false + end + seen[tbl] = true + end + end + + return false +end +-- #endregion + +function Map.new(iterable: Array>?): Map + local array = {} + local map = {} + if iterable ~= nil then + local arrayFromIterable + local iterableType = typeof(iterable) + if iterableType == "table" then + if #iterable > 0 and typeof(iterable[1]) ~= "table" then + error("cannot create Map from {K, V} form, it must be { {K, V}... }") + end + + arrayFromIterable = Array.from(iterable) + else + error(("cannot create array from value of type `%s`"):format(iterableType)) + end + + for _, entry in ipairs(arrayFromIterable) do + local key = entry[1] + if _G.__DEV__ then + if key == nil then + error("cannot create Map from a table that isn't an array.") + end + end + local val = entry[2] + -- only add to array if new + if map[key] == nil then + table.insert(array, key) + end + -- always assign + map[key] = val + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Map) :: any) :: Map +end + +function Map:set(key: K, value: V): Map + -- preserve initial insertion order + if self._map[key] == nil then + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number + 1 + table.insert(self._array, key) + end + -- always update value + self._map[key] = value + return self +end + +function Map:get(key) + return self._map[key] +end + +function Map:clear() + local table_: any = table + self.size = 0 + table_.clear(self._map) + table_.clear(self._array) +end + +function Map:delete(key): boolean + if self._map[key] == nil then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[key] = nil + local index = table.find(self._array, key) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map/forEach +function Map:forEach(callback: callbackFn | callbackFnWithThisArg, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(key: K) + local value: V = self._map[key] :: V + + if thisArg ~= nil then + (callback :: callbackFnWithThisArg)(thisArg, value, key, self) + else + (callback :: callbackFn)(value, key, self) + end + end) +end + +function Map:has(key): boolean + return self._map[key] ~= nil +end + +function Map:keys() + return self._array +end + +function Map:values() + return Array.map(self._array, function(key) + return self._map[key] + end) +end + +function Map:entries() + return Array.map(self._array, function(key) + return { key, self._map[key] } + end) +end + +function Map:ipairs() + return ipairs(self:entries()) +end + +function Map.__index(self, key) + local mapProp = rawget(Map, key) + if mapProp ~= nil then + return mapProp + end + + return Map.get(self, key) +end + +function Map.__newindex(table_, key, value) + table_:set(key, value) +end + +local function coerceToMap(mapLike: Map | Table): Map + return instanceOf(mapLike, Map) and mapLike :: Map -- ROBLOX: order is preservered + or Map.new(Object.entries(mapLike)) -- ROBLOX: order is not preserved +end + +-- local function coerceToTable(mapLike: Map | Table): Table +-- if not instanceOf(mapLike, Map) then +-- return mapLike +-- end + +-- -- create table from map +-- return Array.reduce(mapLike:entries(), function(tbl, entry) +-- tbl[entry[1]] = entry[2] +-- return tbl +-- end, {}) +-- end + +-- #region Tests to verify it works as expected +local function it(description: string, fn: () -> ()) + local ok, result = pcall(fn) + + if not ok then + error("Failed test: " .. description .. "\n" .. result) + end +end + +local AN_ITEM = "bar" +local ANOTHER_ITEM = "baz" + +-- #region [Describe] "Map" +-- #region [Child Describe] "constructors" +it("creates an empty array", function() + local foo = Map.new() + assert(foo.size == 0) +end) + +it("creates a Map from an array", function() + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + }) + assert(foo.size == 2) + assert(foo:has(AN_ITEM) == true) + assert(foo:has(ANOTHER_ITEM) == true) +end) + +it("creates a Map from an array with duplicate keys", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == "foo2") + + assert(#foo:keys() == 1 and foo:keys()[1] == AN_ITEM) + assert(#foo:values() == 1 and foo:values()[1] == "foo2") + assert(#foo:entries() == 1) + assert(#foo:entries()[1] == 2) + + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") +end) + +it("preserves the order of keys first assignment", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { ANOTHER_ITEM, "bar" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 2) + assert(foo:get(AN_ITEM) == "foo2") + assert(foo:get(ANOTHER_ITEM) == "bar") + + assert(foo:keys()[1] == AN_ITEM) + assert(foo:keys()[2] == ANOTHER_ITEM) + assert(foo:values()[1] == "foo2") + assert(foo:values()[2] == "bar") + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") + assert(foo:entries()[2][1] == ANOTHER_ITEM) + assert(foo:entries()[2][2] == "bar") +end) +-- #endregion + +-- #region [Child Describe] "type" +it("instanceOf return true for an actual Map object", function() + local foo = Map.new() + assert(instanceOf(foo, Map) == true) +end) + +it("instanceOf return false for an regular plain object", function() + local foo = {} + assert(instanceOf(foo, Map) == false) +end) +-- #endregion + +-- #region [Child Describe] "set" +it("returns the Map object", function() + local foo = Map.new() + assert(foo:set(1, "baz") == foo) +end) + +it("increments the size if the element is added for the first time", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo.size == 1) +end) + +it("does not increment the size the second time an element is added", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(AN_ITEM, "val") + assert(foo.size == 1) +end) + +it("sets values correctly to true/false", function() + -- Luau FIXME: Luau insists that arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) + + foo:set(AN_ITEM, true) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == true) + + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) +end) + +-- #endregion + +-- #region [Child Describe] "get" +it("returns value of item from provided key", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:get(AN_ITEM) == "foo") +end) + +it("returns nil if the item is not in the Map", function() + local foo = Map.new() + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "clear" +it("sets the size to zero", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo.size == 0) +end) + +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo:has(AN_ITEM) == false) +end) +-- #endregion + +-- #region [Child Describe] "delete" +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo:has(AN_ITEM) == false) +end) + +it("returns true if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:delete(AN_ITEM) == true) +end) + +it("returns false if the item was not in the Map", function() + local foo = Map.new() + assert(foo:delete(AN_ITEM) == false) +end) + +it("decrements the size if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo.size == 0) +end) + +it("does not decrement the size if the item was not in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(ANOTHER_ITEM) + assert(foo.size == 1) +end) + +it("deletes value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + foo:delete(AN_ITEM) + + assert(foo.size == 0) + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "has" +it("returns true if the item is in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:has(AN_ITEM) == true) +end) + +it("returns false if the item is not in the Map", function() + local foo = Map.new() + assert(foo:has(AN_ITEM) == false) +end) + +it("returns correctly with value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + assert(foo:has(AN_ITEM) == true) +end) +-- #endregion + +-- #region [Child Describe] "keys / values / entries" +it("returns array of elements", function() + local myMap = Map.new() + myMap:set(AN_ITEM, "foo") + myMap:set(ANOTHER_ITEM, "val") + + assert(myMap:keys()[1] == AN_ITEM) + assert(myMap:keys()[2] == ANOTHER_ITEM) + + assert(myMap:values()[1] == "foo") + assert(myMap:values()[2] == "val") + + assert(myMap:entries()[1][1] == AN_ITEM) + assert(myMap:entries()[1][2] == "foo") + assert(myMap:entries()[2][1] == ANOTHER_ITEM) + assert(myMap:entries()[2][2] == "val") +end) +-- #endregion + +-- #region [Child Describe] "__index" +it("can access fields directly without using get", function() + local typeName = "size" + + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + { typeName, "buzz" }, + }) + + assert(foo.size == 3) + assert(foo[AN_ITEM] == "foo") + assert(foo[ANOTHER_ITEM] == "val") + assert(foo:get(typeName) == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "__newindex" +it("can set fields directly without using set", function() + local foo = Map.new() + + assert(foo.size == 0) + + foo[AN_ITEM] = "foo" + foo[ANOTHER_ITEM] = "val" + foo.fizz = "buzz" + + assert(foo.size == 3) + assert(foo:get(AN_ITEM) == "foo") + assert(foo:get(ANOTHER_ITEM) == "val") + assert(foo:get("fizz") == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "ipairs" +local function makeArray(...) + local array = {} + for _, item in ... do + table.insert(array, item) + end + return array +end + +it("iterates on the elements by their insertion order", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + assert(makeArray(foo:ipairs())[1][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "foo") + assert(makeArray(foo:ipairs())[2][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "val") +end) + +it("does not iterate on removed elements", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") +end) + +it("iterates on elements if the added back to the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + foo:set(AN_ITEM, "food") + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") + assert(makeArray(foo:ipairs())[2][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "food") +end) +-- #endregion + +-- #region [Child Describe] "Integration Tests" +-- it("MDN Examples", function() +-- local myMap = Map.new() :: Map + +-- local keyString = "a string" +-- local keyObj = {} +-- local keyFunc = function() end + +-- -- setting the values +-- myMap:set(keyString, "value associated with 'a string'") +-- myMap:set(keyObj, "value associated with keyObj") +-- myMap:set(keyFunc, "value associated with keyFunc") + +-- assert(myMap.size == 3) + +-- -- getting the values +-- assert(myMap:get(keyString) == "value associated with 'a string'") +-- assert(myMap:get(keyObj) == "value associated with keyObj") +-- assert(myMap:get(keyFunc) == "value associated with keyFunc") + +-- assert(myMap:get("a string") == "value associated with 'a string'") + +-- assert(myMap:get({}) == nil) -- nil, because keyObj !== {} +-- assert(myMap:get(function() -- nil because keyFunc !== function () {} +-- end) == nil) +-- end) + +it("handles non-traditional keys", function() + local myMap = Map.new() :: Map + + local falseKey = false + local trueKey = true + local negativeKey = -1 + local emptyKey = "" + + myMap:set(falseKey, "apple") + myMap:set(trueKey, "bear") + myMap:set(negativeKey, "corgi") + myMap:set(emptyKey, "doge") + + assert(myMap.size == 4) + + assert(myMap:get(falseKey) == "apple") + assert(myMap:get(trueKey) == "bear") + assert(myMap:get(negativeKey) == "corgi") + assert(myMap:get(emptyKey) == "doge") + + myMap:delete(falseKey) + myMap:delete(trueKey) + myMap:delete(negativeKey) + myMap:delete(emptyKey) + + assert(myMap.size == 0) +end) +-- #endregion + +-- #endregion [Describe] "Map" + +-- #region [Describe] "coerceToMap" +it("returns the same object if instance of Map", function() + local map = Map.new() + assert(coerceToMap(map) == map) + + map = Map.new({}) + assert(coerceToMap(map) == map) + + map = Map.new({ { AN_ITEM, "foo" } }) + assert(coerceToMap(map) == map) +end) +-- #endregion [Describe] "coerceToMap" + +-- #endregion Tests to verify it works as expected diff --git a/scripts/run-with-cachegrind.sh b/scripts/run-with-cachegrind.sh index eb4a8c3f..787043ff 100644 --- a/scripts/run-with-cachegrind.sh +++ b/scripts/run-with-cachegrind.sh @@ -25,10 +25,17 @@ now_ms() { ITERATION_COUNT=$4 START_TIME=$(now_ms) +ARGS=( "$@" ) +REST_ARGS="${ARGS[@]:4}" + valgrind \ --quiet \ --tool=cachegrind \ - "$1" "$2" >/dev/null + "$1" "$2" $REST_ARGS>/dev/null + +ARGS=( "$@" ) +REST_ARGS="${ARGS[@]:4}" + TIME_ELAPSED=$(bc <<< "$(now_ms) - ${START_TIME}") From 224d35bc9e151149f256c90e5ecc0790b08e8d0b Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 24 Jun 2022 18:16:12 -0700 Subject: [PATCH 096/102] Update benchmark.yml Attempt to fix Windows and other builds --- .github/workflows/benchmark.yml | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index d4ac82ad..20a51b00 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -45,12 +45,8 @@ jobs: - name: Move build files to root run: | - move build/RelWithDebInfo/* . + move build/Release/* . - - name: Check dir structure - run: | - ls build/RelWithDebInfo - ls - uses: actions/setup-python@v3 with: python-version: "3.9" @@ -171,7 +167,7 @@ jobs: fail-on-alert: true comment-on-alert: true comment-always: true - github-token: ${{ secrets.GITHUB_TOKEN }} + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) if: matrix.os == 'ubuntu-latest' @@ -185,7 +181,7 @@ jobs: fail-threshold: 1000% fail-on-alert: false comment-on-alert: true - github-token: ${{ secrets.GITHUB_TOKEN }} + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results if: github.event_name == 'push' @@ -205,8 +201,6 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - engine: - - { channel: stable, version: latest } bench: - { script: "run-analyze", @@ -270,7 +264,7 @@ jobs: fail-on-alert: true comment-on-alert: true comment-always: true - github-token: ${{ secrets.GITHUB_TOKEN }} + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) uses: Roblox/rhysd-github-action-benchmark@v-luau @@ -285,7 +279,7 @@ jobs: fail-on-alert: true comment-on-alert: true comment-always: true - github-token: ${{ secrets.GITHUB_TOKEN }} + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results if: github.event_name == 'push' From 9846a6c7b9dc9699693d062835eba07f61ff9126 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 24 Jun 2022 18:26:15 -0700 Subject: [PATCH 097/102] Update benchmark.yml Remove all alert/comment functionality --- .github/workflows/benchmark.yml | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 20a51b00..60fc6bef 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -1,5 +1,3 @@ -name: Luau Benchmarks - on: push: branches: @@ -76,11 +74,6 @@ jobs: tool: "benchmarkluau" output-file-path: ./${{ matrix.bench.script }}-output.txt external-data-json-path: ./gh-pages/dev/bench/data.json - alert-threshold: 150% - fail-threshold: 200% - fail-on-alert: true - comment-on-alert: true - comment-always: true github-token: ${{ secrets.GITHUB_TOKEN }} - name: Push benchmark results @@ -162,11 +155,6 @@ jobs: tool: "benchmarkluau" output-file-path: ./${{ matrix.bench.script }}-output.txt external-data-json-path: ./gh-pages/dev/bench/data.json - alert-threshold: 150% - fail-threshold: 200% - fail-on-alert: true - comment-on-alert: true - comment-always: true github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) @@ -177,10 +165,6 @@ jobs: tool: "roblox" output-file-path: ./${{ matrix.bench.script }}-output.txt external-data-json-path: ./gh-pages/dev/bench/data.json - alert-threshold: 150% - fail-threshold: 1000% - fail-on-alert: false - comment-on-alert: true github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results @@ -259,11 +243,6 @@ jobs: gh-pages-branch: "main" output-file-path: ./${{ matrix.bench.script }}-output.txt external-data-json-path: ./gh-pages/dev/bench/data.json - alert-threshold: 150% - fail-threshold: 200% - fail-on-alert: true - comment-on-alert: true - comment-always: true github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) @@ -274,11 +253,6 @@ jobs: gh-pages-branch: "main" output-file-path: ./${{ matrix.bench.script }}-output.txt external-data-json-path: ./gh-pages/dev/bench/data.json - alert-threshold: 150% - fail-threshold: 200% - fail-on-alert: true - comment-on-alert: true - comment-always: true github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results From 4cd0443913ba5af65a61d7f5c46b0e25b79ad7f7 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 24 Jun 2022 18:30:26 -0700 Subject: [PATCH 098/102] Update benchmark.yml Cleaner names --- .github/workflows/benchmark.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 60fc6bef..4df7b2f3 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -1,3 +1,5 @@ +name: benchmark + on: push: branches: @@ -11,7 +13,7 @@ on: jobs: windows: - name: Run ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) + name: windows-${{matrix.arch}} strategy: fail-fast: false matrix: @@ -89,7 +91,7 @@ jobs: cd .. unix: - name: Run ${{ matrix.bench.title }} (${{ matrix.os}}) + name: ${{matrix.os}} strategy: fail-fast: false matrix: @@ -180,7 +182,7 @@ jobs: cd .. static-analysis: - name: Run ${{ matrix.bench.title }} + name: luau-analyze strategy: fail-fast: false matrix: From 13e50a9cac327fc2dda40ca6936fbff8c1206ad6 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 27 Jun 2022 09:05:50 -0700 Subject: [PATCH 099/102] Update library.md (#564) Clarify behavior of shifts for out of range values. --- docs/_pages/library.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/_pages/library.md b/docs/_pages/library.md index ff3075ac..f419d2bf 100644 --- a/docs/_pages/library.md +++ b/docs/_pages/library.md @@ -647,7 +647,7 @@ All functions in the `bit32` library treat input numbers as 32-bit unsigned inte function bit32.arshift(n: number, i: number): number ``` -Shifts `n` by `i` bits to the right (if `i` is negative, a left shift is performed instead). The most significant bit of `n` is propagated during the shift. +Shifts `n` by `i` bits to the right (if `i` is negative, a left shift is performed instead). The most significant bit of `n` is propagated during the shift. When `i` is larger than 31, returns an integer with all bits set to the sign bit of `n`. When `i` is smaller than `-31`, 0 is returned. ``` function bit32.band(args: ...number): number @@ -695,7 +695,7 @@ Rotates `n` to the left by `i` bits (if `i` is negative, a right rotate is perfo function bit32.lshift(n: number, i: number): number ``` -Shifts `n` to the left by `i` bits (if `i` is negative, a right shift is performed instead). +Shifts `n` to the left by `i` bits (if `i` is negative, a right shift is performed instead). When `i` is outside of `[-31..31]` range, returns 0. ``` function bit32.replace(n: number, r: number, f: number, w: number?): number @@ -713,7 +713,7 @@ Rotates `n` to the right by `i` bits (if `i` is negative, a left rotate is perfo function bit32.rshift(n: number, i: number): number ``` -Shifts `n` to the right by `i` bits (if `i` is negative, a left shift is performed instead). +Shifts `n` to the right by `i` bits (if `i` is negative, a left shift is performed instead). When `i` is outside of `[-31..31]` range, returns 0. ``` function bit32.countlz(n: number): number From fd82e926286765468f39048538c483d0ccae3a73 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 28 Jun 2022 09:06:59 -0700 Subject: [PATCH 100/102] RFC: Support `__len` metamethod for tables and `rawlen` function (#536) --- rfcs/len-metamethod-rawlen.md | 43 +++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 rfcs/len-metamethod-rawlen.md diff --git a/rfcs/len-metamethod-rawlen.md b/rfcs/len-metamethod-rawlen.md new file mode 100644 index 00000000..45284b71 --- /dev/null +++ b/rfcs/len-metamethod-rawlen.md @@ -0,0 +1,43 @@ +# Support `__len` metamethod for tables and `rawlen` function + +## Summary + +`__len` metamethod will be called by `#` operator on tables, matching Lua 5.2 + +## Motivation + +Lua 5.1 invokes `__len` only on userdata objects, whereas Lua 5.2 extends this to tables. In addition to making `__len` metamethod more uniform and making Luau +more compatible with later versions of Lua, this has the important advantage which is that it makes it possible to implement an index based container. + +Before `__iter` and `__len` it was possible to implement a custom container using `__index`/`__newindex`, but to iterate through the container a custom function was +necessary, because Luau didn't support generalized iteration, `__pairs`/`__ipairs` from Lua 5.2, or `#` override. + +With generalized iteration, a custom container can implement its own iteration behavior so as long as code uses `for k,v in obj` iteration style, the container can +be interfaced with the same way as a table. However, when the container uses integer indices, manual iteration via `#` would still not work - which is required for some +more complicated algorithms, or even to simply iterate through the container backwards. + +Supporting `__len` would make it possible to implement a custom integer based container that exposes the same interface as a table does. + +## Design + +`#v` will call `__len` metamethod if the object is a table and the metamethod exists; the result of the metamethod will be returned if it's a number (an error will be raised otherwise). + +`table.` functions that implicitly compute table length, such as `table.getn`, `table.insert`, will continue using the actual table length. This is consistent with the +general policy that Luau doesn't support metamethods in `table.` functions. + +A new function, `rawlen(v)`, will be added to the standard library; given a string or a table, it will return the length of the object without calling any metamethods. +The new function has the previous behavior of `#` operator with the exception of not supporting userdata inputs, as userdata doesn't have an inherent definition of length. + +## Drawbacks + +`#` is an operator that is used frequently and as such an extra metatable check here may impact performance. However, `#` is usually called on tables without metatables, +and even when it is, using the existing metamethod-absence-caching approach we use for many other metamethods a test version of the change to support `__len` shows no +statistically significant difference on existing benchmark suite. This does complicate the `#` computation a little more which may affect JIT as well, but even if the +table doesn't have a metatable the process of computing `#` involves a series of condition checks and as such will likely require slow paths anyway. + +This is technically changing semantics of `#` when called on tables with an existing `__len` metamethod, and as such has a potential to change behavior of an existing valid program. +That said, it's unlikely that any table would have a metatable with `__len` metamethod as outside of userdata it would not anything, and this drawback is not feasible to resolve with any alternate version of the proposal. + +## Alternatives + +Do not implement `__len`. From c29b803046752838b83b5d2e726e53aa9188e6c2 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 28 Jun 2022 09:08:12 -0700 Subject: [PATCH 101/102] Update STATUS.md Add __len metamethod --- rfcs/STATUS.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index 6bfa865d..23a1be83 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -32,3 +32,9 @@ This document tracks unimplemented RFCs. [RFC: never and unknown types](https://github.com/Roblox/luau/blob/master/rfcs/never-and-unknown-types.md) **Status**: Needs implementation + +## __len metamethod for tables and rawlen function + +[RFC: Support __len metamethod for tables and rawlen function](https://github.com/Roblox/luau/blob/master/rfcs/len-metamethod-rawlen.md) + +**Status**: Needs implementation From ee82f1e9973393c3060b869a969642df873d4575 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 28 Jun 2022 23:13:13 -0700 Subject: [PATCH 102/102] Update sandbox.md Since we don't have a formal proof, clarify that we don't have known bugs. --- docs/_pages/sandbox.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/_pages/sandbox.md b/docs/_pages/sandbox.md index d1d7d118..a7ed7476 100644 --- a/docs/_pages/sandbox.md +++ b/docs/_pages/sandbox.md @@ -4,11 +4,11 @@ title: Sandboxing toc: true --- -Luau is safe to embed. Broadly speaking, this means that even in the face of untrusted (and in Roblox case, actively malicious) code, the language and the standard library don't allow any unsafe access to the underlying system, and don't have any bugs that allow escaping out of the sandbox (e.g. to gain native code execution through ROP gadgets et al). Additionally, the VM provides extra features to implement isolation of privileged code from unprivileged code and protect one from the other; this is important if the embedding environment decides to expose some APIs that may not be safe to call from untrusted code, for example because they do provide controlled access to the underlying system or risk PII exposure through fingerprinting etc. +Luau is safe to embed. Broadly speaking, this means that even in the face of untrusted (and in Roblox case, actively malicious) code, the language and the standard library don't allow unsafe access to the underlying system, and don't have known bugs that allow escaping out of the sandbox (e.g. to gain native code execution through ROP gadgets et al). Additionally, the VM provides extra features to implement isolation of privileged code from unprivileged code and protect one from the other; this is important if the embedding environment decides to expose some APIs that may not be safe to call from untrusted code, for example because they do provide controlled access to the underlying system or risk PII exposure through fingerprinting etc. This safety is achieved through a combination of removing features from the standard library that are unsafe, adding features to the VM that make it possible to implement sandboxing and isolation, and making sure the implementation is safe from memory safety issues using fuzzing. -Of course, since the entire stack is implemented in C++, the sandboxing isn't formally proven - in theory, compiler or the standard library can have exploitable vulnerabilities. In practice these are usually found and fixed quickly. While implementing the stack in a safer language such as Rust would make it easier to provide these guarantees, to our knowledge (based on prior art) this would make it difficult to reach the level of performance required. +Of course, since the entire stack is implemented in C++, the sandboxing isn't formally proven - in theory, compiler or the standard library can have exploitable vulnerabilities. In practice these are very rare and usually found and fixed quickly. While implementing the stack in a safer language such as Rust would make it easier to provide these guarantees, to our knowledge (based on prior art) this would make it difficult to reach the level of performance required. ## Library