luau/Analysis/src/Substitution.cpp
Andy Friesen fe7621ee8c
Sync to upstream/release/573 (#903)
* Work toward affording parallel type checking
* The interface to `LazyType` has changed:
* `LazyType` now takes a second callback that is passed the `LazyType&`
itself. This new callback is responsible for populating the field
`TypeId LazyType::unwrapped`. Multithreaded implementations should
acquire a lock in this callback.
* Modules now retain their `humanReadableNames`. This reduces the number
of cases where type checking has to call back to a `ModuleResolver`.
* https://github.com/Roblox/luau/pull/902
* Add timing info to the Luau REPL compilation output

We've also fixed some bugs and crashes in the new solver as we march
toward readiness.
* Thread ICEs (Internal Compiler Errors) back to the Frontend properly
* Refinements are no longer applied to lvalues
* More miscellaneous stability improvements

Lots of activity in the new JIT engine:

* Implement register spilling/restore for A64
* Correct Luau IR value restore location tracking
* Fixed use-after-free in x86 register allocator spill restore
* Use btz for bit tests
* Finish branch assembly support for A64
* Codesize and performance improvements for A64
* The bit32 library has been implemented for arm and x64

---------

Co-authored-by: Arseny Kapoulkine <arseny.kapoulkine@gmail.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2023-04-21 15:14:26 -07:00

829 lines
23 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Substitution.h"
#include "Luau/Common.h"
#include "Luau/Clone.h"
#include "Luau/TxnLog.h"
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false)
LUAU_FASTFLAG(LuauClonePublicInterfaceLess2)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false)
namespace Luau
{
static TypeId DEPRECATED_shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone)
{
ty = log->follow(ty);
TypeId result = ty;
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionType* ftv = get<FunctionType>(ty))
{
FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
clone.genericPacks = ftv->genericPacks;
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
result = dest.addType(std::move(clone));
}
else if (const TableType* ttv = get<TableType>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state};
clone.definitionModuleName = ttv->definitionModuleName;
clone.definitionLocation = ttv->definitionLocation;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
clone.tags = ttv->tags;
result = dest.addType(std::move(clone));
}
else if (const MetatableType* mtv = get<MetatableType>(ty))
{
MetatableType clone = MetatableType{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = dest.addType(std::move(clone));
}
else if (const UnionType* utv = get<UnionType>(ty))
{
UnionType clone;
clone.options = utv->options;
result = dest.addType(std::move(clone));
}
else if (const IntersectionType* itv = get<IntersectionType>(ty))
{
IntersectionType clone;
clone.parts = itv->parts;
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionType* petv = get<PendingExpansionType>(ty))
{
PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments};
result = dest.addType(std::move(clone));
}
else if (const NegationType* ntv = get<NegationType>(ty))
{
result = dest.addType(NegationType{ntv->ty});
}
else
return result;
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
}
static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone)
{
if (!FFlag::LuauClonePublicInterfaceLess2)
return DEPRECATED_shallowClone(ty, dest, log, alwaysClone);
auto go = [ty, &dest, alwaysClone](auto&& a) {
using T = std::decay_t<decltype(a)>;
if constexpr (std::is_same_v<T, FreeType>)
return ty;
else if constexpr (std::is_same_v<T, BoundType>)
{
// This should never happen, but visit() cannot see it.
LUAU_ASSERT(!"shallowClone didn't follow its argument!");
return dest.addType(BoundType{a.boundTo});
}
else if constexpr (std::is_same_v<T, GenericType>)
return dest.addType(a);
else if constexpr (std::is_same_v<T, BlockedType>)
return ty;
else if constexpr (std::is_same_v<T, PrimitiveType>)
return ty;
else if constexpr (std::is_same_v<T, PendingExpansionType>)
return ty;
else if constexpr (std::is_same_v<T, AnyType>)
return ty;
else if constexpr (std::is_same_v<T, ErrorType>)
return ty;
else if constexpr (std::is_same_v<T, UnknownType>)
return ty;
else if constexpr (std::is_same_v<T, NeverType>)
return ty;
else if constexpr (std::is_same_v<T, LazyType>)
return ty;
else if constexpr (std::is_same_v<T, SingletonType>)
return dest.addType(a);
else if constexpr (std::is_same_v<T, FunctionType>)
{
FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf};
clone.generics = a.generics;
clone.genericPacks = a.genericPacks;
clone.magicFunction = a.magicFunction;
clone.dcrMagicFunction = a.dcrMagicFunction;
clone.dcrMagicRefinement = a.dcrMagicRefinement;
clone.tags = a.tags;
clone.argNames = a.argNames;
return dest.addType(std::move(clone));
}
else if constexpr (std::is_same_v<T, TableType>)
{
LUAU_ASSERT(!a.boundTo);
TableType clone = TableType{a.props, a.indexer, a.level, a.scope, a.state};
clone.definitionModuleName = a.definitionModuleName;
clone.definitionLocation = a.definitionLocation;
clone.name = a.name;
clone.syntheticName = a.syntheticName;
clone.instantiatedTypeParams = a.instantiatedTypeParams;
clone.instantiatedTypePackParams = a.instantiatedTypePackParams;
clone.tags = a.tags;
return dest.addType(std::move(clone));
}
else if constexpr (std::is_same_v<T, MetatableType>)
{
MetatableType clone = MetatableType{a.table, a.metatable};
clone.syntheticName = a.syntheticName;
return dest.addType(std::move(clone));
}
else if constexpr (std::is_same_v<T, UnionType>)
{
UnionType clone;
clone.options = a.options;
return dest.addType(std::move(clone));
}
else if constexpr (std::is_same_v<T, IntersectionType>)
{
IntersectionType clone;
clone.parts = a.parts;
return dest.addType(std::move(clone));
}
else if constexpr (std::is_same_v<T, ClassType>)
{
if (alwaysClone)
{
ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName};
return dest.addType(std::move(clone));
}
else
return ty;
}
else if constexpr (std::is_same_v<T, NegationType>)
return dest.addType(NegationType{a.ty});
else
static_assert(always_false_v<T>, "Non-exhaustive shallowClone switch");
};
ty = log->follow(ty);
if (auto pty = log->pending(ty))
ty = &pty->pending;
TypeId resTy = visit(go, ty->ty);
if (resTy != ty)
asMutable(resTy)->documentationSymbol = ty->documentationSymbol;
return resTy;
}
void Tarjan::visitChildren(TypeId ty, int index)
{
LUAU_ASSERT(ty == log->follow(ty));
if (ignoreChildren(ty))
return;
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionType* ftv = get<FunctionType>(ty))
{
if (FFlag::LuauSubstitutionFixMissingFields)
{
for (TypeId generic : ftv->generics)
visitChild(generic);
for (TypePackId genericPack : ftv->genericPacks)
visitChild(genericPack);
}
visitChild(ftv->argTypes);
visitChild(ftv->retTypes);
}
else if (const TableType* ttv = get<TableType>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
for (const auto& [name, prop] : ttv->props)
visitChild(prop.type);
if (ttv->indexer)
{
visitChild(ttv->indexer->indexType);
visitChild(ttv->indexer->indexResultType);
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp);
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
}
else if (const MetatableType* mtv = get<MetatableType>(ty))
{
visitChild(mtv->table);
visitChild(mtv->metatable);
}
else if (const UnionType* utv = get<UnionType>(ty))
{
for (TypeId opt : utv->options)
visitChild(opt);
}
else if (const IntersectionType* itv = get<IntersectionType>(ty))
{
for (TypeId part : itv->parts)
visitChild(part);
}
else if (const PendingExpansionType* petv = get<PendingExpansionType>(ty))
{
for (TypeId a : petv->typeArguments)
visitChild(a);
for (TypePackId a : petv->packArguments)
visitChild(a);
}
else if (const ClassType* ctv = get<ClassType>(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv)
{
for (const auto& [name, prop] : ctv->props)
visitChild(prop.type);
if (ctv->parent)
visitChild(*ctv->parent);
if (ctv->metatable)
visitChild(*ctv->metatable);
}
else if (const NegationType* ntv = get<NegationType>(ty))
{
visitChild(ntv->ty);
}
}
void Tarjan::visitChildren(TypePackId tp, int index)
{
LUAU_ASSERT(tp == log->follow(tp));
if (ignoreChildren(tp))
return;
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
if (const TypePack* tpp = get<TypePack>(tp))
{
for (TypeId tv : tpp->head)
visitChild(tv);
if (tpp->tail)
visitChild(*tpp->tail);
}
else if (const VariadicTypePack* vtp = get<VariadicTypePack>(tp))
{
visitChild(vtp->ty);
}
}
std::pair<int, bool> Tarjan::indexify(TypeId ty)
{
ty = log->follow(ty);
bool fresh = !typeToIndex.contains(ty);
int& index = typeToIndex[ty];
if (fresh)
{
index = int(indexToType.size());
indexToType.push_back(ty);
indexToPack.push_back(nullptr);
onStack.push_back(false);
lowlink.push_back(index);
}
return {index, fresh};
}
std::pair<int, bool> Tarjan::indexify(TypePackId tp)
{
tp = log->follow(tp);
bool fresh = !packToIndex.contains(tp);
int& index = packToIndex[tp];
if (fresh)
{
index = int(indexToPack.size());
indexToType.push_back(nullptr);
indexToPack.push_back(tp);
onStack.push_back(false);
lowlink.push_back(index);
}
return {index, fresh};
}
void Tarjan::visitChild(TypeId ty)
{
ty = log->follow(ty);
edgesTy.push_back(ty);
edgesTp.push_back(nullptr);
}
void Tarjan::visitChild(TypePackId tp)
{
tp = log->follow(tp);
edgesTy.push_back(nullptr);
edgesTp.push_back(tp);
}
TarjanResult Tarjan::loop()
{
// Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing
while (!worklist.empty())
{
auto [index, currEdge, lastEdge] = worklist.back();
// First visit
if (currEdge == -1)
{
++childCount;
if (childLimit > 0 && childLimit <= childCount)
return TarjanResult::TooManyChildren;
stack.push_back(index);
onStack[index] = true;
currEdge = int(edgesTy.size());
// Fill in edge list of this vertex
if (TypeId ty = indexToType[index])
visitChildren(ty, index);
else if (TypePackId tp = indexToPack[index])
visitChildren(tp, index);
lastEdge = int(edgesTy.size());
}
// Visit children
bool foundFresh = false;
for (; currEdge < lastEdge; currEdge++)
{
int childIndex = -1;
bool fresh = false;
if (auto ty = edgesTy[currEdge])
std::tie(childIndex, fresh) = indexify(ty);
else if (auto tp = edgesTp[currEdge])
std::tie(childIndex, fresh) = indexify(tp);
else
LUAU_ASSERT(false);
if (fresh)
{
// Original recursion point, update the parent continuation point and start the new element
worklist.back() = {index, currEdge + 1, lastEdge};
worklist.push_back({childIndex, -1, -1});
// We need to continue the top-level loop from the start with the new worklist element
foundFresh = true;
break;
}
else if (onStack[childIndex])
{
lowlink[index] = std::min(lowlink[index], childIndex);
}
visitEdge(childIndex, index);
}
if (foundFresh)
continue;
if (lowlink[index] == index)
{
visitSCC(index);
while (!stack.empty())
{
int popped = stack.back();
stack.pop_back();
onStack[popped] = false;
if (popped == index)
break;
}
}
worklist.pop_back();
// Original return from recursion into a child
if (!worklist.empty())
{
auto [parentIndex, _, parentEndEdge] = worklist.back();
// No need to keep child edges around
edgesTy.resize(parentEndEdge);
edgesTp.resize(parentEndEdge);
lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]);
visitEdge(index, parentIndex);
}
}
return TarjanResult::Ok;
}
TarjanResult Tarjan::visitRoot(TypeId ty)
{
childCount = 0;
if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
ty = log->follow(ty);
auto [index, fresh] = indexify(ty);
worklist.push_back({index, -1, -1});
return loop();
}
TarjanResult Tarjan::visitRoot(TypePackId tp)
{
childCount = 0;
if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
tp = log->follow(tp);
auto [index, fresh] = indexify(tp);
worklist.push_back({index, -1, -1});
return loop();
}
void FindDirty::clearTarjan()
{
dirty.clear();
typeToIndex.clear();
packToIndex.clear();
indexToType.clear();
indexToPack.clear();
stack.clear();
onStack.clear();
lowlink.clear();
edgesTy.clear();
edgesTp.clear();
worklist.clear();
}
bool FindDirty::getDirty(int index)
{
if (dirty.size() <= size_t(index))
dirty.resize(index + 1, false);
return dirty[index];
}
void FindDirty::setDirty(int index, bool d)
{
if (dirty.size() <= size_t(index))
dirty.resize(index + 1, false);
dirty[index] = d;
}
void FindDirty::visitEdge(int index, int parentIndex)
{
if (getDirty(index))
setDirty(parentIndex, true);
}
void FindDirty::visitSCC(int index)
{
bool d = getDirty(index);
for (auto it = stack.rbegin(); !d && it != stack.rend(); it++)
{
if (TypeId ty = indexToType[*it])
d = isDirty(ty);
else if (TypePackId tp = indexToPack[*it])
d = isDirty(tp);
if (*it == index)
break;
}
if (!d)
return;
for (auto it = stack.rbegin(); it != stack.rend(); it++)
{
setDirty(*it, true);
if (TypeId ty = indexToType[*it])
foundDirty(ty);
else if (TypePackId tp = indexToPack[*it])
foundDirty(tp);
if (*it == index)
return;
}
}
TarjanResult FindDirty::findDirty(TypeId ty)
{
return visitRoot(ty);
}
TarjanResult FindDirty::findDirty(TypePackId tp)
{
return visitRoot(tp);
}
std::optional<TypeId> Substitution::substitute(TypeId ty)
{
ty = log->follow(ty);
// clear algorithm state for reentrancy
if (FFlag::LuauSubstitutionReentrant)
clearTarjan();
auto result = findDirty(ty);
if (result != TarjanResult::Ok)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy))
{
replaceChildren(newTy);
replacedTypes.insert(newTy);
}
}
else
{
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
}
}
for (auto [oldTp, newTp] : newPacks)
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp))
{
replaceChildren(newTp);
replacedTypePacks.insert(newTp);
}
}
else
{
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
}
}
TypeId newTy = replace(ty);
return newTy;
}
std::optional<TypePackId> Substitution::substitute(TypePackId tp)
{
tp = log->follow(tp);
// clear algorithm state for reentrancy
if (FFlag::LuauSubstitutionReentrant)
clearTarjan();
auto result = findDirty(tp);
if (result != TarjanResult::Ok)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy))
{
replaceChildren(newTy);
replacedTypes.insert(newTy);
}
}
else
{
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
}
}
for (auto [oldTp, newTp] : newPacks)
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp))
{
replaceChildren(newTp);
replacedTypePacks.insert(newTp);
}
}
else
{
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
}
}
TypePackId newTp = replace(tp);
return newTp;
}
TypeId Substitution::clone(TypeId ty)
{
return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess2);
}
TypePackId Substitution::clone(TypePackId tp)
{
tp = log->follow(tp);
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
if (const TypePack* tpp = get<TypePack>(tp))
{
TypePack clone;
clone.head = tpp->head;
clone.tail = tpp->tail;
return addTypePack(std::move(clone));
}
else if (const VariadicTypePack* vtp = get<VariadicTypePack>(tp))
{
VariadicTypePack clone;
clone.ty = vtp->ty;
if (FFlag::LuauSubstitutionFixMissingFields)
clone.hidden = vtp->hidden;
return addTypePack(std::move(clone));
}
else if (FFlag::LuauClonePublicInterfaceLess2)
{
return addTypePack(*tp);
}
else
return tp;
}
void Substitution::foundDirty(TypeId ty)
{
ty = log->follow(ty);
if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty))
return;
if (isDirty(ty))
newTypes[ty] = follow(clean(ty));
else
newTypes[ty] = follow(clone(ty));
}
void Substitution::foundDirty(TypePackId tp)
{
tp = log->follow(tp);
if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp))
return;
if (isDirty(tp))
newPacks[tp] = follow(clean(tp));
else
newPacks[tp] = follow(clone(tp));
}
TypeId Substitution::replace(TypeId ty)
{
ty = log->follow(ty);
if (TypeId* prevTy = newTypes.find(ty))
return *prevTy;
else
return ty;
}
TypePackId Substitution::replace(TypePackId tp)
{
tp = log->follow(tp);
if (TypePackId* prevTp = newPacks.find(tp))
return *prevTp;
else
return tp;
}
void Substitution::replaceChildren(TypeId ty)
{
LUAU_ASSERT(ty == log->follow(ty));
if (ignoreChildren(ty))
return;
if (ty->owningArena != arena)
return;
if (FunctionType* ftv = getMutable<FunctionType>(ty))
{
if (FFlag::LuauSubstitutionFixMissingFields)
{
for (TypeId& generic : ftv->generics)
generic = replace(generic);
for (TypePackId& genericPack : ftv->genericPacks)
genericPack = replace(genericPack);
}
ftv->argTypes = replace(ftv->argTypes);
ftv->retTypes = replace(ftv->retTypes);
}
else if (TableType* ttv = getMutable<TableType>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
for (auto& [name, prop] : ttv->props)
prop.type = replace(prop.type);
if (ttv->indexer)
{
ttv->indexer->indexType = replace(ttv->indexer->indexType);
ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType);
}
for (TypeId& itp : ttv->instantiatedTypeParams)
itp = replace(itp);
for (TypePackId& itp : ttv->instantiatedTypePackParams)
itp = replace(itp);
}
else if (MetatableType* mtv = getMutable<MetatableType>(ty))
{
mtv->table = replace(mtv->table);
mtv->metatable = replace(mtv->metatable);
}
else if (UnionType* utv = getMutable<UnionType>(ty))
{
for (TypeId& opt : utv->options)
opt = replace(opt);
}
else if (IntersectionType* itv = getMutable<IntersectionType>(ty))
{
for (TypeId& part : itv->parts)
part = replace(part);
}
else if (PendingExpansionType* petv = getMutable<PendingExpansionType>(ty))
{
for (TypeId& a : petv->typeArguments)
a = replace(a);
for (TypePackId& a : petv->packArguments)
a = replace(a);
}
else if (ClassType* ctv = getMutable<ClassType>(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv)
{
for (auto& [name, prop] : ctv->props)
prop.type = replace(prop.type);
if (ctv->parent)
ctv->parent = replace(*ctv->parent);
if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable);
}
else if (NegationType* ntv = getMutable<NegationType>(ty))
{
ntv->ty = replace(ntv->ty);
}
}
void Substitution::replaceChildren(TypePackId tp)
{
LUAU_ASSERT(tp == log->follow(tp));
if (ignoreChildren(tp))
return;
if (tp->owningArena != arena)
return;
if (TypePack* tpp = getMutable<TypePack>(tp))
{
for (TypeId& tv : tpp->head)
tv = replace(tv);
if (tpp->tail)
tpp->tail = replace(*tpp->tail);
}
else if (VariadicTypePack* vtp = getMutable<VariadicTypePack>(tp))
{
vtp->ty = replace(vtp->ty);
}
}
} // namespace Luau