// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Config.h" #include "Luau/Differ.h" #include "Luau/Error.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" #include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/Type.h" #include "IostreamOptional.h" #include "ScopedFlags.h" #include "doctest.h" #include #include #include namespace Luau { struct TypeChecker; struct TestFileResolver : FileResolver , ModuleResolver { std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; const ModulePtr getModule(const ModuleName& moduleName) const override; bool moduleExists(const ModuleName& moduleName) const override; std::optional readSource(const ModuleName& name) override; std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; std::string getHumanReadableModuleName(const ModuleName& name) const override; std::optional getEnvironmentForModule(const ModuleName& name) const override; std::unordered_map source; std::unordered_map sourceTypes; std::unordered_map environments; }; struct TestConfigResolver : ConfigResolver { Config defaultConfig; std::unordered_map configFiles; const Config& getConfig(const ModuleName& name) const override; }; struct Fixture { explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {}); CheckResult check(Mode mode, std::string source); CheckResult check(const std::string& source); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); LintResult lintModule(const ModuleName& moduleName, const std::optional& lintOptions = {}); /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); ModulePtr getMainModule(); SourceModule* getMainSourceModule(); std::optional getPrimitiveType(TypeId ty); std::optional getType(const std::string& name); TypeId requireType(const std::string& name); TypeId requireType(const ModuleName& moduleName, const std::string& name); TypeId requireType(const ModulePtr& module, const std::string& name); TypeId requireType(const ScopePtr& scope, const std::string& name); std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); TypeId requireTypeAlias(const std::string& name); TypeId requireExportedType(const ModuleName& moduleName, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; TestFileResolver fileResolver; TestConfigResolver configResolver; NullModuleResolver moduleResolver; std::unique_ptr sourceModule; Frontend frontend; InternalErrorReporter ice; NotNull builtinTypes; std::string decorateWithTypes(const std::string& code); void dumpErrors(std::ostream& os, const std::vector& errors); void dumpErrors(const CheckResult& cr); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void validateErrors(const std::vector& errors); std::string getErrors(const CheckResult& cr); void registerTestTypes(); LoadDefinitionFileResult loadDefinition(const std::string& source); }; struct BuiltinsFixture : Fixture { BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; ModuleName fromString(std::string_view name); template std::optional get(const std::map& map, const Name& name) { auto it = map.find(name); if (it != map.end()) return std::optional(it->second); else return std::nullopt; } std::string rep(const std::string& s, size_t n); bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) std::optional linearSearchForBinding(Scope* scope, const char* name); void registerHiddenTypes(Frontend* frontend); void createSomeClasses(Frontend* frontend); template struct DifferFixtureGeneric : BaseFixture { void compareNe(TypeId left, TypeId right, const std::string& expectedMessage) { std::string diffMessage; try { DifferResult diffRes = diff(left, right); REQUIRE_MESSAGE(diffRes.diffError.has_value(), "Differ did not report type error, even though types are unequal"); diffMessage = diffRes.diffError->toString(); } catch (const InternalCompilerError& e) { REQUIRE_MESSAGE(false, ("InternalCompilerError: " + e.message)); } CHECK_EQ(expectedMessage, diffMessage); } void compareTypesNe(const std::string& leftSymbol, const std::string& rightSymbol, const std::string& expectedMessage) { compareNe(BaseFixture::requireType(leftSymbol), BaseFixture::requireType(rightSymbol), expectedMessage); } void compareEq(TypeId left, TypeId right) { try { DifferResult diffRes = diff(left, right); CHECK_MESSAGE(!diffRes.diffError.has_value(), diffRes.diffError->toString()); } catch (const InternalCompilerError& e) { REQUIRE_MESSAGE(false, ("InternalCompilerError: " + e.message)); } } void compareTypesEq(const std::string& leftSymbol, const std::string& rightSymbol) { compareEq(BaseFixture::requireType(leftSymbol), BaseFixture::requireType(rightSymbol)); } }; using DifferFixture = DifferFixtureGeneric; using DifferFixtureWithBuiltins = DifferFixtureGeneric; } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ REQUIRE(!r.errors.empty()); \ } while (false) #define LUAU_REQUIRE_ERROR_COUNT(count, result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ REQUIRE_MESSAGE(count == r.errors.size(), getErrors(r)); \ } while (false) #define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result)