// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/BuiltinTypeFunctions.h" #include "Luau/Config.h" #include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" #include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypeFunction.h" #include "IostreamOptional.h" #include "ScopedFlags.h" #include "doctest.h" #include #include #include #include #include LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauForceAllNewSolverTests) LUAU_FASTFLAG(LuauUpdateSetMetatableTypeSignature) LUAU_FASTFLAG(LuauUpdateGetMetatableTypeSignature) #define DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(line) ScopedFastFlag sff_##line{FFlag::LuauSolverV2, FFlag::DebugLuauForceAllNewSolverTests}; #define DOES_NOT_PASS_NEW_SOLVER_GUARD() DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(__LINE__) namespace Luau { struct TypeChecker; struct TestRequireNode : RequireNode { TestRequireNode(ModuleName moduleName, std::unordered_map* allSources) : moduleName(std::move(moduleName)) , allSources(allSources) { } std::string getLabel() const override; std::string getPathComponent() const override; std::unique_ptr resolvePathToNode(const std::string& path) const override; std::vector> getChildren() const override; std::vector getAvailableAliases() const override; ModuleName moduleName; std::unordered_map* allSources; }; struct TestFileResolver; struct TestRequireSuggester : RequireSuggester { TestRequireSuggester(TestFileResolver* resolver) : resolver(resolver) { } std::unique_ptr getNode(const ModuleName& name) const override; TestFileResolver* resolver; }; struct TestFileResolver : FileResolver , ModuleResolver { TestFileResolver() : FileResolver(std::make_shared(this)) { } std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; const ModulePtr getModule(const ModuleName& moduleName) const override; bool moduleExists(const ModuleName& moduleName) const override; std::optional readSource(const ModuleName& name) override; std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; std::string getHumanReadableModuleName(const ModuleName& name) const override; std::optional getEnvironmentForModule(const ModuleName& name) const override; std::unordered_map source; std::unordered_map sourceTypes; std::unordered_map environments; }; struct TestConfigResolver : ConfigResolver { Config defaultConfig; std::unordered_map configFiles; const Config& getConfig(const ModuleName& name) const override; }; struct Fixture { explicit Fixture(bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {}); CheckResult check(Mode mode, const std::string& source, std::optional = std::nullopt); CheckResult check(const std::string& source, std::optional = std::nullopt); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); LintResult lintModule(const ModuleName& moduleName, const std::optional& lintOptions = {}); /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); ModulePtr getMainModule(bool forAutocomplete = false); SourceModule* getMainSourceModule(); std::optional getPrimitiveType(TypeId ty); std::optional getType(const std::string& name, bool forAutocomplete = false); TypeId requireType(const std::string& name); TypeId requireType(const ModuleName& moduleName, const std::string& name); TypeId requireType(const ModulePtr& module, const std::string& name); TypeId requireType(const ScopePtr& scope, const std::string& name); std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); TypeId requireTypeAlias(const std::string& name); TypeId requireExportedType(const ModuleName& moduleName, const std::string& name); std::string canonicalize(TypeId ty); // While most flags can be flipped inside the unit test, some code changes affect the state that is part of Fixture initialization // Most often those are changes related to builtin type definitions. // In that case, flag can be forced to 'true' using the example below: // ScopedFastFlag sff_LuauExampleFlagDefinition{FFlag::LuauExampleFlagDefinition, true}; ScopedFastFlag sff_LuauUpdateSetMetatableTypeSignature{FFlag::LuauUpdateSetMetatableTypeSignature, true}; ScopedFastFlag sff_LuauUpdateGetMetatableTypeSignature{FFlag::LuauUpdateGetMetatableTypeSignature, true}; // Arena freezing marks the `TypeArena`'s underlying memory as read-only, raising an access violation whenever you mutate it. // This is useful for tracking down violations of Luau's memory model. ScopedFastFlag sff_DebugLuauFreezeArena{FFlag::DebugLuauFreezeArena, true}; TestFileResolver fileResolver; TestConfigResolver configResolver; NullModuleResolver moduleResolver; std::unique_ptr sourceModule; InternalErrorReporter ice; std::string decorateWithTypes(const std::string& code); void dumpErrors(std::ostream& os, const std::vector& errors); void dumpErrors(const CheckResult& cr); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void validateErrors(const std::vector& errors); std::string getErrors(const CheckResult& cr); void registerTestTypes(); LoadDefinitionFileResult loadDefinition(const std::string& source, bool forAutocomplete = false); // TODO: test theory about dynamic dispatch NotNull getBuiltins(); virtual Frontend& getFrontend(); private: bool hasDumpedErrors = false; protected: bool forAutocomplete = false; std::optional frontend; BuiltinTypes* builtinTypes = nullptr; TypeArena simplifierArena; SimplifierPtr simplifier{nullptr, nullptr}; }; struct BuiltinsFixture : Fixture { explicit BuiltinsFixture(bool prepareAutocomplete = false); // For the purpose of our tests, we're always the latest version of type functions. Frontend& getFrontend() override; }; std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments); std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr); ModuleName fromString(std::string_view name); template std::optional get(const std::map& map, const Name& name) { auto it = map.find(name); if (it != map.end()) return std::optional(it->second); else return std::nullopt; } std::string rep(const std::string& s, size_t n); bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) std::optional linearSearchForBinding(Scope* scope, const char* name); void registerHiddenTypes(Frontend& frontend); void createSomeExternTypes(Frontend& frontend); template const E* findError(const CheckResult& result) { for (const auto& e : result.errors) { if (auto p = get(e)) return p; } return nullptr; } } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ REQUIRE(!r.errors.empty()); \ } while (false) #define LUAU_REQUIRE_ERROR_COUNT(count, result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ REQUIRE_MESSAGE(count == r.errors.size(), getErrors(r)); \ } while (false) #define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result) #define LUAU_CHECK_ERRORS(result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ CHECK(!r.errors.empty()); \ } while (false) #define LUAU_CHECK_ERROR_COUNT(count, result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ CHECK_MESSAGE(count == r.errors.size(), getErrors(r)); \ } while (false) #define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) #define LUAU_CHECK_HAS_KEY(map, key) \ do \ { \ auto&& _m = (map); \ auto&& _k = (key); \ const size_t count = _m.count(_k); \ CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ if (!count) \ { \ MESSAGE("Keys: (count " << _m.size() << ")"); \ for (const auto& [k, v] : _m) \ { \ MESSAGE("\tkey: " << k); \ } \ } \ } while (false) #define LUAU_CHECK_HAS_NO_KEY(map, key) \ do \ { \ auto&& _m = (map); \ auto&& _k = (key); \ const size_t count = _m.count(_k); \ CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ if (count) \ { \ MESSAGE("Keys: (count " << _m.size() << ")"); \ for (const auto& [k, v] : _m) \ { \ MESSAGE("\tkey: " << k); \ } \ } \ } while (false) #define LUAU_REQUIRE_ERROR(result, Type) \ do \ { \ using T = Type; \ const auto& res = (result); \ if (!findError(res)) \ { \ dumpErrors(res); \ REQUIRE_MESSAGE(false, "Expected to find " #Type " error"); \ } \ } while (false) #define LUAU_CHECK_ERROR(result, Type) \ do \ { \ using T = Type; \ const auto& res = (result); \ if (!findError(res)) \ { \ dumpErrors(res); \ CHECK_MESSAGE(false, "Expected to find " #Type " error"); \ } \ } while (false) #define LUAU_REQUIRE_NO_ERROR(result, Type) \ do \ { \ using T = Type; \ const auto& res = (result); \ if (findError(res)) \ { \ dumpErrors(res); \ REQUIRE_MESSAGE(false, "Expected to find no " #Type " error"); \ } \ } while (false) #define LUAU_CHECK_NO_ERROR(result, Type) \ do \ { \ using T = Type; \ const auto& res = (result); \ if (findError(res)) \ { \ dumpErrors(res); \ CHECK_MESSAGE(false, "Expected to find no " #Type " error"); \ } \ } while (false)