Restructure into resolveTypeWorker

This commit is contained in:
JohnnyMorganz 2022-07-04 21:06:34 +01:00
parent b78b3b9bcb
commit 278db12028
2 changed files with 24 additions and 16 deletions

View file

@ -345,6 +345,7 @@ private:
TypePackId freshTypePack(TypeLevel level); TypePackId freshTypePack(TypeLevel level);
TypeId resolveType(const ScopePtr& scope, const AstType& annotation); TypeId resolveType(const ScopePtr& scope, const AstType& annotation);
TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,

View file

@ -4884,6 +4884,13 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level)
} }
TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation)
{
TypeId ty = resolveTypeWorker(scope, annotation);
currentModule->astResolvedTypes[&annotation] = ty;
return ty;
}
TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& annotation)
{ {
if (const auto& lit = annotation.as<AstTypeReference>()) if (const auto& lit = annotation.as<AstTypeReference>())
{ {
@ -4899,7 +4906,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (lit->parameters.size != 1 || !lit->parameters.data[0].type) if (lit->parameters.size != 1 || !lit->parameters.data[0].type)
{ {
reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}});
return currentModule->astResolvedTypes[&annotation] = errorRecoveryType(anyType); return errorRecoveryType(anyType);
} }
ToStringOptions opts; ToStringOptions opts;
@ -4909,7 +4916,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
TypeId param = resolveType(scope, *lit->parameters.data[0].type); TypeId param = resolveType(scope, *lit->parameters.data[0].type);
luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str()));
return currentModule->astResolvedTypes[&annotation] = param; return param;
} }
else else
@ -4918,7 +4925,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (!tf) if (!tf)
{ {
if (lit->name == kParseNameError) if (lit->name == kParseNameError)
return currentModule->astResolvedTypes[&annotation] = errorRecoveryType(scope); return errorRecoveryType(scope);
std::string typeName; std::string typeName;
if (lit->prefix) if (lit->prefix)
@ -4930,11 +4937,11 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
else else
reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}});
return currentModule->astResolvedTypes[&annotation] = errorRecoveryType(scope); return errorRecoveryType(scope);
} }
if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty())
return currentModule->astResolvedTypes[&annotation] = tf->type; return tf->type;
bool parameterCountErrorReported = false; bool parameterCountErrorReported = false;
bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) {
@ -5085,9 +5092,9 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
// If the generic parameters and the type arguments are the same, we are about to // If the generic parameters and the type arguments are the same, we are about to
// perform an identity substitution, which we can just short-circuit. // perform an identity substitution, which we can just short-circuit.
if (sameTys && sameTps) if (sameTys && sameTps)
return currentModule->astResolvedTypes[&annotation] = tf->type; return tf->type;
return currentModule->astResolvedTypes[&annotation] = instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location);
} }
else if (const auto& table = annotation.as<AstTypeTable>()) else if (const auto& table = annotation.as<AstTypeTable>())
{ {
@ -5102,7 +5109,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed};
ttv.definitionModuleName = currentModuleName; ttv.definitionModuleName = currentModuleName;
return currentModule->astResolvedTypes[&annotation] = addType(std::move(ttv)); return addType(std::move(ttv));
} }
else if (const auto& func = annotation.as<AstTypeFunction>()) else if (const auto& func = annotation.as<AstTypeFunction>())
{ {
@ -5139,12 +5146,12 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
ftv->argNames.push_back(std::nullopt); ftv->argNames.push_back(std::nullopt);
} }
return currentModule->astResolvedTypes[&annotation] = fnType; return fnType;
} }
else if (auto typeOf = annotation.as<AstTypeTypeof>()) else if (auto typeOf = annotation.as<AstTypeTypeof>())
{ {
TypeId ty = checkExpr(scope, *typeOf->expr).type; TypeId ty = checkExpr(scope, *typeOf->expr).type;
return currentModule->astResolvedTypes[&annotation] = ty; return ty;
} }
else if (const auto& un = annotation.as<AstTypeUnion>()) else if (const auto& un = annotation.as<AstTypeUnion>())
{ {
@ -5152,7 +5159,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
for (AstType* ann : un->types) for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann)); types.push_back(resolveType(scope, *ann));
return currentModule->astResolvedTypes[&annotation] = addType(UnionTypeVar{types}); return addType(UnionTypeVar{types});
} }
else if (const auto& un = annotation.as<AstTypeIntersection>()) else if (const auto& un = annotation.as<AstTypeIntersection>())
{ {
@ -5160,22 +5167,22 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
for (AstType* ann : un->types) for (AstType* ann : un->types)
types.push_back(resolveType(scope, *ann)); types.push_back(resolveType(scope, *ann));
return currentModule->astResolvedTypes[&annotation] = addType(IntersectionTypeVar{types}); return addType(IntersectionTypeVar{types});
} }
else if (const auto& tsb = annotation.as<AstTypeSingletonBool>()) else if (const auto& tsb = annotation.as<AstTypeSingletonBool>())
{ {
return currentModule->astResolvedTypes[&annotation] = singletonType(tsb->value); return singletonType(tsb->value);
} }
else if (const auto& tss = annotation.as<AstTypeSingletonString>()) else if (const auto& tss = annotation.as<AstTypeSingletonString>())
{ {
return currentModule->astResolvedTypes[&annotation] = singletonType(std::string(tss->value.data, tss->value.size)); return singletonType(std::string(tss->value.data, tss->value.size));
} }
else if (annotation.is<AstTypeError>()) else if (annotation.is<AstTypeError>())
return currentModule->astResolvedTypes[&annotation] = errorRecoveryType(scope); return errorRecoveryType(scope);
else else
{ {
reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}});
return currentModule->astResolvedTypes[&annotation] = errorRecoveryType(scope); return errorRecoveryType(scope);
} }
} }