#include "Luau/TypeChecker2.h" #include #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" #include "Luau/Instantiation.h" #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" namespace Luau { /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. * TypeChecker2 uses this to maintain knowledge about which scope encloses every * given AstNode. */ struct StackPusher { std::vector>* stack; NotNull scope; explicit StackPusher(std::vector>& stack, Scope* scope) : stack(&stack) , scope(scope) { stack.push_back(NotNull{scope}); } ~StackPusher() { if (stack) { LUAU_ASSERT(stack->back() == scope); stack->pop_back(); } } StackPusher(const StackPusher&) = delete; StackPusher&& operator=(const StackPusher&) = delete; StackPusher(StackPusher&& other) : stack(std::exchange(other.stack, nullptr)) , scope(other.scope) { } }; struct TypeChecker2 { const SourceModule* sourceModule; Module* module; InternalErrorReporter ice; // FIXME accept a pointer from Frontend SingletonTypes& singletonTypes; std::vector> stack; TypeChecker2(const SourceModule* sourceModule, Module* module) : sourceModule(sourceModule) , module(module) , singletonTypes(getSingletonTypes()) { } std::optional pushStack(AstNode* node) { if (Scope** scope = module->astScopes.find(node)) return StackPusher{stack, *scope}; else return std::nullopt; } TypePackId lookupPack(AstExpr* expr) { // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. // We'll just return anyType in these cases. Typechecking against any is very fast and this // allows us not to think about this very much in the actual typechecking logic. TypePackId* tp = module->astTypePacks.find(expr); if (tp) return follow(*tp); else return singletonTypes.anyTypePack; } TypeId lookupType(AstExpr* expr) { // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. // We'll just return anyType in these cases. Typechecking against any is very fast and this // allows us not to think about this very much in the actual typechecking logic. TypeId* ty = module->astTypes.find(expr); if (ty) return follow(*ty); TypePackId* tp = module->astTypePacks.find(expr); if (tp) return flattenPack(*tp); return singletonTypes.anyType; } TypeId lookupAnnotation(AstType* annotation) { TypeId* ty = module->astResolvedTypes.find(annotation); LUAU_ASSERT(ty); return follow(*ty); } TypePackId lookupPackAnnotation(AstTypePack* annotation) { TypePackId* tp = module->astResolvedTypePacks.find(annotation); LUAU_ASSERT(tp); return follow(*tp); } TypePackId reconstructPack(AstArray exprs, TypeArena& arena) { if (exprs.size == 0) return arena.addTypePack(TypePack{{}, std::nullopt}); std::vector head; for (size_t i = 0; i < exprs.size - 1; ++i) { head.push_back(lookupType(exprs.data[i])); } TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); return arena.addTypePack(TypePack{head, tail}); } Scope* findInnermostScope(Location location) { Scope* bestScope = module->getModuleScope().get(); Location bestLocation = module->scopes[0].first; for (size_t i = 0; i < module->scopes.size(); ++i) { auto& [scopeBounds, scope] = module->scopes[i]; if (scopeBounds.encloses(location)) { if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) { bestScope = scope.get(); bestLocation = scopeBounds; } } else if (scopeBounds.begin > location.end) { // TODO: Is this sound? This relies on the fact that scopes are inserted // into the scope list in the order that they appear in the AST. break; } } return bestScope; } void visit(AstStat* stat) { auto pusher = pushStack(stat); if (0) {} else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else if (auto s = stat->as()) return visit(s); else LUAU_ASSERT(!"TypeChecker2 encountered an unknown node type"); } void visit(AstStatBlock* block) { auto StackPusher = pushStack(block); for (AstStat* statement : block->body) visit(statement); } void visit(AstStatIf* ifStatement) { visit(ifStatement->condition); visit(ifStatement->thenbody); if (ifStatement->elsebody) visit(ifStatement->elsebody); } void visit(AstStatWhile* whileStatement) { visit(whileStatement->condition); visit(whileStatement->body); } void visit(AstStatRepeat* repeatStatement) { visit(repeatStatement->body); visit(repeatStatement->condition); } void visit(AstStatBreak*) {} void visit(AstStatContinue*) {} void visit(AstStatReturn* ret) { Scope* scope = findInnermostScope(ret->location); TypePackId expectedRetType = scope->returnType; TypeArena arena; TypePackId actualRetType = reconstructPack(ret->list, arena); UnifierSharedState sharedState{&ice}; Unifier u{&arena, Mode::Strict, stack.back(), ret->location, Covariant, sharedState}; u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); const bool ok = u.errors.empty() && u.log.empty(); if (!ok) { for (const TypeError& e : u.errors) reportError(e); } for (AstExpr* expr : ret->list) visit(expr); } void visit(AstStatExpr* expr) { visit(expr->expr); } void visit(AstStatLocal* local) { for (size_t i = 0; i < local->values.size; ++i) { AstExpr* value = local->values.data[i]; visit(value); if (i == local->values.size - 1) { if (i < local->values.size) { TypePackId valueTypes = lookupPack(value); auto it = begin(valueTypes); for (size_t j = i; j < local->vars.size; ++j) { if (it == end(valueTypes)) { break; } AstLocal* var = local->vars.data[i]; if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); if (!isSubtype(*it, varType, stack.back(), ice)) { reportError(TypeMismatch{varType, *it}, value->location); } } ++it; } } } else { TypeId valueType = lookupType(value); AstLocal* var = local->vars.data[i]; if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); if (!isSubtype(varType, valueType, stack.back(), ice)) { reportError(TypeMismatch{varType, valueType}, value->location); } } } } } void visit(AstStatFor* forStatement) { if (forStatement->var->annotation) visit(forStatement->var->annotation); visit(forStatement->from); visit(forStatement->to); if (forStatement->step) visit(forStatement->step); visit(forStatement->body); } void visit(AstStatForIn* forInStatement) { for (AstLocal* local : forInStatement->vars) { if (local->annotation) visit(local->annotation); } for (AstExpr* expr : forInStatement->values) visit(expr); visit(forInStatement->body); } void visit(AstStatAssign* assign) { size_t count = std::min(assign->vars.size, assign->values.size); for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; visit(lhs); TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; visit(rhs); TypeId rhsType = lookupType(rhs); if (!isSubtype(rhsType, lhsType, stack.back(), ice)) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } } } void visit(AstStatCompoundAssign* stat) { visit(stat->var); visit(stat->value); } void visit(AstStatFunction* stat) { visit(stat->name); visit(stat->func); } void visit(AstStatLocalFunction* stat) { visit(stat->func); } void visit(const AstTypeList* typeList) { for (AstType* ty : typeList->types) visit(ty); if (typeList->tailType) visit(typeList->tailType); } void visit(AstStatTypeAlias* stat) { for (const AstGenericType& el : stat->generics) { if (el.defaultValue) visit(el.defaultValue); } for (const AstGenericTypePack& el : stat->genericPacks) { if (el.defaultValue) visit(el.defaultValue); } visit(stat->type); } void visit(AstTypeList types) { for (AstType* type : types.types) visit(type); if (types.tailType) visit(types.tailType); } void visit(AstStatDeclareFunction* stat) { visit(stat->params); visit(stat->retTypes); } void visit(AstStatDeclareGlobal* stat) { visit(stat->type); } void visit(AstStatDeclareClass* stat) { for (const AstDeclaredClassProp& prop : stat->props) visit(prop.ty); } void visit(AstStatError* stat) { for (AstExpr* expr : stat->expressions) visit(expr); for (AstStat* s : stat->statements) visit(s); } void visit(AstExpr* expr) { auto StackPusher = pushStack(expr); if (0) {} else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type"); } void visit(AstExprGroup* expr) { visit(expr->expr); } void visit(AstExprConstantNil* expr) { // TODO! } void visit(AstExprConstantBool* expr) { // TODO! } void visit(AstExprConstantNumber* number) { TypeId actualType = lookupType(number); TypeId numberType = getSingletonTypes().numberType; if (!isSubtype(numberType, actualType, stack.back(), ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } } void visit(AstExprConstantString* string) { TypeId actualType = lookupType(string); TypeId stringType = getSingletonTypes().stringType; if (!isSubtype(stringType, actualType, stack.back(), ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } } void visit(AstExprLocal* expr) { // TODO! } void visit(AstExprGlobal* expr) { // TODO! } void visit(AstExprVarargs* expr) { // TODO! } void visit(AstExprCall* call) { visit(call->func); for (AstExpr* arg : call->args) visit(arg); TypeArena arena; Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); TypeId instantiatedFunctionType = instantiation.substitute(functionType).value_or(nullptr); LUAU_ASSERT(functionType); TypePack args; for (AstExpr* arg : call->args) { TypeId argTy = module->astTypes[arg]; LUAU_ASSERT(argTy); args.head.push_back(argTy); } TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), ice)) { unfreeze(module->interfaceTypes); CloneState cloneState; expectedType = clone(expectedType, module->interfaceTypes, cloneState); freeze(module->interfaceTypes); reportError(TypeMismatch{expectedType, functionType}, call->location); } } void visit(AstExprIndexName* indexName) { TypeId leftType = lookupType(indexName->expr); TypeId resultType = lookupType(indexName); // leftType must have a property called indexName->index std::optional ty = getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { if (!isSubtype(resultType, *ty, stack.back(), ice)) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } } } void visit(AstExprIndexExpr* indexExpr) { // TODO! visit(indexExpr->expr); visit(indexExpr->index); } void visit(AstExprFunction* fn) { auto StackPusher = pushStack(fn); TypeId inferredFnTy = lookupType(fn); const FunctionTypeVar* inferredFtv = get(inferredFnTy); LUAU_ASSERT(inferredFtv); auto argIt = begin(inferredFtv->argTypes); for (const auto& arg : fn->args) { if (argIt == end(inferredFtv->argTypes)) break; if (arg->annotation) { TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), ice)) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } } ++argIt; } visit(fn->body); } void visit(AstExprTable* expr) { // TODO! for (const AstExprTable::Item& item : expr->items) { if (item.key) visit(item.key); visit(item.value); } } void visit(AstExprUnary* expr) { // TODO! visit(expr->expr); } void visit(AstExprBinary* expr) { // TODO! visit(expr->left); visit(expr->right); } void visit(AstExprTypeAssertion* expr) { visit(expr->expr); visit(expr->annotation); TypeId annotationType = lookupAnnotation(expr->annotation); TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. if (isSubtype(annotationType, computedType, stack.back(), ice)) return; if (isSubtype(computedType, annotationType, stack.back(), ice)) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); } void visit(AstExprIfElse* expr) { // TODO! visit(expr->condition); visit(expr->trueExpr); visit(expr->falseExpr); } void visit(AstExprError* expr) { // TODO! for (AstExpr* e : expr->expressions) visit(e); } /** Extract a TypeId for the first type of the provided pack. * * Note that this may require modifying some types. I hope this doesn't cause problems! */ TypeId flattenPack(TypePackId pack) { pack = follow(pack); while (true) { auto tp = get(pack); if (tp && tp->head.empty() && tp->tail) pack = *tp->tail; else break; } if (auto ty = first(pack)) return *ty; else if (auto vtp = get(pack)) return vtp->ty; else if (auto ftp = get(pack)) { TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope}); TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePack& resultPack = asMutable(pack)->ty.emplace(); resultPack.head.assign(1, result); resultPack.tail = freeTail; return result; } else if (get(pack)) return singletonTypes.errorRecoveryType(); else ice.ice("flattenPack got a weird pack!"); } void visit(AstType* ty) { if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) return visit(t); else if (auto t = ty->as()) return visit(t); } void visit(AstTypeReference* ty) { for (const AstTypeOrPack& param : ty->parameters) { if (param.type) visit(param.type); else visit(param.typePack); } Scope* scope = findInnermostScope(ty->location); LUAU_ASSERT(scope); // TODO: Imported types std::optional alias = scope->lookupType(ty->name.value); if (alias.has_value()) { size_t typesRequired = alias->typeParams.size(); size_t packsRequired = alias->typePackParams.size(); bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) { return el.defaultValue.has_value(); }); bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) { return el.defaultValue.has_value(); }); if (!ty->hasParameterList) { if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks)) { reportError(GenericError{"Type parameter list is required"}, ty->location); } } size_t typesProvided = 0; size_t extraTypes = 0; size_t packsProvided = 0; for (const AstTypeOrPack& p : ty->parameters) { if (p.type) { if (packsProvided != 0) { reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); } if (typesProvided < typesRequired) { typesProvided += 1; } else { extraTypes += 1; } } else if (p.typePack) { TypePackId tp = lookupPackAnnotation(p.typePack); if (typesProvided < typesRequired && size(tp) == 1 && finite(tp) && first(tp)) { typesProvided += 1; } else { packsProvided += 1; } } } if (extraTypes != 0 && packsProvided == 0) { packsProvided += 1; } for (size_t i = typesProvided; i < typesRequired; ++i) { if (alias->typeParams[i].defaultValue) { typesProvided += 1; } } for (size_t i = packsProvided; i < packsProvided; ++i) { if (alias->typePackParams[i].defaultValue) { packsProvided += 1; } } if (extraTypes == 0 && packsProvided + 1 == packsRequired) { packsProvided += 1; } if (typesProvided != typesRequired || packsProvided != packsRequired) { reportError(IncorrectGenericParameterCount{ /* name */ ty->name.value, /* typeFun */ *alias, /* actualParameters */ typesProvided, /* actualPackParameters */ packsProvided, }, ty->location); } } else { if (scope->lookupPack(ty->name.value)) { reportError( SwappedGenericTypeParameter{ ty->name.value, SwappedGenericTypeParameter::Kind::Type, }, ty->location); } else { reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); } } } void visit(AstTypeTable* table) { // TODO! for (const AstTableProp& prop : table->props) visit(prop.type); if (table->indexer) { visit(table->indexer->indexType); visit(table->indexer->resultType); } } void visit(AstTypeFunction* ty) { // TODO! visit(ty->argTypes); visit(ty->returnTypes); } void visit(AstTypeTypeof* ty) { visit(ty->expr); } void visit(AstTypeUnion* ty) { // TODO! for (AstType* type : ty->types) visit(type); } void visit(AstTypeIntersection* ty) { // TODO! for (AstType* type : ty->types) visit(type); } void visit(AstTypePack* pack) { if (auto p = pack->as()) return visit(p); else if (auto p = pack->as()) return visit(p); else if (auto p = pack->as()) return visit(p); } void visit(AstTypePackExplicit* tp) { // TODO! for (AstType* type : tp->typeList.types) visit(type); if (tp->typeList.tailType) visit(tp->typeList.tailType); } void visit(AstTypePackVariadic* tp) { // TODO! visit(tp->variadicType); } void visit(AstTypePackGeneric* tp) { Scope* scope = findInnermostScope(tp->location); LUAU_ASSERT(scope); std::optional alias = scope->lookupPack(tp->genericName.value); if (!alias.has_value()) { if (scope->lookupType(tp->genericName.value)) { reportError( SwappedGenericTypeParameter{ tp->genericName.value, SwappedGenericTypeParameter::Kind::Pack, }, tp->location); } else { reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); } } } void reportError(TypeErrorData&& data, const Location& location) { module->errors.emplace_back(location, sourceModule->name, std::move(data)); } void reportError(TypeError e) { module->errors.emplace_back(std::move(e)); } std::optional getIndexTypeFromType( const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors) { return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, type, prop, location, addErrors, ice); } }; void check(const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{&sourceModule, module}; typeChecker.visit(sourceModule.root); } } // namespace Luau