diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 4b216518..6d5824ac 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -6,6 +6,7 @@ #include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" @@ -1076,11 +1077,11 @@ IntersectionTypeIterator end(const IntersectionType* itv) return IntersectionTypeIterator{}; } -static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) +static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; - std::vector result; + std::vector result; for (size_t i = 0; i < size; ++i) { @@ -1099,13 +1100,13 @@ static std::vector parseFormatString(NotNull builtinTypes, break; if (data[i] == 'q' || data[i] == 's') - result.push_back(builtinTypes->stringType); + result.push_back(UnionType {{builtinTypes->stringType, builtinTypes->numberType}}); else if (data[i] == '*') - result.push_back(builtinTypes->unknownType); + result.push_back(UnionType {{builtinTypes->unknownType}}); else if (strchr(options, data[i])) - result.push_back(builtinTypes->numberType); + result.push_back(UnionType {{builtinTypes->numberType}}); else - result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); + result.push_back(UnionType {{builtinTypes->errorRecoveryType(builtinTypes->anyType)}}); } } @@ -1134,7 +1135,7 @@ std::optional> magicFunctionFormat( if (!fmt) return std::nullopt; - std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); size_t paramOffset = 1; @@ -1145,7 +1146,9 @@ std::optional> magicFunctionFormat( { Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(params[i + paramOffset], expected[i], scope, location); + // use arena to flatten union type + const TypeId formatTypes = arena.addType(expected[i]); + typechecker.unify(params[i + paramOffset], formatTypes, scope, location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error @@ -1177,7 +1180,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) if (!fmt) return false; - std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(context.arguments); size_t paramOffset = 1; @@ -1185,7 +1188,8 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) // unify the prefix one argument at a time for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) { - context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); + const TypeId formatTypes = arena->addType(expected[i]); + context.solver->unify(params[i + paramOffset], formatTypes, context.solver->rootScope); } // if we know the argument count or if we have too many arguments for sure, we can issue an error