// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/EqSatSimplification.h" #include "Luau/EqSatSimplificationImpl.h" #include "Luau/EGraph.h" #include "Luau/Id.h" #include "Luau/Language.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypeArena.h" #include "Luau/TypeFunction.h" #include "Luau/VisitType.h" #include #include #include #include #include #include #include LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplification) LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplificationToDot) LUAU_FASTFLAGVARIABLE(DebugLuauExtraEqSatSanityChecks) namespace Luau::EqSatSimplification { using Id = Luau::EqSat::Id; using EGraph = Luau::EqSat::EGraph; using Luau::EqSat::Slice; TTable::TTable(Id basis) { storage.push_back(basis); } // I suspect that this is going to become a performance hotspot. It would be // nice to avoid allocating propTypes_ TTable::TTable(Id basis, std::vector propNames_, std::vector propTypes_) : propNames(std::move(propNames_)) { storage.reserve(propTypes_.size() + 1); storage.push_back(basis); storage.insert(storage.end(), propTypes_.begin(), propTypes_.end()); LUAU_ASSERT(storage.size() == 1 + propTypes_.size()); } Id TTable::getBasis() const { LUAU_ASSERT(!storage.empty()); return storage[0]; } Slice TTable::propTypes() const { LUAU_ASSERT(propNames.size() + 1 == storage.size()); return Slice{storage.data() + 1, propNames.size()}; } Slice TTable::mutableOperands() { return Slice{storage.data(), storage.size()}; } Slice TTable::operands() const { return Slice{storage.data(), storage.size()}; } bool TTable::operator==(const TTable& rhs) const { return storage == rhs.storage && propNames == rhs.propNames; } size_t TTable::Hash::operator()(const TTable& value) const { size_t hash = 0; // We're using pointers here, which does mean platform divergence. I think // it's okay? (famous last words, I know) for (StringId s : value.propNames) EqSat::hashCombine(hash, EqSat::languageHash(s)); EqSat::hashCombine(hash, EqSat::languageHash(value.storage)); return hash; } uint32_t StringCache::add(std::string_view s) { size_t hash = std::hash()(s); if (uint32_t* it = strings.find(hash)) return *it; char* storage = static_cast(allocator.allocate(s.size())); memcpy(storage, s.data(), s.size()); uint32_t result = uint32_t(views.size()); views.emplace_back(storage, s.size()); strings[hash] = result; return result; } std::string_view StringCache::asStringView(StringId id) const { LUAU_ASSERT(id < views.size()); return views[id]; } std::string StringCache::asString(StringId id) const { return std::string{asStringView(id)}; } template Simplify::Data Simplify::make(const EGraph&, const T&) const { return true; } void Simplify::join(Data& left, const Data& right) const { left = left || right; } using EClass = Luau::EqSat::EClass; // A terminal type is a type that does not contain any other types. // Examples: any, unknown, number, string, boolean, nil, table, class, thread, function // // All class types are also terminal. static bool isTerminal(const EType& node) { return node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || node.get(); } static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs) { // If either node is non-terminal, then we early exit: we're not going to // do a state space search for whether something like: // (A | B | C | D) & (E | F | G | H) // ... is a disjoint intersection. if (!isTerminal(lhs) || !isTerminal(rhs)) return false; // Special case some types that aren't strict, disjoint subsets. if (lhs.get() || lhs.get()) return !(rhs.get() || rhs.get()); // Handling strings / booleans: these are the types for which we // expect something like: // // "foo" & ~"bar" // // ... to simplify to "foo". if (lhs.get()) return !(rhs.get() || rhs.get()); if (lhs.get()) return !(rhs.get() || rhs.get()); if (auto lhsSString = lhs.get()) { auto rhsSString = rhs.get(); if (!rhsSString) return !rhs.get(); return lhsSString->value() != rhsSString->value(); } if (auto lhsSBoolean = lhs.get()) { auto rhsSBoolean = rhs.get(); if (!rhsSBoolean) return !rhs.get(); return lhsSBoolean->value() != rhsSBoolean->value(); } // At this point: // - We know both nodes are terminal // - We know that the LHS is not any boolean, string, or class // At this point, we have two classes of checks left: // - Whether the two enodes are exactly the same set (now that the static // sets have been covered). // - Whether one of the enodes is a large semantic set such as TAny, // TUnknown, or TError. return !( lhs.index() == rhs.index() || lhs.get() || rhs.get() || lhs.get() || rhs.get() || lhs.get() || rhs.get() || lhs.get() || rhs.get() || lhs.get() || rhs.get() ); } static bool isTerminal(const EGraph& egraph, Id eclass) { const auto& nodes = egraph[eclass].nodes; return std::any_of( nodes.begin(), nodes.end(), [](auto& a) { return isTerminal(a); } ); } Id mkUnion(EGraph& egraph, std::vector parts) { if (parts.size() == 0) return egraph.add(TNever{}); else if (parts.size() == 1) return parts[0]; else return egraph.add(Union{std::move(parts)}); } Id mkIntersection(EGraph& egraph, std::vector parts) { if (parts.size() == 0) return egraph.add(TUnknown{}); else if (parts.size() == 1) return parts[0]; else return egraph.add(Intersection{std::move(parts)}); } struct ListRemover { std::unordered_map>& mappings2; TypeId ty; ~ListRemover() { mappings2.erase(ty); } }; /* * Crucial subtlety: It is very extremely important that enodes and eclasses are * immutable. Mutating an enode would mean that it is no longer equivalent to * other nodes in the same eclass. * * At the same time, many TypeIds are NOT immutable! * * The thing that makes this navigable is that it is okay if the same TypeId is * imported as a different Id at different times as type inference runs. For * example, if we at one point import a BlockedType as a TOpaque, and later * import that same TypeId as some other enode type, this is all completely * okay. * * The main thing we have to be very cautious about, I think, is unsealed * tables. Unsealed table types have properties imperatively inserted into them * as type inference runs. If we were to encode that TypeId as part of an * enode, we could run into a situation where the egraph makes incorrect * assumptions about the table. * * The solution is pretty simple: Never use the contents of a mutable TypeId in * any reduction rule. TOpaque is always okay because we never actually poke * around inside the TypeId to do anything. */ Id toId( EGraph& egraph, NotNull builtinTypes, std::unordered_map& mappingIdToClass, std::unordered_map>& typeToMappingId, // (TypeId: (MappingId, count)) std::unordered_set& boundNodes, StringCache& strings, TypeId ty ) { ty = follow(ty); // First, handle types which do not contain other types. They obviously // cannot participate in cycles, so we don't have to check for that. if (auto freeTy = get(ty)) return egraph.add(TOpaque{ty}); else if (get(ty)) return egraph.add(TOpaque{ty}); else if (auto prim = get(ty)) { switch (prim->type) { case Luau::PrimitiveType::NilType: return egraph.add(TNil{}); case Luau::PrimitiveType::Boolean: return egraph.add(TBoolean{}); case Luau::PrimitiveType::Number: return egraph.add(TNumber{}); case Luau::PrimitiveType::String: return egraph.add(TString{}); case Luau::PrimitiveType::Thread: return egraph.add(TThread{}); case Luau::PrimitiveType::Function: return egraph.add(TTopFunction{}); case Luau::PrimitiveType::Table: return egraph.add(TTopTable{}); case Luau::PrimitiveType::Buffer: return egraph.add(TBuffer{}); default: LUAU_ASSERT(!"Unimplemented"); return egraph.add(Invalid{}); } } else if (auto s = get(ty)) { if (auto bs = get(s)) return egraph.add(SBoolean{bs->value}); else if (auto ss = get(s)) return egraph.add(SString{strings.add(ss->value)}); else LUAU_ASSERT(!"Unexpected"); } else if (get(ty)) return egraph.add(TOpaque{ty}); else if (get(ty)) return egraph.add(TOpaque{ty}); else if (get(ty)) return egraph.add(TFunction{ty}); else if (ty == builtinTypes->classType) return egraph.add(TTopClass{}); else if (get(ty)) return egraph.add(TClass{ty}); else if (get(ty)) return egraph.add(TAny{}); else if (get(ty)) return egraph.add(TError{}); else if (get(ty)) return egraph.add(TUnknown{}); else if (get(ty)) return egraph.add(TNever{}); // Now handle composite types. if (auto it = typeToMappingId.find(ty); it != typeToMappingId.end()) { auto& [mappingId, count] = it->second; ++count; Id res = egraph.add(TBound{mappingId}); boundNodes.insert(res); return res; } typeToMappingId.emplace(ty, std::pair{mappingIdToClass.size(), 0}); ListRemover lr{typeToMappingId, ty}; auto cache = [&](Id res) { const auto& [mappingId, count] = typeToMappingId.at(ty); if (count > 0) mappingIdToClass.emplace(mappingId, res); return res; }; if (auto tt = get(ty)) return egraph.add(TImportedTable{ty}); else if (get(ty)) return egraph.add(TOpaque{ty}); else if (auto ut = get(ty)) { std::vector parts; for (TypeId part : ut) parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); return cache(mkUnion(egraph, std::move(parts))); } else if (auto it = get(ty)) { std::vector parts; for (TypeId part : it) parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); LUAU_ASSERT(parts.size() > 1); return cache(mkIntersection(egraph, std::move(parts))); } else if (auto negation = get(ty)) { Id part = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, negation->ty); return cache(egraph.add(Negation{std::array{part}})); } else if (auto tfun = get(ty)) { LUAU_ASSERT(tfun->packArguments.empty()); std::vector parts; parts.reserve(tfun->typeArguments.size()); for (TypeId part : tfun->typeArguments) parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); // This looks sily, but we're making a copy of the specific // `TypeFunctionInstanceType` outside of the provided arena so that // we can access the members without fear of the specific TFIT being // overwritten with a bound type. return cache(egraph.add(TTypeFun{ std::make_shared( tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData ), std::move(parts) })); } else if (get(ty)) return egraph.add(TNoRefine{}); else { LUAU_ASSERT(!"Unhandled Type"); return cache(egraph.add(Invalid{})); } } Id toId(EGraph& egraph, NotNull builtinTypes, std::unordered_map& mappingIdToClass, StringCache& strings, TypeId ty) { std::unordered_map> typeToMappingId; std::unordered_set boundNodes; Id id = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, ty); for (Id id : boundNodes) { for (const auto [tb, _index] : Query(&egraph, id)) { Id bindee = mappingIdToClass.at(tb->value()); egraph.merge(id, bindee); } } egraph.rebuild(); return egraph.find(id); } // We apply a penalty to cyclic types to guide the system away from them where // possible. static const int CYCLE_PENALTY = 5000; // Composite types have cost equal to the sum of the costs of their parts plus a // constant factor. static const int SET_TYPE_PENALTY = 1; static const int TABLE_TYPE_PENALTY = 2; static const int NEGATION_PENALTY = 2; static const int TFUN_PENALTY = 2; // FIXME. We don't have an accurate way to score a TImportedTable table against // a TTable. static const int IMPORTED_TABLE_PENALTY = 50; // TBound shouldn't ever be selected as the best node of a class unless we are // debugging eqsat itself and need to stringify eclasses. We thus penalize it // so heavily that we'll use any other alternative. static const int BOUND_PENALTY = 999999999; // TODO iteration count limit // TODO also: accept an argument which is the maximum cost to consider before // abandoning the count. // TODO: the egraph should be the first parameter. static size_t computeCost(std::unordered_map& bestNodes, const EGraph& egraph, std::unordered_map& costs, Id id) { if (auto it = costs.find(id); it != costs.end()) return it->second; const std::vector& nodes = egraph[id].nodes; size_t minCost = std::numeric_limits::max(); size_t bestNode = std::numeric_limits::max(); const auto updateCost = [&](size_t cost, size_t node) { if (cost < minCost) { minCost = cost; bestNode = node; } }; // First, quickly scan for a terminal type. If we can find one, it is obviously the best. for (size_t index = 0; index < nodes.size(); ++index) { if (isTerminal(nodes[index])) { minCost = 1; bestNode = index; costs[id] = 1; const auto [iter, isFresh] = bestNodes.insert({id, index}); // If we are forcing the cost function to select a specific node, // then we still need to traverse into that node, even if this // particular node is the obvious choice under normal circumstances. if (isFresh || iter->second == index) return 1; } } // If we recur into this type before this call frame completes, it is // because this type participates in a cycle. costs[id] = CYCLE_PENALTY; auto computeChildren = [&](Slice parts, size_t maxCost) -> std::optional { size_t cost = 0; for (Id part : parts) { cost += computeCost(bestNodes, egraph, costs, part); // Abandon this node if it is too costly if (cost > maxCost) return std::nullopt; } return cost; }; size_t startIndex = 0; size_t endIndex = nodes.size(); // FFlag::DebugLuauLogSimplification will sometimes stringify an Id and pass // in a prepopulated bestNodes map. If that mapping already has an index // for this Id, don't look at the other nodes of this class. if (auto it = bestNodes.find(id); it != bestNodes.end()) { LUAU_ASSERT(it->second < nodes.size()); startIndex = it->second; endIndex = startIndex + 1; } for (size_t index = startIndex; index < endIndex; ++index) { const auto& node = nodes[index]; if (node.get()) updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. else if (node.get()) { minCost = 1; bestNode = index; } else if (auto tbl = node.get()) { // TODO: We could make the penalty a parameter to computeChildren. std::optional maybeCost = computeChildren(tbl->operands(), minCost); if (maybeCost) updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); } else if (node.get()) { minCost = IMPORTED_TABLE_PENALTY; bestNode = index; } else if (auto u = node.get()) { std::optional maybeCost = computeChildren(u->operands(), minCost); if (maybeCost) updateCost(SET_TYPE_PENALTY + *maybeCost, index); } else if (auto i = node.get()) { std::optional maybeCost = computeChildren(i->operands(), minCost); if (maybeCost) updateCost(SET_TYPE_PENALTY + *maybeCost, index); } else if (auto negation = node.get()) { std::optional maybeCost = computeChildren(negation->operands(), minCost); if (maybeCost) updateCost(NEGATION_PENALTY + *maybeCost, index); } else if (auto tfun = node.get()) { std::optional maybeCost = computeChildren(tfun->operands(), minCost); if (maybeCost) updateCost(TFUN_PENALTY + *maybeCost, index); } } LUAU_ASSERT(bestNode < nodes.size()); costs[id] = minCost; bestNodes.insert({id, bestNode}); return minCost; } static std::unordered_map computeBestResult(const EGraph& egraph, Id id, const std::unordered_map& forceNodes) { std::unordered_map costs; std::unordered_map bestNodes = forceNodes; computeCost(bestNodes, egraph, costs, id); return bestNodes; } static std::unordered_map computeBestResult(const EGraph& egraph, Id id) { std::unordered_map costs; std::unordered_map bestNodes; computeCost(bestNodes, egraph, costs, id); return bestNodes; } TypeId fromId( EGraph& egraph, const StringCache& strings, NotNull builtinTypes, NotNull arena, const std::unordered_map& bestNodes, std::unordered_map& seen, std::vector& newTypeFunctions, Id rootId ); TypeId flattenTableNode( EGraph& egraph, const StringCache& strings, NotNull builtinTypes, NotNull arena, const std::unordered_map& bestNodes, std::unordered_map& seen, std::vector& newTypeFunctions, Id rootId ) { std::vector stack; std::unordered_set seenIds; Id id = rootId; const TImportedTable* importedTable = nullptr; while (true) { size_t index = bestNodes.at(id); const auto& eclass = egraph[id]; const auto [_iter, isFresh] = seenIds.insert(id); if (!isFresh) { // If a TTable is its own basis, it must be the case that some other // node on this eclass is a TImportedTable. Let's use that. bool found = false; for (size_t i = 0; i < eclass.nodes.size(); ++i) { if (eclass.nodes[i].get()) { found = true; index = i; break; } } if (!found) { // If we couldn't find one, we don't know what to do. Use ErrorType. LUAU_ASSERT(0); return builtinTypes->errorType; } } const auto& node = eclass.nodes[index]; if (const TTable* ttable = node.get()) { stack.push_back(ttable); id = ttable->getBasis(); continue; } else if (const TImportedTable* ti = node.get()) { importedTable = ti; break; } else LUAU_ASSERT(0); } TableType resultTable; if (importedTable) { const TableType* t = Luau::get(importedTable->value()); LUAU_ASSERT(t); resultTable = *t; // Intentional shallow clone here } while (!stack.empty()) { const TTable* t = stack.back(); stack.pop_back(); for (size_t i = 0; i < t->propNames.size(); ++i) { StringId propName = t->propNames[i]; const Id propType = t->propTypes()[i]; resultTable.props[strings.asString(propName)] = Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)}; } } return arena->addType(std::move(resultTable)); } TypeId fromId( EGraph& egraph, const StringCache& strings, NotNull builtinTypes, NotNull arena, const std::unordered_map& bestNodes, std::unordered_map& seen, std::vector& newTypeFunctions, Id rootId ) { if (auto it = seen.find(rootId); it != seen.end()) return it->second; size_t index = bestNodes.at(rootId); LUAU_ASSERT(index <= egraph[rootId].nodes.size()); const EType& node = egraph[rootId].nodes[index]; if (node.get()) return builtinTypes->nilType; else if (node.get()) return builtinTypes->booleanType; else if (node.get()) return builtinTypes->numberType; else if (node.get()) return builtinTypes->stringType; else if (node.get()) return builtinTypes->threadType; else if (node.get()) return builtinTypes->functionType; else if (node.get()) return builtinTypes->tableType; else if (node.get()) return builtinTypes->classType; else if (node.get()) return builtinTypes->bufferType; else if (auto opaque = node.get()) return opaque->value(); else if (auto b = node.get()) return b->value() ? builtinTypes->trueType : builtinTypes->falseType; else if (auto s = node.get()) return arena->addType(SingletonType{StringSingleton{strings.asString(s->value())}}); else if (auto fun = node.get()) return fun->value(); else if (auto tbl = node.get()) { TypeId res = arena->addType(BlockedType{}); seen[rootId] = res; TypeId flattened = flattenTableNode(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); asMutable(res)->ty.emplace(flattened); return flattened; } else if (auto tbl = node.get()) return tbl->value(); else if (auto cls = node.get()) return cls->value(); else if (node.get()) return builtinTypes->anyType; else if (node.get()) return builtinTypes->errorType; else if (node.get()) return builtinTypes->unknownType; else if (node.get()) return builtinTypes->neverType; else if (auto u = node.get()) { Slice parts = u->operands(); if (parts.empty()) return builtinTypes->neverType; else if (parts.size() == 1) { TypeId placeholder = arena->addType(BlockedType{}); seen[rootId] = placeholder; auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); if (follow(result) == placeholder) { emplaceType(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE"); } else { emplaceType(asMutable(placeholder), result); } return result; } else { TypeId res = arena->addType(BlockedType{}); seen[rootId] = res; std::vector partTypes; partTypes.reserve(parts.size()); for (Id part : parts) partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); asMutable(res)->ty.emplace(std::move(partTypes)); return res; } } else if (auto i = node.get()) { Slice parts = i->operands(); if (parts.empty()) return builtinTypes->neverType; else if (parts.size() == 1) { LUAU_ASSERT(parts[0] != rootId); return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); } else { TypeId res = arena->addType(BlockedType{}); seen[rootId] = res; std::vector partTypes; partTypes.reserve(parts.size()); for (Id part : parts) partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); asMutable(res)->ty.emplace(std::move(partTypes)); return res; } } else if (auto negation = node.get()) { TypeId res = arena->addType(BlockedType{}); seen[rootId] = res; TypeId ty = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, negation->operands()[0]); asMutable(res)->ty.emplace(ty); return res; } else if (auto tfun = node.get()) { TypeId res = arena->addType(BlockedType{}); seen[rootId] = res; std::vector args; for (Id part : tfun->operands()) args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); auto oldInstance = tfun->value(); asMutable(res)->ty.emplace( oldInstance->function, std::move(args), std::vector(), oldInstance->userFuncName, oldInstance->userFuncData ); newTypeFunctions.push_back(res); return res; } else if (node.get()) return builtinTypes->errorType; else if (node.get()) return builtinTypes->noRefineType; else { LUAU_ASSERT(!"Unimplemented"); return nullptr; } } static TypeId fromId( EGraph& egraph, const StringCache& strings, NotNull builtinTypes, NotNull arena, const std::unordered_map& forceNodes, std::vector& newTypeFunctions, Id rootId ) { const std::unordered_map bestNodes = computeBestResult(egraph, rootId, forceNodes); std::unordered_map seen; return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); } static TypeId fromId( EGraph& egraph, const StringCache& strings, NotNull builtinTypes, NotNull arena, std::vector& newTypeFunctions, Id rootId ) { const std::unordered_map bestNodes = computeBestResult(egraph, rootId); std::unordered_map seen; return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); } Subst::Subst(Id eclass, Id newClass, std::string desc) : eclass(std::move(eclass)) , newClass(std::move(newClass)) , desc(std::move(desc)) { } std::string mkDesc( EGraph& egraph, const StringCache& strings, NotNull arena, NotNull builtinTypes, Id from, Id to, const std::unordered_map& forceNodes, const std::string& rule ) { if (!FFlag::DebugLuauLogSimplification) return ""; std::vector newTypeFunctions; TypeId fromTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, from); TypeId toTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, to); ToStringOptions opts; opts.useQuestionMarks = false; const int RULE_PADDING = 35; const std::string rulePadding(std::max(0, RULE_PADDING - rule.size()), ' '); const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") "; const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") "; return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts); } std::string mkDesc(EGraph& egraph, const StringCache& strings, NotNull arena, NotNull builtinTypes, Id from, Id to, const std::string& rule) { if (!FFlag::DebugLuauLogSimplification) return ""; return mkDesc(egraph, strings, arena, builtinTypes, from, to, {}, rule); } static std::string getNodeName(const StringCache& strings, const EType& node) { if (node.get()) return "nil"; else if (node.get()) return "boolean"; else if (node.get()) return "number"; else if (node.get()) return "string"; else if (node.get()) return "thread"; else if (node.get()) return "function"; else if (node.get()) return "table"; else if (node.get()) return "class"; else if (node.get()) return "buffer"; else if (node.get()) return "opaque"; else if (auto b = node.get()) return b->value() ? "true" : "false"; else if (auto s = node.get()) return "\"" + strings.asString(s->value()) + "\""; else if (node.get()) return "\xe2\x88\xaa"; else if (node.get()) return "\xe2\x88\xa9"; else if (auto cls = node.get()) { const ClassType* ct = get(cls->value()); LUAU_ASSERT(ct); return ct->name; } else if (node.get()) return "any"; else if (node.get()) return "error"; else if (node.get()) return "unknown"; else if (node.get()) return "never"; else if (auto tfun = node.get()) return "tfun " + tfun->value()->function->name; else if (node.get()) return "~"; else if (node.get()) return "invalid?"; else if (node.get()) return "bound"; return "???"; } std::string toDot(const StringCache& strings, const EGraph& egraph) { std::stringstream ss; ss << "digraph G {" << '\n'; ss << " graph [fontsize=10 fontname=\"Verdana\" compound=true];" << '\n'; ss << " node [shape=record fontsize=10 fontname=\"Verdana\"];" << '\n'; std::set populated; for (const auto& [id, eclass] : egraph.getAllClasses()) { for (const auto& node : eclass.nodes) { if (!node.operands().empty()) populated.insert(id); for (Id op : node.operands()) populated.insert(op); } } for (const auto& [id, eclass] : egraph.getAllClasses()) { if (!populated.count(id)) continue; const std::string className = "cluster_" + std::to_string(uint32_t(id)); ss << " subgraph " << className << " {" << '\n'; ss << " node [style=\"rounded,filled\"];" << '\n'; ss << " label = \"" << uint32_t(id) << "\";" << '\n'; ss << " color = blue;" << '\n'; for (size_t index = 0; index < eclass.nodes.size(); ++index) { const auto& node = eclass.nodes[index]; const std::string label = getNodeName(strings, node); const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); ss << " " << nodeName << " [label=\"" << label << "\"];" << '\n'; } ss << " }" << '\n'; } for (const auto& [id, eclass] : egraph.getAllClasses()) { for (size_t index = 0; index < eclass.nodes.size(); ++index) { const auto& node = eclass.nodes[index]; const std::string label = getNodeName(strings, node); const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); for (Id op : node.operands()) { op = egraph.find(op); const std::string destNodeName = "n" + std::to_string(uint32_t(op)) + "_0"; ss << " " << nodeName << " -> " << destNodeName << " [lhead=cluster_" << uint32_t(op) << "];" << '\n'; } } } ss << "}" << '\n'; return ss.str(); } template static Tag const* isTag(const EType& node) { return node.get(); } /// Important: Only use this to test for leaf node types like TUnknown and /// TNumber. Things that we know cannot be simplified any further and are safe /// to short-circuit on. /// /// It does a linear scan and exits early, so if a particular eclass has /// multiple "interesting" representations, this function can surprise you. template static Tag const* isTag(const EGraph& egraph, Id id) { for (const auto& node : egraph[id].nodes) { if (auto n = isTag(node)) return n; } return nullptr; } struct RewriteRule { explicit RewriteRule(EGraph* egraph) : egraph(egraph) { } virtual void read(std::vector& substs, Id eclass, const EType* enode) = 0; protected: const EqSat::EClass& get(Id id) { return (*egraph)[id]; } Id find(Id id) { return egraph->find(id); } Id add(EType enode) { return egraph->add(std::move(enode)); } template const Tag* isTag(Id id) { for (const auto& node : (*egraph)[id].nodes) { if (auto n = node.get()) return n; } return nullptr; } template bool isTag(const EType& enode) { return enode.get(); } public: EGraph* egraph; }; enum SubclassRelationship { LeftSuper, RightSuper, Unrelated }; static SubclassRelationship relateClasses(const TClass* leftClass, const TClass* rightClass) { const ClassType* leftClassType = Luau::get(leftClass->value()); const ClassType* rightClassType = Luau::get(rightClass->value()); if (isSubclass(leftClassType, rightClassType)) return RightSuper; else if (isSubclass(rightClassType, leftClassType)) return LeftSuper; else return Unrelated; } // Entirely analogous to NormalizedType except that it operates on eclasses instead of TypeIds. struct CanonicalizedType { std::optional nilPart; std::optional truePart; std::optional falsePart; std::optional numberPart; std::optional stringPart; std::vector stringSingletons; std::optional threadPart; std::optional functionPart; std::optional tablePart; std::vector classParts; std::optional bufferPart; std::optional errorPart; // Functions that have been union'd into the type std::unordered_set functionParts; // Anything that isn't canonical: Intersections, unions, free types, and so on. std::unordered_set otherParts; bool isUnknown() const { return nilPart && truePart && falsePart && numberPart && stringPart && threadPart && functionPart && tablePart && bufferPart; } }; void unionUnknown(EGraph& egraph, CanonicalizedType& ct) { ct.nilPart = egraph.add(TNil{}); ct.truePart = egraph.add(SBoolean{true}); ct.falsePart = egraph.add(SBoolean{false}); ct.numberPart = egraph.add(TNumber{}); ct.stringPart = egraph.add(TString{}); ct.threadPart = egraph.add(TThread{}); ct.functionPart = egraph.add(TTopFunction{}); ct.tablePart = egraph.add(TTopTable{}); ct.bufferPart = egraph.add(TBuffer{}); ct.functionParts.clear(); ct.otherParts.clear(); } void unionAny(EGraph& egraph, CanonicalizedType& ct) { unionUnknown(egraph, ct); ct.errorPart = egraph.add(TError{}); } void unionClasses(EGraph& egraph, std::vector& hereParts, Id there) { if (1 == hereParts.size() && isTag(egraph, hereParts[0])) return; const auto thereClass = isTag(egraph, there); if (!thereClass) return; for (size_t index = 0; index < hereParts.size(); ++index) { const Id herePart = hereParts[index]; if (auto partClass = isTag(egraph, herePart)) { switch (relateClasses(partClass, thereClass)) { case LeftSuper: return; case RightSuper: hereParts[index] = there; std::sort(hereParts.begin(), hereParts.end()); return; case Unrelated: continue; } } } hereParts.push_back(there); std::sort(hereParts.begin(), hereParts.end()); } void unionWithType(EGraph& egraph, CanonicalizedType& ct, Id part) { if (isTag(egraph, part)) ct.nilPart = part; else if (isTag(egraph, part)) ct.truePart = ct.falsePart = part; else if (auto b = isTag(egraph, part)) { if (b->value()) ct.truePart = part; else ct.falsePart = part; } else if (isTag(egraph, part)) ct.numberPart = part; else if (isTag(egraph, part)) ct.stringPart = part; else if (isTag(egraph, part)) ct.stringSingletons.push_back(part); else if (isTag(egraph, part)) ct.threadPart = part; else if (isTag(egraph, part)) { ct.functionPart = part; ct.functionParts.clear(); } else if (isTag(egraph, part)) ct.tablePart = part; else if (isTag(egraph, part)) ct.classParts = {part}; else if (isTag(egraph, part)) ct.bufferPart = part; else if (isTag(egraph, part)) { if (!ct.functionPart) ct.functionParts.insert(part); } else if (auto tclass = isTag(egraph, part)) unionClasses(egraph, ct.classParts, part); else if (isTag(egraph, part)) { unionAny(egraph, ct); return; } else if (isTag(egraph, part)) ct.errorPart = part; else if (isTag(egraph, part)) unionUnknown(egraph, ct); else if (isTag(egraph, part)) { // Nothing } else ct.otherParts.insert(part); } // Find an enode under the given eclass which is simple enough that it could be // subtracted from a CanonicalizedType easily. // // A union is "simple enough" if it is acyclic and is only comprised of terminal // types and unions that are themselves subtractable const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set& seen, Id id) { if (seen.count(id)) return nullptr; const EType* bestUnion = nullptr; std::optional unionSize; for (const auto& node : egraph[id].nodes) { if (isTerminal(node)) return &node; if (const auto u = node.get()) { seen.insert(id); for (Id part : u->operands()) { if (!findSubtractableClass(egraph, seen, part)) return nullptr; } // If multiple unions in this class are all simple enough, prefer // the shortest one. if (!unionSize || u->operands().size() < unionSize) { unionSize = u->operands().size(); bestUnion = &node; } } } return bestUnion; } const EType* findSubtractableClass(const EGraph& egraph, Id id) { std::unordered_set seen; return findSubtractableClass(egraph, seen, id); } // Subtract the type 'part' from 'ct' // Returns true if the subtraction succeeded. This function will fail if 'part` is too complicated. bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part) { const EType* etype = findSubtractableClass(egraph, part); if (!etype) return false; if (etype->get()) ct.nilPart.reset(); else if (etype->get()) { ct.truePart.reset(); ct.falsePart.reset(); } else if (auto b = etype->get()) { if (b->value()) ct.truePart.reset(); else ct.falsePart.reset(); } else if (etype->get()) ct.numberPart.reset(); else if (etype->get()) ct.stringPart.reset(); else if (etype->get()) return false; else if (etype->get()) ct.threadPart.reset(); else if (etype->get()) ct.functionPart.reset(); else if (etype->get()) ct.tablePart.reset(); else if (etype->get()) ct.classParts.clear(); else if (auto tclass = etype->get()) { auto it = std::find(ct.classParts.begin(), ct.classParts.end(), part); if (it != ct.classParts.end()) ct.classParts.erase(it); else return false; } else if (etype->get()) ct.bufferPart.reset(); else if (etype->get()) ct = {}; else if (etype->get()) ct.errorPart.reset(); else if (etype->get()) { std::optional errorPart = ct.errorPart; ct = {}; ct.errorPart = errorPart; } else if (etype->get()) { // Nothing } else if (auto u = etype->get()) { // TODO cycles // TODO this is super promlematic because 'part' represents a whole group of equivalent enodes. for (Id unionPart : u->operands()) { // TODO: This recursive call will require that we re-traverse this // eclass to find the subtractible enode. It would be nice to do the // work just once and reuse it. bool ok = subtract(egraph, ct, unionPart); if (!ok) return false; } } else if (etype->get()) return false; else return false; return true; } Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) { if (ct.isUnknown()) { if (ct.errorPart) return egraph.add(TAny{}); else return egraph.add(TUnknown{}); } std::vector parts; if (ct.nilPart) parts.push_back(*ct.nilPart); if (ct.truePart && ct.falsePart) parts.push_back(egraph.add(TBoolean{})); else if (ct.truePart) parts.push_back(*ct.truePart); else if (ct.falsePart) parts.push_back(*ct.falsePart); if (ct.numberPart) parts.push_back(*ct.numberPart); if (ct.stringPart) parts.push_back(*ct.stringPart); else if (!ct.stringSingletons.empty()) parts.insert(parts.end(), ct.stringSingletons.begin(), ct.stringSingletons.end()); if (ct.threadPart) parts.push_back(*ct.threadPart); if (ct.functionPart) parts.push_back(*ct.functionPart); if (ct.tablePart) parts.push_back(*ct.tablePart); parts.insert(parts.end(), ct.classParts.begin(), ct.classParts.end()); if (ct.bufferPart) parts.push_back(*ct.bufferPart); if (ct.errorPart) parts.push_back(*ct.errorPart); parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); return mkUnion(egraph, std::move(parts)); } void addChildren(const EGraph& egraph, const EType* enode, VecDeque& worklist) { for (Id id : enode->operands()) worklist.push_back(id); } static bool occurs(EGraph& egraph, Id outerId, Slice operands) { for (const Id i : operands) { if (egraph.find(i) == outerId) return true; } return false; } Simplifier::Simplifier(NotNull arena, NotNull builtinTypes) : arena(arena) , builtinTypes(builtinTypes) , egraph(Simplify{}) { } const EqSat::EClass& Simplifier::get(Id id) const { return egraph[id]; } Id Simplifier::find(Id id) const { return egraph.find(id); } Id Simplifier::add(EType enode) { return egraph.add(std::move(enode)); } template const Tag* Simplifier::isTag(Id id) const { for (const auto& node : get(id).nodes) { if (const Tag* ty = node.get()) return ty; } return nullptr; } template const Tag* Simplifier::isTag(const EType& enode) const { return enode.get(); } void Simplifier::subst(Id from, Id to) { substs.emplace_back(from, to, " - "); } void Simplifier::subst(Id from, Id to, const std::string& ruleName) { std::string desc; if (FFlag::DebugLuauLogSimplification) desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, std::move(ruleName)); substs.emplace_back(from, to, desc); } void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) { std::string desc; if (FFlag::DebugLuauLogSimplification) desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); substs.emplace_back(from, to, desc); } void Simplifier::unionClasses(std::vector& hereParts, Id there) { if (1 == hereParts.size() && isTag(hereParts[0])) return; const auto thereClass = isTag(there); if (!thereClass) return; for (size_t index = 0; index < hereParts.size(); ++index) { const Id herePart = hereParts[index]; if (auto partClass = isTag(herePart)) { switch (relateClasses(partClass, thereClass)) { case LeftSuper: return; case RightSuper: hereParts[index] = there; std::sort(hereParts.begin(), hereParts.end()); return; case Unrelated: continue; } } } hereParts.push_back(there); std::sort(hereParts.begin(), hereParts.end()); } void Simplifier::simplifyUnion(Id id) { id = find(id); for (const auto [u, unionIndex] : Query(&egraph, id)) { std::vector newParts; std::unordered_set seen; CanonicalizedType canonicalized; if (occurs(egraph, id, u->operands())) continue; for (Id part : u->operands()) unionWithType(egraph, canonicalized, find(part)); Id resultId = fromCanonicalized(egraph, canonicalized); subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); } } // If one of the nodes matches the given Tag, succeed and return the id and node for the other half. // If neither matches, return nullopt. template static std::optional> matchOne(Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) { if (hereNode->get()) return std::pair{thereId, thereNode}; else if (thereNode->get()) return std::pair{hereId, hereNode}; else return std::nullopt; } // If the two nodes can be intersected into a "simple" type, return that, else return nullopt. std::optional intersectOne(EGraph& egraph, Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) { hereId = egraph.find(hereId); thereId = egraph.find(thereId); if (hereId == thereId) return *hereNode; if (hereNode->get() || thereNode->get()) return TNever{}; if (hereNode->get() || hereNode->get() || hereNode->get() || thereNode->get() || thereNode->get() || thereNode->get() || hereNode->get() || thereNode->get()) return std::nullopt; if (hereNode->get()) return *thereNode; if (thereNode->get()) return *hereNode; if (hereNode->get() || thereNode->get()) return std::nullopt; if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) { const auto [otherId, otherNode] = *res; if (otherNode->get() || otherNode->get()) return *otherNode; else return TNever{}; } if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) { const auto [otherId, otherNode] = *res; if (otherNode->get() || otherNode->get()) return *otherNode; } if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) { const auto [otherId, otherNode] = *res; if (otherNode->get()) return std::nullopt; // TODO else return TNever{}; } if (auto hereClass = hereNode->get()) { if (auto thereClass = thereNode->get()) { switch (relateClasses(hereClass, thereClass)) { case LeftSuper: return *thereNode; case RightSuper: return *hereNode; case Unrelated: return TNever{}; } } else return TNever{}; } if (auto hereBool = hereNode->get()) { if (auto thereBool = thereNode->get()) { if (hereBool->value() == thereBool->value()) return *hereNode; else return TNever{}; } else if (thereNode->get()) return *hereNode; else return TNever{}; } if (auto thereBool = thereNode->get()) { if (auto hereBool = hereNode->get()) { if (thereBool->value() == hereBool->value()) return *thereNode; else return TNever{}; } else if (hereNode->get()) return *thereNode; else return TNever{}; } if (hereNode->get()) { if (thereNode->get()) return TBoolean{}; else if (thereNode->get()) return *thereNode; else return TNever{}; } if (thereNode->get()) { if (hereNode->get()) return TBoolean{}; else if (hereNode->get()) return *hereNode; else return TNever{}; } if (hereNode->get()) { if (thereNode->get()) return *hereNode; else return TNever{}; } if (thereNode->get()) { if (hereNode->get()) return *thereNode; else return TNever{}; } if (hereNode->get()) { if (thereNode->get() || thereNode->get()) return *thereNode; else return TNever{}; } if (thereNode->get()) { if (hereNode->get() || hereNode->get()) return *hereNode; else return TNever{}; } if (hereNode->get() && thereNode->get()) return std::nullopt; if (hereNode->get() && isTerminal(*thereNode)) return TNever{}; if (thereNode->get() && isTerminal(*hereNode)) return TNever{}; if (isTerminal(*hereNode) && isTerminal(*thereNode)) { // We already know that 'here' and 'there' are different classes. return TNever{}; } return std::nullopt; } void Simplifier::uninhabitedIntersection(Id id) { for (const auto [intersection, index] : Query(&egraph, id)) { Slice parts = intersection->operands(); if (parts.empty()) { Id never = egraph.add(TNever{}); subst(id, never, "uninhabitedIntersection"); return; } else if (1 == parts.size()) { subst(id, parts[0], "uninhabitedIntersection"); return; } Id accumulator = egraph.add(TUnknown{}); EType accumulatorNode = TUnknown{}; std::vector unsimplified; if (occurs(egraph, id, parts)) continue; for (Id partId : parts) { if (isTag(partId)) return; bool found = false; const auto& partNodes = egraph[partId].nodes; for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) { const EType& N = partNodes[partIndex]; if (std::optional intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) { if (isTag(*intersection)) { subst(id, egraph.add(TNever{}), "uninhabitedIntersection", {{id, index}, {partId, partIndex}}); return; } accumulator = egraph.add(*intersection); accumulatorNode = *intersection; found = true; break; } } if (!found) unsimplified.push_back(partId); } if ((unsimplified.empty() || !isTag(accumulator)) && find(accumulator) != id) unsimplified.push_back(accumulator); const Id result = mkIntersection(egraph, std::move(unsimplified)); subst(id, result, "uninhabitedIntersection", {{id, index}}); } } void Simplifier::intersectWithNegatedClass(Id id) { for (const auto pair : Query(&egraph, id)) { const Intersection* intersection = pair.first; const size_t intersectionIndex = pair.second; auto trySubst = [&](size_t i, size_t j) { Id iId = intersection->operands()[i]; Id jId = intersection->operands()[j]; for (const auto [negation, negationIndex] : Query(&egraph, jId)) { const Id negated = negation->operands()[0]; if (iId == negated) { subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {jId, negationIndex}}); return; } for (const auto [negatedClass, negatedClassIndex] : Query(&egraph, negated)) { const auto& iNodes = egraph[iId].nodes; for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) { const EType& iNode = iNodes[iIndex]; if (isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || // isTag(iNode) || // I'm not sure about this one. isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode)) { // eg string & ~SomeClass subst(id, iId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); return; } if (const TClass* class_ = iNode.get()) { switch (relateClasses(class_, negatedClass)) { case LeftSuper: // eg Instance & ~Part // This cannot be meaningfully reduced. continue; case RightSuper: subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); return; case Unrelated: // Part & ~Folder == Part { std::vector newParts; newParts.reserve(intersection->operands().size() - 1); for (Id part : intersection->operands()) { if (part != jId) newParts.push_back(part); } Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()}); subst(id, substId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); } } } } } } }; if (2 != intersection->operands().size()) continue; trySubst(0, 1); trySubst(1, 0); } } void Simplifier::intersectWithNegatedAtom(Id id) { // Let I and ~J be two arbitrary distinct operands of an intersection where // I and J are terminal but are not type variables. (free, generic, or // otherwise opaque) // // If I and J are equal, then the whole intersection is equivalent to never. // // If I and J are inequal, then J & ~I == J for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) { const Slice& intersectionOperands = intersection->operands(); for (size_t i = 0; i < intersectionOperands.size(); ++i) { for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[i])) { for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex) { const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex]; if (!isTerminal(*negationOperand) || negationOperand->get()) continue; for (size_t j = 0; j < intersectionOperands.size(); ++j) { if (j == i) continue; for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex) { const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex]; if (!isTerminal(*jNode) || jNode->get()) continue; if (*negationOperand == *jNode) { // eg "Hello" & ~"Hello" // or boolean & ~boolean subst( id, egraph.add(TNever{}), "intersectWithNegatedAtom", {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} ); return; } else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand)) { // eg "Hello" & ~"World" // or boolean & ~string std::vector newOperands(intersectionOperands.begin(), intersectionOperands.end()); newOperands.erase(newOperands.begin() + std::vector::difference_type(i)); subst( id, egraph.add(Intersection{newOperands}), "intersectWithNegatedAtom", {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} ); } } } } } } } } void Simplifier::intersectWithNoRefine(Id id) { for (const auto pair : Query(&egraph, id)) { const Intersection* intersection = pair.first; const size_t intersectionIndex = pair.second; const Slice intersectionOperands = intersection->operands(); for (size_t index = 0; index < intersectionOperands.size(); ++index) { const auto replace = [&]() { std::vector newOperands{intersectionOperands.begin(), intersectionOperands.end()}; newOperands.erase(newOperands.begin() + index); Id substId = egraph.add(Intersection{std::move(newOperands)}); subst(id, substId, "intersectWithNoRefine", {{id, intersectionIndex}}); }; if (isTag(intersectionOperands[index])) replace(); else { for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[index])) { if (isTag(negation->operands()[0])) { replace(); break; } } } } } } /* * Replace x where x = A & (B | x) with A * * Important subtlety: The egraph is routinely going to create cyclic unions and * intersections. We can't arbitrarily remove things from a union just because * it can be referred to in a cyclic way. We must only do this for things that * can only be expressed in a cyclic way. * * As an example, we will bind the following type to true: * * (true | buffer | class | function | number | string | table | thread) & * boolean * * The egraph represented by this type will indeed be cyclic as the 'true' class * includes both 'true' itself and the above type, but removing true from the * union will result is an incorrect judgment! * * The solution (for now) is only to consider a type to be cyclic if it was * cyclic on its original import. * * FIXME: I still don't think this is quite right, but I don't know how to * articulate what the actual rule ought to be. */ void Simplifier::cyclicIntersectionOfUnion(Id id) { // FIXME: This has pretty terrible runtime complexity. for (const auto [i, intersectionIndex] : Query(&egraph, id)) { Slice intersectionParts = i->operands(); for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionParts.size(); ++intersectionOperandIndex) { const Id intersectionPart = find(intersectionParts[intersectionOperandIndex]); for (const auto [bound, _boundIndex] : Query(&egraph, intersectionPart)) { const Id pointee = find(mappingIdToClass.at(bound->value())); for (const auto [u, unionIndex] : Query(&egraph, pointee)) { const Slice& unionOperands = u->operands(); for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) { Id unionOperand = find(unionOperands[unionOperandIndex]); if (unionOperand == id) { std::vector newIntersectionParts(intersectionParts.begin(), intersectionParts.end()); newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); subst( id, mkIntersection(egraph, std::move(newIntersectionParts)), "cyclicIntersectionOfUnion", {{id, intersectionIndex}, {pointee, unionIndex}} ); } } } } } } } void Simplifier::cyclicUnionOfIntersection(Id id) { // FIXME: This has pretty terrible runtime complexity. for (const auto [union_, unionIndex] : Query(&egraph, id)) { Slice unionOperands = union_->operands(); for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) { const Id unionPart = find(unionOperands[unionOperandIndex]); for (const auto [bound, _boundIndex] : Query(&egraph, unionPart)) { const Id pointee = find(mappingIdToClass.at(bound->value())); for (const auto [intersection, intersectionIndex] : Query(&egraph, pointee)) { Slice intersectionOperands = intersection->operands(); for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionOperands.size(); ++intersectionOperandIndex) { const Id intersectionPart = find(intersectionOperands[intersectionOperandIndex]); if (intersectionPart == id) { std::vector newIntersectionParts(intersectionOperands.begin(), intersectionOperands.end()); newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); if (!newIntersectionParts.empty()) { Id newIntersection = mkIntersection(egraph, std::move(newIntersectionParts)); std::vector newIntersectionParts(unionOperands.begin(), unionOperands.end()); newIntersectionParts.erase(newIntersectionParts.begin() + unionOperandIndex); newIntersectionParts.push_back(newIntersection); subst( id, mkUnion(egraph, std::move(newIntersectionParts)), "cyclicUnionOfIntersection", {{id, unionIndex}, {pointee, intersectionIndex}} ); } } } } } } } } void Simplifier::expandNegation(Id id) { for (const auto [negation, index] : Query{&egraph, id}) { if (isTag(negation->operands()[0])) return; CanonicalizedType canonicalized; unionUnknown(egraph, canonicalized); const bool ok = subtract(egraph, canonicalized, negation->operands()[0]); if (!ok) continue; subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}}); } } /** * Let A be a class-node having the form B & C1 & ... & Cn * And B be a class-node having the form (D | E) * * Create a class containing the node (C1 & ... & Cn & D) | (C1 & ... & Cn & E) * * This function does nothing and returns nullopt if A and B are cyclic. */ static std::optional distributeIntersectionOfUnion( EGraph& egraph, Id outerClass, const Intersection* outerIntersection, Id innerClass, const Union* innerUnion ) { Slice outerOperands = outerIntersection->operands(); std::vector newOperands; newOperands.reserve(innerUnion->operands().size()); for (Id innerOperand : innerUnion->operands()) { if (isTag(egraph, innerOperand)) continue; if (innerOperand == outerClass) { // Skip cyclic intersections of unions. There's a separate // rule to get rid of those. return std::nullopt; } std::vector intersectionParts; intersectionParts.reserve(outerOperands.size()); intersectionParts.push_back(innerOperand); for (const Id op : outerOperands) { if (isTag(egraph, op)) { break; } if (op != innerClass) intersectionParts.push_back(op); } newOperands.push_back(mkIntersection(egraph, intersectionParts)); } return mkUnion(egraph, std::move(newOperands)); } // A & (B | C) -> (A & B) | (A & C) // // A & B & (C | D) -> A & (B & (C | D)) // -> A & ((B & C) | (B & D)) // -> (A & B & C) | (A & B & D) void Simplifier::intersectionOfUnion(Id id) { id = find(id); for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) { // For each operand O // For each node N // If N is a union U // Create a new union comprised of every operand except O intersected with every operand of U const Slice operands = intersection->operands(); if (operands.size() < 2) return; if (occurs(egraph, id, operands)) continue; for (Id operand : operands) { operand = find(operand); if (operand == id) break; // Optimization: Decline to distribute any unions on an eclass that // also contains a terminal node. if (isTerminal(egraph, operand)) continue; for (const auto [operandUnion, unionIndex] : Query(&egraph, operand)) { if (occurs(egraph, id, operandUnion->operands())) continue; std::optional distributed = distributeIntersectionOfUnion(egraph, id, intersection, operand, operandUnion); if (distributed) subst(id, *distributed, "intersectionOfUnion", {{id, intersectionIndex}, {operand, unionIndex}}); } } } } // {"a": b} & {"a": c, ...} => {"a": b & c, ...} void Simplifier::intersectTableProperty(Id id) { for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) { const Slice intersectionParts = intersection->operands(); for (size_t i = 0; i < intersection->operands().size(); ++i) { const Id iId = intersection->operands()[i]; for (size_t j = 0; j < intersection->operands().size(); ++j) { if (i == j) continue; const Id jId = intersection->operands()[j]; if (iId == jId) continue; for (const auto [table1, table1Index] : Query(&egraph, iId)) { const TableType* table1Ty = Luau::get(table1->value()); LUAU_ASSERT(table1Ty); if (table1Ty->props.size() != 1) continue; for (const auto [table2, table2Index] : Query(&egraph, jId)) { const TableType* table2Ty = Luau::get(table2->value()); LUAU_ASSERT(table2Ty); auto it = table2Ty->props.find(table1Ty->props.begin()->first); if (it != table2Ty->props.end()) { std::vector newIntersectionParts; newIntersectionParts.reserve(intersectionParts.size() - 1); for (size_t index = 0; index < intersectionParts.size(); ++index) { if (index != i && index != j) newIntersectionParts.push_back(intersectionParts[index]); } Id newTableProp = egraph.add(Intersection{ toId(egraph, builtinTypes, mappingIdToClass, stringCache, it->second.type()), toId(egraph, builtinTypes, mappingIdToClass, stringCache, table1Ty->props.begin()->second.type()) }); newIntersectionParts.push_back(egraph.add(TTable{jId, {stringCache.add(it->first)}, {newTableProp}})); subst( id, mkIntersection(egraph, std::move(newIntersectionParts)), "intersectTableProperty", {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} ); } } } } } } } // { prop: never } == never void Simplifier::uninhabitedTable(Id id) { for (const auto [table, tableIndex] : Query(&egraph, id)) { const TableType* tt = Luau::get(table->value()); LUAU_ASSERT(tt); for (const auto& [propName, prop] : tt->props) { if (prop.readTy && Luau::get(follow(*prop.readTy))) { subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); return; } if (prop.writeTy && Luau::get(follow(*prop.writeTy))) { subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); return; } } } for (const auto [table, tableIndex] : Query(&egraph, id)) { for (Id propType : table->propTypes()) { if (isTag(propType)) { subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); return; } } } } void Simplifier::unneededTableModification(Id id) { for (const auto [tbl, tblIndex] : Query(&egraph, id)) { const Id basis = tbl->getBasis(); for (const auto [importedTbl, importedTblIndex] : Query(&egraph, basis)) { const TableType* tt = Luau::get(importedTbl->value()); LUAU_ASSERT(tt); bool skip = false; for (size_t i = 0; i < tbl->propNames.size(); ++i) { StringId propName = tbl->propNames[i]; const Id propType = tbl->propTypes()[i]; Id importedProp = toId(egraph, builtinTypes, mappingIdToClass, stringCache, tt->props.at(stringCache.asString(propName)).type()); if (find(importedProp) != find(propType)) { skip = true; break; } } if (!skip) subst(id, basis, "unneededTableModification", {{id, tblIndex}, {basis, importedTblIndex}}); } } } void Simplifier::builtinTypeFunctions(Id id) { for (const auto [tfun, index] : Query(&egraph, id)) { const Slice& args = tfun->operands(); if (args.size() != 2) continue; const std::string& name = tfun->value()->function->name; if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") { if (isTag(args[0]) && isTag(args[1])) { subst(id, add(TNumber{}), "builtinTypeFunctions", {{id, index}}); } } } } // Replace union<>, intersect<>, and refine<> with unions or intersections. // These type functions exist primarily to cause simplification to defer until // particular points in execution, so it is safe to get rid of them here. // // It's not clear that these type functions should exist at all. void Simplifier::iffyTypeFunctions(Id id) { for (const auto [tfun, index] : Query(&egraph, id)) { const Slice& args = tfun->operands(); const std::string& name = tfun->value()->function->name; if (name == "union") subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); else if (name == "intersect") subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); } } // Replace instances of `lt` and `le` when either X or Y is `number` // or `string` with `boolean`. Lua semantics are that if we see the expression: // // x < y // // ... we error if `x` and `y` don't have the same type. We know that for // `string` and `number`, comparisons will always return a boolean. So if either // of the arguments to `lt<>` are equivalent to `number` or `string`, then the // type is effectively `boolean`: either the other type is equivalent, in which // case we eval to `boolean`, or we diverge (raise an error). void Simplifier::strictMetamethods(Id id) { for (const auto [tfun, index] : Query(&egraph, id)) { const Slice& args = tfun->operands(); const std::string& name = tfun->value()->function->name; if (!(name == "lt" || name == "le") || args.size() != 2) continue; if (isTag(args[0]) || isTag(args[0]) || isTag(args[1]) || isTag(args[1])) { subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}}); } } } static void deleteSimplifier(Simplifier* s) { delete s; } SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes) { return SimplifierPtr{new Simplifier(arena, builtinTypes), &deleteSimplifier}; } } // namespace Luau::EqSatSimplification namespace Luau { std::optional eqSatSimplify(NotNull simplifier, TypeId ty) { using namespace Luau::EqSatSimplification; std::unordered_map newMappings; Id rootId = toId(simplifier->egraph, simplifier->builtinTypes, newMappings, simplifier->stringCache, ty); simplifier->mappingIdToClass.insert(newMappings.begin(), newMappings.end()); Simplifier::RewriteRuleFn rules[] = { &Simplifier::simplifyUnion, &Simplifier::uninhabitedIntersection, &Simplifier::intersectWithNegatedClass, &Simplifier::intersectWithNegatedAtom, &Simplifier::intersectWithNoRefine, &Simplifier::cyclicIntersectionOfUnion, &Simplifier::cyclicUnionOfIntersection, &Simplifier::expandNegation, &Simplifier::intersectionOfUnion, &Simplifier::intersectTableProperty, &Simplifier::uninhabitedTable, &Simplifier::unneededTableModification, &Simplifier::builtinTypeFunctions, &Simplifier::iffyTypeFunctions, &Simplifier::strictMetamethods, }; std::unordered_set seen; VecDeque worklist; bool progressed = true; int count = 0; const int MAX_COUNT = 1000; if (FFlag::DebugLuauLogSimplificationToDot) std::ofstream("begin.dot") << toDot(simplifier->stringCache, simplifier->egraph); auto& egraph = simplifier->egraph; const auto& builtinTypes = simplifier->builtinTypes; auto& arena = simplifier->arena; if (FFlag::DebugLuauLogSimplification) printf(">> simplify %s\n", toString(ty).c_str()); while (progressed && count < MAX_COUNT) { progressed = false; worklist.clear(); seen.clear(); rootId = egraph.find(rootId); worklist.push_back(rootId); if (FFlag::DebugLuauLogSimplification) { std::vector newTypeFunctions; const TypeId t = fromId(egraph, simplifier->stringCache, builtinTypes, arena, newTypeFunctions, rootId); std::cout << "Begin (" << uint32_t(egraph.find(rootId)) << ")\t" << toString(t) << '\n'; } while (!worklist.empty() && count < MAX_COUNT) { Id id = egraph.find(worklist.front()); worklist.pop_front(); const bool isFresh = seen.insert(id).second; if (!isFresh) continue; simplifier->substs.clear(); // Optimization: If this class alraedy has a terminal node, don't // try to run any rules on it. bool shouldAbort = false; for (const EType& enode : egraph[id].nodes) { if (isTerminal(enode)) { shouldAbort = true; break; } } if (shouldAbort) continue; for (const EType& enode : egraph[id].nodes) addChildren(egraph, &enode, worklist); for (Simplifier::RewriteRuleFn rule : rules) (simplifier.get()->*rule)(id); if (simplifier->substs.empty()) continue; for (const Subst& subst : simplifier->substs) { if (subst.newClass == subst.eclass) continue; if (FFlag::DebugLuauExtraEqSatSanityChecks) { const Id never = egraph.find(egraph.add(TNever{})); const Id str = egraph.find(egraph.add(TString{})); const Id unk = egraph.find(egraph.add(TUnknown{})); LUAU_ASSERT(never != str); LUAU_ASSERT(never != unk); } const bool isFresh = egraph.merge(subst.newClass, subst.eclass); ++count; if (FFlag::DebugLuauLogSimplification && isFresh) std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n'; if (FFlag::DebugLuauLogSimplificationToDot) { std::string filename = format("step%03d.dot", count); std::ofstream(filename) << toDot(simplifier->stringCache, egraph); } if (FFlag::DebugLuauExtraEqSatSanityChecks) { const Id never = egraph.find(egraph.add(TNever{})); const Id str = egraph.find(egraph.add(TString{})); const Id unk = egraph.find(egraph.add(TUnknown{})); const Id trueId = egraph.find(egraph.add(SBoolean{true})); LUAU_ASSERT(never != str); LUAU_ASSERT(never != unk); LUAU_ASSERT(never != trueId); } progressed |= isFresh; } egraph.rebuild(); } } EqSatSimplificationResult result; result.result = fromId(egraph, simplifier->stringCache, builtinTypes, arena, result.newTypeFunctions, rootId); if (FFlag::DebugLuauLogSimplification) printf("<< simplify %s\n", toString(result.result).c_str()); return result; } } // namespace Luau