luau/Analysis/src/EqSatSimplification.cpp
Vighnesh-V 2e6fdd90a0
Some checks failed
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Has been cancelled
build / macos (push) Has been cancelled
build / macos-arm (push) Has been cancelled
build / ubuntu (push) Has been cancelled
build / windows (Win32) (push) Has been cancelled
build / windows (x64) (push) Has been cancelled
build / coverage (push) Has been cancelled
build / web (push) Has been cancelled
release / macos (push) Has been cancelled
release / ubuntu (push) Has been cancelled
release / windows (push) Has been cancelled
release / web (push) Has been cancelled
Sync to upstream/release/655 (#1563)
## New Solver
* Type functions should be able to signal whether or not irreducibility
is due to an error
* Do not generate extra expansion constraint for uninvoked user-defined
type functions
* Print in a user-defined type function reports as an error instead of
logging to stdout
* Many e-graphs bugfixes and performance improvements
* Many general bugfixes and improvements to the new solver as a whole
* Fixed issue with used-defined type functions not being able to call
each other
* Infer types of globals under new type solver

## Fragment Autocomplete
* Miscellaneous fixes to make interop with the old solver better

## Runtime
* Support disabling specific built-in functions from being fast-called
or constant-evaluated (Closes #1538)
* New compiler option `disabledBuiltins` accepts a list of library
function names like "tonumber" or "math.cos"
* Added constant folding for vector arithmetic
* Added constant propagation and type inference for vector globals
(Fixes #1511)
* New compiler option `librariesWithKnownMembers` accepts a list of
libraries for members of which a request for constant value and/or type
will be made
* `libraryMemberTypeCb` callback is called to get the type of a global,
return one of the `LuauBytecodeType` values. 'boolean', 'number',
'string' and 'vector' type are supported.
* `libraryMemberConstantCb` callback is called to setup the constant
value of a global. To set a value, C API `luau_set_compile_constant_*`
or C++ API `setCompileConstant*` functions should be used.

---
Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: Daniel Angel <danielangel@roblox.com>
Co-authored-by: Jonathan Kelaty <jkelaty@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>

---------

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Alexander McCord <amccord@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Aviral Goel <agoel@roblox.com>
Co-authored-by: David Cope <dcope@roblox.com>
Co-authored-by: Lily Brown <lbrown@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
Co-authored-by: Junseo Yoo <jyoo@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <61795485+vrn-sn@users.noreply.github.com>
Co-authored-by: Alexander Youngblood <ayoungblood@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Andrew Miranti <amiranti@roblox.com>
Co-authored-by: Shiqi Ai <sai@roblox.com>
Co-authored-by: Yohoo Lin <yohoo@roblox.com>
Co-authored-by: Daniel Angel <danielangel@roblox.com>
Co-authored-by: Jonathan Kelaty <jkelaty@roblox.com>
2024-12-13 13:02:30 -08:00

2632 lines
84 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/EqSatSimplification.h"
#include "Luau/EqSatSimplificationImpl.h"
#include "Luau/EGraph.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/StringUtils.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFunction.h"
#include "Luau/VisitType.h"
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplification)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplificationToDot)
LUAU_FASTFLAGVARIABLE(DebugLuauExtraEqSatSanityChecks)
namespace Luau::EqSatSimplification
{
using Id = Luau::EqSat::Id;
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
using Luau::EqSat::Slice;
TTable::TTable(Id basis)
{
storage.push_back(basis);
}
// I suspect that this is going to become a performance hotspot. It would be
// nice to avoid allocating propTypes_
TTable::TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_)
: propNames(std::move(propNames_))
{
storage.reserve(propTypes_.size() + 1);
storage.push_back(basis);
storage.insert(storage.end(), propTypes_.begin(), propTypes_.end());
LUAU_ASSERT(storage.size() == 1 + propTypes_.size());
}
Id TTable::getBasis() const
{
LUAU_ASSERT(!storage.empty());
return storage[0];
}
Slice<const Id> TTable::propTypes() const
{
LUAU_ASSERT(propNames.size() + 1 == storage.size());
return Slice{storage.data() + 1, propNames.size()};
}
Slice<Id> TTable::mutableOperands()
{
return Slice{storage.data(), storage.size()};
}
Slice<const Id> TTable::operands() const
{
return Slice{storage.data(), storage.size()};
}
bool TTable::operator==(const TTable& rhs) const
{
return storage == rhs.storage && propNames == rhs.propNames;
}
size_t TTable::Hash::operator()(const TTable& value) const
{
size_t hash = 0;
// We're using pointers here, which does mean platform divergence. I think
// it's okay? (famous last words, I know)
for (StringId s : value.propNames)
EqSat::hashCombine(hash, EqSat::languageHash(s));
EqSat::hashCombine(hash, EqSat::languageHash(value.storage));
return hash;
}
uint32_t StringCache::add(std::string_view s)
{
size_t hash = std::hash<std::string_view>()(s);
if (uint32_t* it = strings.find(hash))
return *it;
char* storage = static_cast<char*>(allocator.allocate(s.size()));
memcpy(storage, s.data(), s.size());
uint32_t result = uint32_t(views.size());
views.emplace_back(storage, s.size());
strings[hash] = result;
return result;
}
std::string_view StringCache::asStringView(StringId id) const
{
LUAU_ASSERT(id < views.size());
return views[id];
}
std::string StringCache::asString(StringId id) const
{
return std::string{asStringView(id)};
}
template<typename T>
Simplify::Data Simplify::make(const EGraph&, const T&) const
{
return true;
}
void Simplify::join(Data& left, const Data& right) const
{
left = left || right;
}
using EClass = Luau::EqSat::EClass<EType, Simplify::Data>;
// A terminal type is a type that does not contain any other types.
// Examples: any, unknown, number, string, boolean, nil, table, class, thread, function
//
// All class types are also terminal.
static bool isTerminal(const EType& node)
{
return node.get<TNil>() || node.get<TBoolean>() || node.get<TNumber>() || node.get<TString>() || node.get<TThread>() ||
node.get<TTopFunction>() || node.get<TTopTable>() || node.get<TTopClass>() || node.get<TBuffer>() || node.get<TOpaque>() ||
node.get<SBoolean>() || node.get<SString>() || node.get<TClass>() || node.get<TAny>() || node.get<TError>() || node.get<TUnknown>() ||
node.get<TNever>() || node.get<TNoRefine>();
}
static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs)
{
// If either node is non-terminal, then we early exit: we're not going to
// do a state space search for whether something like:
// (A | B | C | D) & (E | F | G | H)
// ... is a disjoint intersection.
if (!isTerminal(lhs) || !isTerminal(rhs))
return false;
// Special case some types that aren't strict, disjoint subsets.
if (lhs.get<TTopClass>() || lhs.get<TClass>())
return !(rhs.get<TTopClass>() || rhs.get<TClass>());
// Handling strings / booleans: these are the types for which we
// expect something like:
//
// "foo" & ~"bar"
//
// ... to simplify to "foo".
if (lhs.get<TString>())
return !(rhs.get<TString>() || rhs.get<SString>());
if (lhs.get<TBoolean>())
return !(rhs.get<TBoolean>() || rhs.get<SBoolean>());
if (auto lhsSString = lhs.get<SString>())
{
auto rhsSString = rhs.get<SString>();
if (!rhsSString)
return !rhs.get<TString>();
return lhsSString->value() != rhsSString->value();
}
if (auto lhsSBoolean = lhs.get<SBoolean>())
{
auto rhsSBoolean = rhs.get<SBoolean>();
if (!rhsSBoolean)
return !rhs.get<TBoolean>();
return lhsSBoolean->value() != rhsSBoolean->value();
}
// At this point:
// - We know both nodes are terminal
// - We know that the LHS is not any boolean, string, or class
// At this point, we have two classes of checks left:
// - Whether the two enodes are exactly the same set (now that the static
// sets have been covered).
// - Whether one of the enodes is a large semantic set such as TAny,
// TUnknown, or TError.
return !(
lhs.index() == rhs.index() ||
lhs.get<TUnknown>() || rhs.get<TUnknown>() || lhs.get<TAny>() || rhs.get<TAny>() || lhs.get<TNoRefine>() || rhs.get<TNoRefine>() ||
lhs.get<TError>() || rhs.get<TError>() || lhs.get<TOpaque>() || rhs.get<TOpaque>()
);
}
static bool isTerminal(const EGraph& egraph, Id eclass)
{
const auto& nodes = egraph[eclass].nodes;
return std::any_of(
nodes.begin(),
nodes.end(),
[](auto& a)
{
return isTerminal(a);
}
);
}
Id mkUnion(EGraph& egraph, std::vector<Id> parts)
{
if (parts.size() == 0)
return egraph.add(TNever{});
else if (parts.size() == 1)
return parts[0];
else
return egraph.add(Union{std::move(parts)});
}
Id mkIntersection(EGraph& egraph, std::vector<Id> parts)
{
if (parts.size() == 0)
return egraph.add(TUnknown{});
else if (parts.size() == 1)
return parts[0];
else
return egraph.add(Intersection{std::move(parts)});
}
struct ListRemover
{
std::unordered_map<TypeId, std::pair<size_t, size_t>>& mappings2;
TypeId ty;
~ListRemover()
{
mappings2.erase(ty);
}
};
/*
* Crucial subtlety: It is very extremely important that enodes and eclasses are
* immutable. Mutating an enode would mean that it is no longer equivalent to
* other nodes in the same eclass.
*
* At the same time, many TypeIds are NOT immutable!
*
* The thing that makes this navigable is that it is okay if the same TypeId is
* imported as a different Id at different times as type inference runs. For
* example, if we at one point import a BlockedType as a TOpaque, and later
* import that same TypeId as some other enode type, this is all completely
* okay.
*
* The main thing we have to be very cautious about, I think, is unsealed
* tables. Unsealed table types have properties imperatively inserted into them
* as type inference runs. If we were to encode that TypeId as part of an
* enode, we could run into a situation where the egraph makes incorrect
* assumptions about the table.
*
* The solution is pretty simple: Never use the contents of a mutable TypeId in
* any reduction rule. TOpaque is always okay because we never actually poke
* around inside the TypeId to do anything.
*/
Id toId(
EGraph& egraph,
NotNull<BuiltinTypes> builtinTypes,
std::unordered_map<size_t, Id>& mappingIdToClass,
std::unordered_map<TypeId, std::pair<size_t, size_t>>& typeToMappingId, // (TypeId: (MappingId, count))
std::unordered_set<Id>& boundNodes,
StringCache& strings,
TypeId ty
)
{
ty = follow(ty);
// First, handle types which do not contain other types. They obviously
// cannot participate in cycles, so we don't have to check for that.
if (auto freeTy = get<FreeType>(ty))
return egraph.add(TOpaque{ty});
else if (get<GenericType>(ty))
return egraph.add(TOpaque{ty});
else if (auto prim = get<PrimitiveType>(ty))
{
switch (prim->type)
{
case Luau::PrimitiveType::NilType:
return egraph.add(TNil{});
case Luau::PrimitiveType::Boolean:
return egraph.add(TBoolean{});
case Luau::PrimitiveType::Number:
return egraph.add(TNumber{});
case Luau::PrimitiveType::String:
return egraph.add(TString{});
case Luau::PrimitiveType::Thread:
return egraph.add(TThread{});
case Luau::PrimitiveType::Function:
return egraph.add(TTopFunction{});
case Luau::PrimitiveType::Table:
return egraph.add(TTopTable{});
case Luau::PrimitiveType::Buffer:
return egraph.add(TBuffer{});
default:
LUAU_ASSERT(!"Unimplemented");
return egraph.add(Invalid{});
}
}
else if (auto s = get<SingletonType>(ty))
{
if (auto bs = get<BooleanSingleton>(s))
return egraph.add(SBoolean{bs->value});
else if (auto ss = get<StringSingleton>(s))
return egraph.add(SString{strings.add(ss->value)});
else
LUAU_ASSERT(!"Unexpected");
}
else if (get<BlockedType>(ty))
return egraph.add(TOpaque{ty});
else if (get<PendingExpansionType>(ty))
return egraph.add(TOpaque{ty});
else if (get<FunctionType>(ty))
return egraph.add(TFunction{ty});
else if (ty == builtinTypes->classType)
return egraph.add(TTopClass{});
else if (get<ClassType>(ty))
return egraph.add(TClass{ty});
else if (get<AnyType>(ty))
return egraph.add(TAny{});
else if (get<ErrorType>(ty))
return egraph.add(TError{});
else if (get<UnknownType>(ty))
return egraph.add(TUnknown{});
else if (get<NeverType>(ty))
return egraph.add(TNever{});
// Now handle composite types.
if (auto it = typeToMappingId.find(ty); it != typeToMappingId.end())
{
auto& [mappingId, count] = it->second;
++count;
Id res = egraph.add(TBound{mappingId});
boundNodes.insert(res);
return res;
}
typeToMappingId.emplace(ty, std::pair{mappingIdToClass.size(), 0});
ListRemover lr{typeToMappingId, ty};
auto cache = [&](Id res)
{
const auto& [mappingId, count] = typeToMappingId.at(ty);
if (count > 0)
mappingIdToClass.emplace(mappingId, res);
return res;
};
if (auto tt = get<TableType>(ty))
return egraph.add(TImportedTable{ty});
else if (get<MetatableType>(ty))
return egraph.add(TOpaque{ty});
else if (auto ut = get<UnionType>(ty))
{
std::vector<EqSat::Id> parts;
for (TypeId part : ut)
parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part));
return cache(mkUnion(egraph, std::move(parts)));
}
else if (auto it = get<IntersectionType>(ty))
{
std::vector<Id> parts;
for (TypeId part : it)
parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part));
LUAU_ASSERT(parts.size() > 1);
return cache(mkIntersection(egraph, std::move(parts)));
}
else if (auto negation = get<NegationType>(ty))
{
Id part = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, negation->ty);
return cache(egraph.add(Negation{std::array{part}}));
}
else if (auto tfun = get<TypeFunctionInstanceType>(ty))
{
LUAU_ASSERT(tfun->packArguments.empty());
std::vector<Id> parts;
parts.reserve(tfun->typeArguments.size());
for (TypeId part : tfun->typeArguments)
parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part));
// This looks sily, but we're making a copy of the specific
// `TypeFunctionInstanceType` outside of the provided arena so that
// we can access the members without fear of the specific TFIT being
// overwritten with a bound type.
return cache(egraph.add(TTypeFun{
std::make_shared<const TypeFunctionInstanceType>(
tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData
),
std::move(parts)
}));
}
else if (get<NoRefineType>(ty))
return egraph.add(TNoRefine{});
else
{
LUAU_ASSERT(!"Unhandled Type");
return cache(egraph.add(Invalid{}));
}
}
Id toId(EGraph& egraph, NotNull<BuiltinTypes> builtinTypes, std::unordered_map<size_t, Id>& mappingIdToClass, StringCache& strings, TypeId ty)
{
std::unordered_map<TypeId, std::pair<size_t, size_t>> typeToMappingId;
std::unordered_set<Id> boundNodes;
Id id = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, ty);
for (Id id : boundNodes)
{
for (const auto [tb, _index] : Query<TBound>(&egraph, id))
{
Id bindee = mappingIdToClass.at(tb->value());
egraph.merge(id, bindee);
}
}
egraph.rebuild();
return egraph.find(id);
}
// We apply a penalty to cyclic types to guide the system away from them where
// possible.
static const int CYCLE_PENALTY = 5000;
// Composite types have cost equal to the sum of the costs of their parts plus a
// constant factor.
static const int SET_TYPE_PENALTY = 1;
static const int TABLE_TYPE_PENALTY = 2;
static const int NEGATION_PENALTY = 2;
static const int TFUN_PENALTY = 2;
// FIXME. We don't have an accurate way to score a TImportedTable table against
// a TTable.
static const int IMPORTED_TABLE_PENALTY = 50;
// TBound shouldn't ever be selected as the best node of a class unless we are
// debugging eqsat itself and need to stringify eclasses. We thus penalize it
// so heavily that we'll use any other alternative.
static const int BOUND_PENALTY = 999999999;
// TODO iteration count limit
// TODO also: accept an argument which is the maximum cost to consider before
// abandoning the count.
// TODO: the egraph should be the first parameter.
static size_t computeCost(std::unordered_map<Id, size_t>& bestNodes, const EGraph& egraph, std::unordered_map<Id, size_t>& costs, Id id)
{
if (auto it = costs.find(id); it != costs.end())
return it->second;
const std::vector<EType>& nodes = egraph[id].nodes;
size_t minCost = std::numeric_limits<size_t>::max();
size_t bestNode = std::numeric_limits<size_t>::max();
const auto updateCost = [&](size_t cost, size_t node)
{
if (cost < minCost)
{
minCost = cost;
bestNode = node;
}
};
// First, quickly scan for a terminal type. If we can find one, it is obviously the best.
for (size_t index = 0; index < nodes.size(); ++index)
{
if (isTerminal(nodes[index]))
{
minCost = 1;
bestNode = index;
costs[id] = 1;
const auto [iter, isFresh] = bestNodes.insert({id, index});
// If we are forcing the cost function to select a specific node,
// then we still need to traverse into that node, even if this
// particular node is the obvious choice under normal circumstances.
if (isFresh || iter->second == index)
return 1;
}
}
// If we recur into this type before this call frame completes, it is
// because this type participates in a cycle.
costs[id] = CYCLE_PENALTY;
auto computeChildren = [&](Slice<const Id> parts, size_t maxCost) -> std::optional<size_t>
{
size_t cost = 0;
for (Id part : parts)
{
cost += computeCost(bestNodes, egraph, costs, part);
// Abandon this node if it is too costly
if (cost > maxCost)
return std::nullopt;
}
return cost;
};
size_t startIndex = 0;
size_t endIndex = nodes.size();
// FFlag::DebugLuauLogSimplification will sometimes stringify an Id and pass
// in a prepopulated bestNodes map. If that mapping already has an index
// for this Id, don't look at the other nodes of this class.
if (auto it = bestNodes.find(id); it != bestNodes.end())
{
LUAU_ASSERT(it->second < nodes.size());
startIndex = it->second;
endIndex = startIndex + 1;
}
for (size_t index = startIndex; index < endIndex; ++index)
{
const auto& node = nodes[index];
if (node.get<TBound>())
updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound.
else if (node.get<TFunction>())
{
minCost = 1;
bestNode = index;
}
else if (auto tbl = node.get<TTable>())
{
// TODO: We could make the penalty a parameter to computeChildren.
std::optional<size_t> maybeCost = computeChildren(tbl->operands(), minCost);
if (maybeCost)
updateCost(TABLE_TYPE_PENALTY + *maybeCost, index);
}
else if (node.get<TImportedTable>())
{
minCost = IMPORTED_TABLE_PENALTY;
bestNode = index;
}
else if (auto u = node.get<Union>())
{
std::optional<size_t> maybeCost = computeChildren(u->operands(), minCost);
if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index);
}
else if (auto i = node.get<Intersection>())
{
std::optional<size_t> maybeCost = computeChildren(i->operands(), minCost);
if (maybeCost)
updateCost(SET_TYPE_PENALTY + *maybeCost, index);
}
else if (auto negation = node.get<Negation>())
{
std::optional<size_t> maybeCost = computeChildren(negation->operands(), minCost);
if (maybeCost)
updateCost(NEGATION_PENALTY + *maybeCost, index);
}
else if (auto tfun = node.get<TTypeFun>())
{
std::optional<size_t> maybeCost = computeChildren(tfun->operands(), minCost);
if (maybeCost)
updateCost(TFUN_PENALTY + *maybeCost, index);
}
}
LUAU_ASSERT(bestNode < nodes.size());
costs[id] = minCost;
bestNodes.insert({id, bestNode});
return minCost;
}
static std::unordered_map<Id, size_t> computeBestResult(const EGraph& egraph, Id id, const std::unordered_map<Id, size_t>& forceNodes)
{
std::unordered_map<Id, size_t> costs;
std::unordered_map<Id, size_t> bestNodes = forceNodes;
computeCost(bestNodes, egraph, costs, id);
return bestNodes;
}
static std::unordered_map<Id, size_t> computeBestResult(const EGraph& egraph, Id id)
{
std::unordered_map<Id, size_t> costs;
std::unordered_map<Id, size_t> bestNodes;
computeCost(bestNodes, egraph, costs, id);
return bestNodes;
}
TypeId fromId(
EGraph& egraph,
const StringCache& strings,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
const std::unordered_map<Id, size_t>& bestNodes,
std::unordered_map<Id, TypeId>& seen,
std::vector<TypeId>& newTypeFunctions,
Id rootId
);
TypeId flattenTableNode(
EGraph& egraph,
const StringCache& strings,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
const std::unordered_map<Id, size_t>& bestNodes,
std::unordered_map<Id, TypeId>& seen,
std::vector<TypeId>& newTypeFunctions,
Id rootId
)
{
std::vector<const TTable*> stack;
std::unordered_set<Id> seenIds;
Id id = rootId;
const TImportedTable* importedTable = nullptr;
while (true)
{
size_t index = bestNodes.at(id);
const auto& eclass = egraph[id];
const auto [_iter, isFresh] = seenIds.insert(id);
if (!isFresh)
{
// If a TTable is its own basis, it must be the case that some other
// node on this eclass is a TImportedTable. Let's use that.
bool found = false;
for (size_t i = 0; i < eclass.nodes.size(); ++i)
{
if (eclass.nodes[i].get<TImportedTable>())
{
found = true;
index = i;
break;
}
}
if (!found)
{
// If we couldn't find one, we don't know what to do. Use ErrorType.
LUAU_ASSERT(0);
return builtinTypes->errorType;
}
}
const auto& node = eclass.nodes[index];
if (const TTable* ttable = node.get<TTable>())
{
stack.push_back(ttable);
id = ttable->getBasis();
continue;
}
else if (const TImportedTable* ti = node.get<TImportedTable>())
{
importedTable = ti;
break;
}
else
LUAU_ASSERT(0);
}
TableType resultTable;
if (importedTable)
{
const TableType* t = Luau::get<TableType>(importedTable->value());
LUAU_ASSERT(t);
resultTable = *t; // Intentional shallow clone here
}
while (!stack.empty())
{
const TTable* t = stack.back();
stack.pop_back();
for (size_t i = 0; i < t->propNames.size(); ++i)
{
StringId propName = t->propNames[i];
const Id propType = t->propTypes()[i];
resultTable.props[strings.asString(propName)] = Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)};
}
}
return arena->addType(std::move(resultTable));
}
TypeId fromId(
EGraph& egraph,
const StringCache& strings,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
const std::unordered_map<Id, size_t>& bestNodes,
std::unordered_map<Id, TypeId>& seen,
std::vector<TypeId>& newTypeFunctions,
Id rootId
)
{
if (auto it = seen.find(rootId); it != seen.end())
return it->second;
size_t index = bestNodes.at(rootId);
LUAU_ASSERT(index <= egraph[rootId].nodes.size());
const EType& node = egraph[rootId].nodes[index];
if (node.get<TNil>())
return builtinTypes->nilType;
else if (node.get<TBoolean>())
return builtinTypes->booleanType;
else if (node.get<TNumber>())
return builtinTypes->numberType;
else if (node.get<TString>())
return builtinTypes->stringType;
else if (node.get<TThread>())
return builtinTypes->threadType;
else if (node.get<TTopFunction>())
return builtinTypes->functionType;
else if (node.get<TTopTable>())
return builtinTypes->tableType;
else if (node.get<TTopClass>())
return builtinTypes->classType;
else if (node.get<TBuffer>())
return builtinTypes->bufferType;
else if (auto opaque = node.get<TOpaque>())
return opaque->value();
else if (auto b = node.get<SBoolean>())
return b->value() ? builtinTypes->trueType : builtinTypes->falseType;
else if (auto s = node.get<SString>())
return arena->addType(SingletonType{StringSingleton{strings.asString(s->value())}});
else if (auto fun = node.get<TFunction>())
return fun->value();
else if (auto tbl = node.get<TTable>())
{
TypeId res = arena->addType(BlockedType{});
seen[rootId] = res;
TypeId flattened = flattenTableNode(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId);
asMutable(res)->ty.emplace<BoundType>(flattened);
return flattened;
}
else if (auto tbl = node.get<TImportedTable>())
return tbl->value();
else if (auto cls = node.get<TClass>())
return cls->value();
else if (node.get<TAny>())
return builtinTypes->anyType;
else if (node.get<TError>())
return builtinTypes->errorType;
else if (node.get<TUnknown>())
return builtinTypes->unknownType;
else if (node.get<TNever>())
return builtinTypes->neverType;
else if (auto u = node.get<Union>())
{
Slice<const Id> parts = u->operands();
if (parts.empty())
return builtinTypes->neverType;
else if (parts.size() == 1)
{
TypeId placeholder = arena->addType(BlockedType{});
seen[rootId] = placeholder;
auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]);
if (follow(result) == placeholder)
{
emplaceType<GenericType>(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE");
}
else
{
emplaceType<BoundType>(asMutable(placeholder), result);
}
return result;
}
else
{
TypeId res = arena->addType(BlockedType{});
seen[rootId] = res;
std::vector<TypeId> partTypes;
partTypes.reserve(parts.size());
for (Id part : parts)
partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part));
asMutable(res)->ty.emplace<UnionType>(std::move(partTypes));
return res;
}
}
else if (auto i = node.get<Intersection>())
{
Slice<const Id> parts = i->operands();
if (parts.empty())
return builtinTypes->neverType;
else if (parts.size() == 1)
{
LUAU_ASSERT(parts[0] != rootId);
return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]);
}
else
{
TypeId res = arena->addType(BlockedType{});
seen[rootId] = res;
std::vector<TypeId> partTypes;
partTypes.reserve(parts.size());
for (Id part : parts)
partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part));
asMutable(res)->ty.emplace<IntersectionType>(std::move(partTypes));
return res;
}
}
else if (auto negation = node.get<Negation>())
{
TypeId res = arena->addType(BlockedType{});
seen[rootId] = res;
TypeId ty = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, negation->operands()[0]);
asMutable(res)->ty.emplace<NegationType>(ty);
return res;
}
else if (auto tfun = node.get<TTypeFun>())
{
TypeId res = arena->addType(BlockedType{});
seen[rootId] = res;
std::vector<TypeId> args;
for (Id part : tfun->operands())
args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part));
auto oldInstance = tfun->value();
asMutable(res)->ty.emplace<TypeFunctionInstanceType>(
oldInstance->function, std::move(args), std::vector<TypePackId>(), oldInstance->userFuncName, oldInstance->userFuncData
);
newTypeFunctions.push_back(res);
return res;
}
else if (node.get<TBound>())
return builtinTypes->errorType;
else if (node.get<TNoRefine>())
return builtinTypes->noRefineType;
else
{
LUAU_ASSERT(!"Unimplemented");
return nullptr;
}
}
static TypeId fromId(
EGraph& egraph,
const StringCache& strings,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
const std::unordered_map<Id, size_t>& forceNodes,
std::vector<TypeId>& newTypeFunctions,
Id rootId
)
{
const std::unordered_map<Id, size_t> bestNodes = computeBestResult(egraph, rootId, forceNodes);
std::unordered_map<Id, TypeId> seen;
return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId);
}
static TypeId fromId(
EGraph& egraph,
const StringCache& strings,
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
std::vector<TypeId>& newTypeFunctions,
Id rootId
)
{
const std::unordered_map<Id, size_t> bestNodes = computeBestResult(egraph, rootId);
std::unordered_map<Id, TypeId> seen;
return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId);
}
Subst::Subst(Id eclass, Id newClass, std::string desc)
: eclass(std::move(eclass))
, newClass(std::move(newClass))
, desc(std::move(desc))
{
}
std::string mkDesc(
EGraph& egraph,
const StringCache& strings,
NotNull<TypeArena> arena,
NotNull<BuiltinTypes> builtinTypes,
Id from,
Id to,
const std::unordered_map<Id, size_t>& forceNodes,
const std::string& rule
)
{
if (!FFlag::DebugLuauLogSimplification)
return "";
std::vector<TypeId> newTypeFunctions;
TypeId fromTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, from);
TypeId toTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, to);
ToStringOptions opts;
opts.useQuestionMarks = false;
const int RULE_PADDING = 35;
const std::string rulePadding(std::max<size_t>(0, RULE_PADDING - rule.size()), ' ');
const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") ";
const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") ";
return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts);
}
std::string mkDesc(EGraph& egraph, const StringCache& strings, NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Id from, Id to, const std::string& rule)
{
if (!FFlag::DebugLuauLogSimplification)
return "";
return mkDesc(egraph, strings, arena, builtinTypes, from, to, {}, rule);
}
static std::string getNodeName(const StringCache& strings, const EType& node)
{
if (node.get<TNil>())
return "nil";
else if (node.get<TBoolean>())
return "boolean";
else if (node.get<TNumber>())
return "number";
else if (node.get<TString>())
return "string";
else if (node.get<TThread>())
return "thread";
else if (node.get<TTopFunction>())
return "function";
else if (node.get<TTopTable>())
return "table";
else if (node.get<TTopClass>())
return "class";
else if (node.get<TBuffer>())
return "buffer";
else if (node.get<TOpaque>())
return "opaque";
else if (auto b = node.get<SBoolean>())
return b->value() ? "true" : "false";
else if (auto s = node.get<SString>())
return "\"" + strings.asString(s->value()) + "\"";
else if (node.get<Union>())
return "\xe2\x88\xaa";
else if (node.get<Intersection>())
return "\xe2\x88\xa9";
else if (auto cls = node.get<TClass>())
{
const ClassType* ct = get<ClassType>(cls->value());
LUAU_ASSERT(ct);
return ct->name;
}
else if (node.get<TAny>())
return "any";
else if (node.get<TError>())
return "error";
else if (node.get<TUnknown>())
return "unknown";
else if (node.get<TNever>())
return "never";
else if (auto tfun = node.get<TTypeFun>())
return "tfun " + tfun->value()->function->name;
else if (node.get<Negation>())
return "~";
else if (node.get<Invalid>())
return "invalid?";
else if (node.get<TBound>())
return "bound";
return "???";
}
std::string toDot(const StringCache& strings, const EGraph& egraph)
{
std::stringstream ss;
ss << "digraph G {" << '\n';
ss << " graph [fontsize=10 fontname=\"Verdana\" compound=true];" << '\n';
ss << " node [shape=record fontsize=10 fontname=\"Verdana\"];" << '\n';
std::set<Id> populated;
for (const auto& [id, eclass] : egraph.getAllClasses())
{
for (const auto& node : eclass.nodes)
{
if (!node.operands().empty())
populated.insert(id);
for (Id op : node.operands())
populated.insert(op);
}
}
for (const auto& [id, eclass] : egraph.getAllClasses())
{
if (!populated.count(id))
continue;
const std::string className = "cluster_" + std::to_string(uint32_t(id));
ss << " subgraph " << className << " {" << '\n';
ss << " node [style=\"rounded,filled\"];" << '\n';
ss << " label = \"" << uint32_t(id) << "\";" << '\n';
ss << " color = blue;" << '\n';
for (size_t index = 0; index < eclass.nodes.size(); ++index)
{
const auto& node = eclass.nodes[index];
const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index);
ss << " " << nodeName << " [label=\"" << label << "\"];" << '\n';
}
ss << " }" << '\n';
}
for (const auto& [id, eclass] : egraph.getAllClasses())
{
for (size_t index = 0; index < eclass.nodes.size(); ++index)
{
const auto& node = eclass.nodes[index];
const std::string label = getNodeName(strings, node);
const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index);
for (Id op : node.operands())
{
op = egraph.find(op);
const std::string destNodeName = "n" + std::to_string(uint32_t(op)) + "_0";
ss << " " << nodeName << " -> " << destNodeName << " [lhead=cluster_" << uint32_t(op) << "];" << '\n';
}
}
}
ss << "}" << '\n';
return ss.str();
}
template<typename Tag>
static Tag const* isTag(const EType& node)
{
return node.get<Tag>();
}
/// Important: Only use this to test for leaf node types like TUnknown and
/// TNumber. Things that we know cannot be simplified any further and are safe
/// to short-circuit on.
///
/// It does a linear scan and exits early, so if a particular eclass has
/// multiple "interesting" representations, this function can surprise you.
template<typename Tag>
static Tag const* isTag(const EGraph& egraph, Id id)
{
for (const auto& node : egraph[id].nodes)
{
if (auto n = isTag<Tag>(node))
return n;
}
return nullptr;
}
struct RewriteRule
{
explicit RewriteRule(EGraph* egraph)
: egraph(egraph)
{
}
virtual void read(std::vector<Subst>& substs, Id eclass, const EType* enode) = 0;
protected:
const EqSat::EClass<EType, Simplify::Data>& get(Id id)
{
return (*egraph)[id];
}
Id find(Id id)
{
return egraph->find(id);
}
Id add(EType enode)
{
return egraph->add(std::move(enode));
}
template<typename Tag>
const Tag* isTag(Id id)
{
for (const auto& node : (*egraph)[id].nodes)
{
if (auto n = node.get<Tag>())
return n;
}
return nullptr;
}
template<typename Tag>
bool isTag(const EType& enode)
{
return enode.get<Tag>();
}
public:
EGraph* egraph;
};
enum SubclassRelationship
{
LeftSuper,
RightSuper,
Unrelated
};
static SubclassRelationship relateClasses(const TClass* leftClass, const TClass* rightClass)
{
const ClassType* leftClassType = Luau::get<ClassType>(leftClass->value());
const ClassType* rightClassType = Luau::get<ClassType>(rightClass->value());
if (isSubclass(leftClassType, rightClassType))
return RightSuper;
else if (isSubclass(rightClassType, leftClassType))
return LeftSuper;
else
return Unrelated;
}
// Entirely analogous to NormalizedType except that it operates on eclasses instead of TypeIds.
struct CanonicalizedType
{
std::optional<Id> nilPart;
std::optional<Id> truePart;
std::optional<Id> falsePart;
std::optional<Id> numberPart;
std::optional<Id> stringPart;
std::vector<Id> stringSingletons;
std::optional<Id> threadPart;
std::optional<Id> functionPart;
std::optional<Id> tablePart;
std::vector<Id> classParts;
std::optional<Id> bufferPart;
std::optional<Id> errorPart;
// Functions that have been union'd into the type
std::unordered_set<Id> functionParts;
// Anything that isn't canonical: Intersections, unions, free types, and so on.
std::unordered_set<Id> otherParts;
bool isUnknown() const
{
return nilPart && truePart && falsePart && numberPart && stringPart && threadPart && functionPart && tablePart && bufferPart;
}
};
void unionUnknown(EGraph& egraph, CanonicalizedType& ct)
{
ct.nilPart = egraph.add(TNil{});
ct.truePart = egraph.add(SBoolean{true});
ct.falsePart = egraph.add(SBoolean{false});
ct.numberPart = egraph.add(TNumber{});
ct.stringPart = egraph.add(TString{});
ct.threadPart = egraph.add(TThread{});
ct.functionPart = egraph.add(TTopFunction{});
ct.tablePart = egraph.add(TTopTable{});
ct.bufferPart = egraph.add(TBuffer{});
ct.functionParts.clear();
ct.otherParts.clear();
}
void unionAny(EGraph& egraph, CanonicalizedType& ct)
{
unionUnknown(egraph, ct);
ct.errorPart = egraph.add(TError{});
}
void unionClasses(EGraph& egraph, std::vector<Id>& hereParts, Id there)
{
if (1 == hereParts.size() && isTag<TTopClass>(egraph, hereParts[0]))
return;
const auto thereClass = isTag<TClass>(egraph, there);
if (!thereClass)
return;
for (size_t index = 0; index < hereParts.size(); ++index)
{
const Id herePart = hereParts[index];
if (auto partClass = isTag<TClass>(egraph, herePart))
{
switch (relateClasses(partClass, thereClass))
{
case LeftSuper:
return;
case RightSuper:
hereParts[index] = there;
std::sort(hereParts.begin(), hereParts.end());
return;
case Unrelated:
continue;
}
}
}
hereParts.push_back(there);
std::sort(hereParts.begin(), hereParts.end());
}
void unionWithType(EGraph& egraph, CanonicalizedType& ct, Id part)
{
if (isTag<TNil>(egraph, part))
ct.nilPart = part;
else if (isTag<TBoolean>(egraph, part))
ct.truePart = ct.falsePart = part;
else if (auto b = isTag<SBoolean>(egraph, part))
{
if (b->value())
ct.truePart = part;
else
ct.falsePart = part;
}
else if (isTag<TNumber>(egraph, part))
ct.numberPart = part;
else if (isTag<TString>(egraph, part))
ct.stringPart = part;
else if (isTag<SString>(egraph, part))
ct.stringSingletons.push_back(part);
else if (isTag<TThread>(egraph, part))
ct.threadPart = part;
else if (isTag<TTopFunction>(egraph, part))
{
ct.functionPart = part;
ct.functionParts.clear();
}
else if (isTag<TTopTable>(egraph, part))
ct.tablePart = part;
else if (isTag<TTopClass>(egraph, part))
ct.classParts = {part};
else if (isTag<TBuffer>(egraph, part))
ct.bufferPart = part;
else if (isTag<TFunction>(egraph, part))
{
if (!ct.functionPart)
ct.functionParts.insert(part);
}
else if (auto tclass = isTag<TClass>(egraph, part))
unionClasses(egraph, ct.classParts, part);
else if (isTag<TAny>(egraph, part))
{
unionAny(egraph, ct);
return;
}
else if (isTag<TError>(egraph, part))
ct.errorPart = part;
else if (isTag<TUnknown>(egraph, part))
unionUnknown(egraph, ct);
else if (isTag<TNever>(egraph, part))
{
// Nothing
}
else
ct.otherParts.insert(part);
}
// Find an enode under the given eclass which is simple enough that it could be
// subtracted from a CanonicalizedType easily.
//
// A union is "simple enough" if it is acyclic and is only comprised of terminal
// types and unions that are themselves subtractable
const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set<Id>& seen, Id id)
{
if (seen.count(id))
return nullptr;
const EType* bestUnion = nullptr;
std::optional<size_t> unionSize;
for (const auto& node : egraph[id].nodes)
{
if (isTerminal(node))
return &node;
if (const auto u = node.get<Union>())
{
seen.insert(id);
for (Id part : u->operands())
{
if (!findSubtractableClass(egraph, seen, part))
return nullptr;
}
// If multiple unions in this class are all simple enough, prefer
// the shortest one.
if (!unionSize || u->operands().size() < unionSize)
{
unionSize = u->operands().size();
bestUnion = &node;
}
}
}
return bestUnion;
}
const EType* findSubtractableClass(const EGraph& egraph, Id id)
{
std::unordered_set<Id> seen;
return findSubtractableClass(egraph, seen, id);
}
// Subtract the type 'part' from 'ct'
// Returns true if the subtraction succeeded. This function will fail if 'part` is too complicated.
bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part)
{
const EType* etype = findSubtractableClass(egraph, part);
if (!etype)
return false;
if (etype->get<TNil>())
ct.nilPart.reset();
else if (etype->get<TBoolean>())
{
ct.truePart.reset();
ct.falsePart.reset();
}
else if (auto b = etype->get<SBoolean>())
{
if (b->value())
ct.truePart.reset();
else
ct.falsePart.reset();
}
else if (etype->get<TNumber>())
ct.numberPart.reset();
else if (etype->get<TString>())
ct.stringPart.reset();
else if (etype->get<SString>())
return false;
else if (etype->get<TThread>())
ct.threadPart.reset();
else if (etype->get<TTopFunction>())
ct.functionPart.reset();
else if (etype->get<TTopTable>())
ct.tablePart.reset();
else if (etype->get<TTopClass>())
ct.classParts.clear();
else if (auto tclass = etype->get<TClass>())
{
auto it = std::find(ct.classParts.begin(), ct.classParts.end(), part);
if (it != ct.classParts.end())
ct.classParts.erase(it);
else
return false;
}
else if (etype->get<TBuffer>())
ct.bufferPart.reset();
else if (etype->get<TAny>())
ct = {};
else if (etype->get<TError>())
ct.errorPart.reset();
else if (etype->get<TUnknown>())
{
std::optional<Id> errorPart = ct.errorPart;
ct = {};
ct.errorPart = errorPart;
}
else if (etype->get<TNever>())
{
// Nothing
}
else if (auto u = etype->get<Union>())
{
// TODO cycles
// TODO this is super promlematic because 'part' represents a whole group of equivalent enodes.
for (Id unionPart : u->operands())
{
// TODO: This recursive call will require that we re-traverse this
// eclass to find the subtractible enode. It would be nice to do the
// work just once and reuse it.
bool ok = subtract(egraph, ct, unionPart);
if (!ok)
return false;
}
}
else if (etype->get<Intersection>())
return false;
else
return false;
return true;
}
Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct)
{
if (ct.isUnknown())
{
if (ct.errorPart)
return egraph.add(TAny{});
else
return egraph.add(TUnknown{});
}
std::vector<Id> parts;
if (ct.nilPart)
parts.push_back(*ct.nilPart);
if (ct.truePart && ct.falsePart)
parts.push_back(egraph.add(TBoolean{}));
else if (ct.truePart)
parts.push_back(*ct.truePart);
else if (ct.falsePart)
parts.push_back(*ct.falsePart);
if (ct.numberPart)
parts.push_back(*ct.numberPart);
if (ct.stringPart)
parts.push_back(*ct.stringPart);
else if (!ct.stringSingletons.empty())
parts.insert(parts.end(), ct.stringSingletons.begin(), ct.stringSingletons.end());
if (ct.threadPart)
parts.push_back(*ct.threadPart);
if (ct.functionPart)
parts.push_back(*ct.functionPart);
if (ct.tablePart)
parts.push_back(*ct.tablePart);
parts.insert(parts.end(), ct.classParts.begin(), ct.classParts.end());
if (ct.bufferPart)
parts.push_back(*ct.bufferPart);
if (ct.errorPart)
parts.push_back(*ct.errorPart);
parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end());
parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end());
return mkUnion(egraph, std::move(parts));
}
void addChildren(const EGraph& egraph, const EType* enode, VecDeque<Id>& worklist)
{
for (Id id : enode->operands())
worklist.push_back(id);
}
static bool occurs(EGraph& egraph, Id outerId, Slice<const Id> operands)
{
for (const Id i : operands)
{
if (egraph.find(i) == outerId)
return true;
}
return false;
}
Simplifier::Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes)
: arena(arena)
, builtinTypes(builtinTypes)
, egraph(Simplify{})
{
}
const EqSat::EClass<EType, Simplify::Data>& Simplifier::get(Id id) const
{
return egraph[id];
}
Id Simplifier::find(Id id) const
{
return egraph.find(id);
}
Id Simplifier::add(EType enode)
{
return egraph.add(std::move(enode));
}
template<typename Tag>
const Tag* Simplifier::isTag(Id id) const
{
for (const auto& node : get(id).nodes)
{
if (const Tag* ty = node.get<Tag>())
return ty;
}
return nullptr;
}
template<typename Tag>
const Tag* Simplifier::isTag(const EType& enode) const
{
return enode.get<Tag>();
}
void Simplifier::subst(Id from, Id to)
{
substs.emplace_back(from, to, " - ");
}
void Simplifier::subst(Id from, Id to, const std::string& ruleName)
{
std::string desc;
if (FFlag::DebugLuauLogSimplification)
desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, std::move(ruleName));
substs.emplace_back(from, to, desc);
}
void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes)
{
std::string desc;
if (FFlag::DebugLuauLogSimplification)
desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName);
substs.emplace_back(from, to, desc);
}
void Simplifier::unionClasses(std::vector<Id>& hereParts, Id there)
{
if (1 == hereParts.size() && isTag<TTopClass>(hereParts[0]))
return;
const auto thereClass = isTag<TClass>(there);
if (!thereClass)
return;
for (size_t index = 0; index < hereParts.size(); ++index)
{
const Id herePart = hereParts[index];
if (auto partClass = isTag<TClass>(herePart))
{
switch (relateClasses(partClass, thereClass))
{
case LeftSuper:
return;
case RightSuper:
hereParts[index] = there;
std::sort(hereParts.begin(), hereParts.end());
return;
case Unrelated:
continue;
}
}
}
hereParts.push_back(there);
std::sort(hereParts.begin(), hereParts.end());
}
void Simplifier::simplifyUnion(Id id)
{
id = find(id);
for (const auto [u, unionIndex] : Query<Union>(&egraph, id))
{
std::vector<Id> newParts;
std::unordered_set<Id> seen;
CanonicalizedType canonicalized;
if (occurs(egraph, id, u->operands()))
continue;
for (Id part : u->operands())
unionWithType(egraph, canonicalized, find(part));
Id resultId = fromCanonicalized(egraph, canonicalized);
subst(id, resultId, "simplifyUnion", {{id, unionIndex}});
}
}
// If one of the nodes matches the given Tag, succeed and return the id and node for the other half.
// If neither matches, return nullopt.
template<typename Tag>
static std::optional<std::pair<Id, const EType*>> matchOne(Id hereId, const EType* hereNode, Id thereId, const EType* thereNode)
{
if (hereNode->get<Tag>())
return std::pair{thereId, thereNode};
else if (thereNode->get<Tag>())
return std::pair{hereId, hereNode};
else
return std::nullopt;
}
// If the two nodes can be intersected into a "simple" type, return that, else return nullopt.
std::optional<EType> intersectOne(EGraph& egraph, Id hereId, const EType* hereNode, Id thereId, const EType* thereNode)
{
hereId = egraph.find(hereId);
thereId = egraph.find(thereId);
if (hereId == thereId)
return *hereNode;
if (hereNode->get<TNever>() || thereNode->get<TNever>())
return TNever{};
if (hereNode->get<Union>() || hereNode->get<Intersection>() || hereNode->get<Negation>() || thereNode->get<Union>() ||
thereNode->get<Intersection>() || thereNode->get<Negation>() || hereNode->get<TOpaque>() || thereNode->get<TOpaque>())
return std::nullopt;
if (hereNode->get<TUnknown>())
return *thereNode;
if (thereNode->get<TUnknown>())
return *hereNode;
if (hereNode->get<TTypeFun>() || thereNode->get<TTypeFun>())
return std::nullopt;
if (auto res = matchOne<TTopClass>(hereId, hereNode, thereId, thereNode))
{
const auto [otherId, otherNode] = *res;
if (otherNode->get<TClass>() || otherNode->get<TTopClass>())
return *otherNode;
else
return TNever{};
}
if (auto res = matchOne<TTopTable>(hereId, hereNode, thereId, thereNode))
{
const auto [otherId, otherNode] = *res;
if (otherNode->get<TTopTable>() || otherNode->get<TImportedTable>())
return *otherNode;
}
if (auto res = matchOne<TImportedTable>(hereId, hereNode, thereId, thereNode))
{
const auto [otherId, otherNode] = *res;
if (otherNode->get<TImportedTable>())
return std::nullopt; // TODO
else
return TNever{};
}
if (auto hereClass = hereNode->get<TClass>())
{
if (auto thereClass = thereNode->get<TClass>())
{
switch (relateClasses(hereClass, thereClass))
{
case LeftSuper:
return *thereNode;
case RightSuper:
return *hereNode;
case Unrelated:
return TNever{};
}
}
else
return TNever{};
}
if (auto hereBool = hereNode->get<SBoolean>())
{
if (auto thereBool = thereNode->get<SBoolean>())
{
if (hereBool->value() == thereBool->value())
return *hereNode;
else
return TNever{};
}
else if (thereNode->get<TBoolean>())
return *hereNode;
else
return TNever{};
}
if (auto thereBool = thereNode->get<SBoolean>())
{
if (auto hereBool = hereNode->get<SBoolean>())
{
if (thereBool->value() == hereBool->value())
return *thereNode;
else
return TNever{};
}
else if (hereNode->get<TBoolean>())
return *thereNode;
else
return TNever{};
}
if (hereNode->get<TBoolean>())
{
if (thereNode->get<TBoolean>())
return TBoolean{};
else if (thereNode->get<SBoolean>())
return *thereNode;
else
return TNever{};
}
if (thereNode->get<TBoolean>())
{
if (hereNode->get<TBoolean>())
return TBoolean{};
else if (hereNode->get<SBoolean>())
return *hereNode;
else
return TNever{};
}
if (hereNode->get<SString>())
{
if (thereNode->get<TString>())
return *hereNode;
else
return TNever{};
}
if (thereNode->get<SString>())
{
if (hereNode->get<TString>())
return *thereNode;
else
return TNever{};
}
if (hereNode->get<TTopFunction>())
{
if (thereNode->get<TFunction>() || thereNode->get<TTopFunction>())
return *thereNode;
else
return TNever{};
}
if (thereNode->get<TTopFunction>())
{
if (hereNode->get<TFunction>() || hereNode->get<TTopFunction>())
return *hereNode;
else
return TNever{};
}
if (hereNode->get<TFunction>() && thereNode->get<TFunction>())
return std::nullopt;
if (hereNode->get<TFunction>() && isTerminal(*thereNode))
return TNever{};
if (thereNode->get<TFunction>() && isTerminal(*hereNode))
return TNever{};
if (isTerminal(*hereNode) && isTerminal(*thereNode))
{
// We already know that 'here' and 'there' are different classes.
return TNever{};
}
return std::nullopt;
}
void Simplifier::uninhabitedIntersection(Id id)
{
for (const auto [intersection, index] : Query<Intersection>(&egraph, id))
{
Slice<const Id> parts = intersection->operands();
if (parts.empty())
{
Id never = egraph.add(TNever{});
subst(id, never, "uninhabitedIntersection");
return;
}
else if (1 == parts.size())
{
subst(id, parts[0], "uninhabitedIntersection");
return;
}
Id accumulator = egraph.add(TUnknown{});
EType accumulatorNode = TUnknown{};
std::vector<Id> unsimplified;
if (occurs(egraph, id, parts))
continue;
for (Id partId : parts)
{
if (isTag<TNoRefine>(partId))
return;
bool found = false;
const auto& partNodes = egraph[partId].nodes;
for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex)
{
const EType& N = partNodes[partIndex];
if (std::optional<EType> intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N))
{
if (isTag<TNever>(*intersection))
{
subst(id, egraph.add(TNever{}), "uninhabitedIntersection", {{id, index}, {partId, partIndex}});
return;
}
accumulator = egraph.add(*intersection);
accumulatorNode = *intersection;
found = true;
break;
}
}
if (!found)
unsimplified.push_back(partId);
}
if ((unsimplified.empty() || !isTag<TUnknown>(accumulator)) && find(accumulator) != id)
unsimplified.push_back(accumulator);
const Id result = mkIntersection(egraph, std::move(unsimplified));
subst(id, result, "uninhabitedIntersection", {{id, index}});
}
}
void Simplifier::intersectWithNegatedClass(Id id)
{
for (const auto pair : Query<Intersection>(&egraph, id))
{
const Intersection* intersection = pair.first;
const size_t intersectionIndex = pair.second;
auto trySubst = [&](size_t i, size_t j)
{
Id iId = intersection->operands()[i];
Id jId = intersection->operands()[j];
for (const auto [negation, negationIndex] : Query<Negation>(&egraph, jId))
{
const Id negated = negation->operands()[0];
if (iId == negated)
{
subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {jId, negationIndex}});
return;
}
for (const auto [negatedClass, negatedClassIndex] : Query<TClass>(&egraph, negated))
{
const auto& iNodes = egraph[iId].nodes;
for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex)
{
const EType& iNode = iNodes[iIndex];
if (isTag<TNil>(iNode) || isTag<TBoolean>(iNode) || isTag<TNumber>(iNode) || isTag<TString>(iNode) || isTag<TThread>(iNode) ||
isTag<TTopFunction>(iNode) ||
// isTag<TTopTable>(iNode) || // I'm not sure about this one.
isTag<SBoolean>(iNode) || isTag<SString>(iNode) || isTag<TFunction>(iNode) || isTag<TNever>(iNode))
{
// eg string & ~SomeClass
subst(id, iId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}});
return;
}
if (const TClass* class_ = iNode.get<TClass>())
{
switch (relateClasses(class_, negatedClass))
{
case LeftSuper:
// eg Instance & ~Part
// This cannot be meaningfully reduced.
continue;
case RightSuper:
subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}});
return;
case Unrelated:
// Part & ~Folder == Part
{
std::vector<Id> newParts;
newParts.reserve(intersection->operands().size() - 1);
for (Id part : intersection->operands())
{
if (part != jId)
newParts.push_back(part);
}
Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()});
subst(id, substId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}});
}
}
}
}
}
}
};
if (2 != intersection->operands().size())
continue;
trySubst(0, 1);
trySubst(1, 0);
}
}
void Simplifier::intersectWithNegatedAtom(Id id)
{
// Let I and ~J be two arbitrary distinct operands of an intersection where
// I and J are terminal but are not type variables. (free, generic, or
// otherwise opaque)
//
// If I and J are equal, then the whole intersection is equivalent to never.
//
// If I and J are inequal, then J & ~I == J
for (const auto [intersection, intersectionIndex] : Query<Intersection>(&egraph, id))
{
const Slice<const Id>& intersectionOperands = intersection->operands();
for (size_t i = 0; i < intersectionOperands.size(); ++i)
{
for (const auto [negation, negationIndex] : Query<Negation>(&egraph, intersectionOperands[i]))
{
for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex)
{
const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex];
if (!isTerminal(*negationOperand) || negationOperand->get<TOpaque>())
continue;
for (size_t j = 0; j < intersectionOperands.size(); ++j)
{
if (j == i)
continue;
for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex)
{
const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex];
if (!isTerminal(*jNode) || jNode->get<TOpaque>())
continue;
if (*negationOperand == *jNode)
{
// eg "Hello" & ~"Hello"
// or boolean & ~boolean
subst(
id,
egraph.add(TNever{}),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
return;
}
else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand))
{
// eg "Hello" & ~"World"
// or boolean & ~string
std::vector<Id> newOperands(intersectionOperands.begin(), intersectionOperands.end());
newOperands.erase(newOperands.begin() + std::vector<Id>::difference_type(i));
subst(
id,
egraph.add(Intersection{newOperands}),
"intersectWithNegatedAtom",
{{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}}
);
}
}
}
}
}
}
}
}
void Simplifier::intersectWithNoRefine(Id id)
{
for (const auto pair : Query<Intersection>(&egraph, id))
{
const Intersection* intersection = pair.first;
const size_t intersectionIndex = pair.second;
const Slice<const Id> intersectionOperands = intersection->operands();
for (size_t index = 0; index < intersectionOperands.size(); ++index)
{
const auto replace = [&]()
{
std::vector<Id> newOperands{intersectionOperands.begin(), intersectionOperands.end()};
newOperands.erase(newOperands.begin() + index);
Id substId = egraph.add(Intersection{std::move(newOperands)});
subst(id, substId, "intersectWithNoRefine", {{id, intersectionIndex}});
};
if (isTag<TNoRefine>(intersectionOperands[index]))
replace();
else
{
for (const auto [negation, negationIndex] : Query<Negation>(&egraph, intersectionOperands[index]))
{
if (isTag<TNoRefine>(negation->operands()[0]))
{
replace();
break;
}
}
}
}
}
}
/*
* Replace x where x = A & (B | x) with A
*
* Important subtlety: The egraph is routinely going to create cyclic unions and
* intersections. We can't arbitrarily remove things from a union just because
* it can be referred to in a cyclic way. We must only do this for things that
* can only be expressed in a cyclic way.
*
* As an example, we will bind the following type to true:
*
* (true | buffer | class | function | number | string | table | thread) &
* boolean
*
* The egraph represented by this type will indeed be cyclic as the 'true' class
* includes both 'true' itself and the above type, but removing true from the
* union will result is an incorrect judgment!
*
* The solution (for now) is only to consider a type to be cyclic if it was
* cyclic on its original import.
*
* FIXME: I still don't think this is quite right, but I don't know how to
* articulate what the actual rule ought to be.
*/
void Simplifier::cyclicIntersectionOfUnion(Id id)
{
// FIXME: This has pretty terrible runtime complexity.
for (const auto [i, intersectionIndex] : Query<Intersection>(&egraph, id))
{
Slice<const Id> intersectionParts = i->operands();
for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionParts.size(); ++intersectionOperandIndex)
{
const Id intersectionPart = find(intersectionParts[intersectionOperandIndex]);
for (const auto [bound, _boundIndex] : Query<TBound>(&egraph, intersectionPart))
{
const Id pointee = find(mappingIdToClass.at(bound->value()));
for (const auto [u, unionIndex] : Query<Union>(&egraph, pointee))
{
const Slice<const Id>& unionOperands = u->operands();
for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex)
{
Id unionOperand = find(unionOperands[unionOperandIndex]);
if (unionOperand == id)
{
std::vector<Id> newIntersectionParts(intersectionParts.begin(), intersectionParts.end());
newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex);
subst(
id,
mkIntersection(egraph, std::move(newIntersectionParts)),
"cyclicIntersectionOfUnion",
{{id, intersectionIndex}, {pointee, unionIndex}}
);
}
}
}
}
}
}
}
void Simplifier::cyclicUnionOfIntersection(Id id)
{
// FIXME: This has pretty terrible runtime complexity.
for (const auto [union_, unionIndex] : Query<Union>(&egraph, id))
{
Slice<const Id> unionOperands = union_->operands();
for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex)
{
const Id unionPart = find(unionOperands[unionOperandIndex]);
for (const auto [bound, _boundIndex] : Query<TBound>(&egraph, unionPart))
{
const Id pointee = find(mappingIdToClass.at(bound->value()));
for (const auto [intersection, intersectionIndex] : Query<Intersection>(&egraph, pointee))
{
Slice<const Id> intersectionOperands = intersection->operands();
for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionOperands.size(); ++intersectionOperandIndex)
{
const Id intersectionPart = find(intersectionOperands[intersectionOperandIndex]);
if (intersectionPart == id)
{
std::vector<Id> newIntersectionParts(intersectionOperands.begin(), intersectionOperands.end());
newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex);
if (!newIntersectionParts.empty())
{
Id newIntersection = mkIntersection(egraph, std::move(newIntersectionParts));
std::vector<Id> newIntersectionParts(unionOperands.begin(), unionOperands.end());
newIntersectionParts.erase(newIntersectionParts.begin() + unionOperandIndex);
newIntersectionParts.push_back(newIntersection);
subst(
id,
mkUnion(egraph, std::move(newIntersectionParts)),
"cyclicUnionOfIntersection",
{{id, unionIndex}, {pointee, intersectionIndex}}
);
}
}
}
}
}
}
}
}
void Simplifier::expandNegation(Id id)
{
for (const auto [negation, index] : Query<Negation>{&egraph, id})
{
if (isTag<TNoRefine>(negation->operands()[0]))
return;
CanonicalizedType canonicalized;
unionUnknown(egraph, canonicalized);
const bool ok = subtract(egraph, canonicalized, negation->operands()[0]);
if (!ok)
continue;
subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}});
}
}
/**
* Let A be a class-node having the form B & C1 & ... & Cn
* And B be a class-node having the form (D | E)
*
* Create a class containing the node (C1 & ... & Cn & D) | (C1 & ... & Cn & E)
*
* This function does nothing and returns nullopt if A and B are cyclic.
*/
static std::optional<Id> distributeIntersectionOfUnion(
EGraph& egraph,
Id outerClass,
const Intersection* outerIntersection,
Id innerClass,
const Union* innerUnion
)
{
Slice<const Id> outerOperands = outerIntersection->operands();
std::vector<Id> newOperands;
newOperands.reserve(innerUnion->operands().size());
for (Id innerOperand : innerUnion->operands())
{
if (isTag<TNever>(egraph, innerOperand))
continue;
if (innerOperand == outerClass)
{
// Skip cyclic intersections of unions. There's a separate
// rule to get rid of those.
return std::nullopt;
}
std::vector<Id> intersectionParts;
intersectionParts.reserve(outerOperands.size());
intersectionParts.push_back(innerOperand);
for (const Id op : outerOperands)
{
if (isTag<TNever>(egraph, op))
{
break;
}
if (op != innerClass)
intersectionParts.push_back(op);
}
newOperands.push_back(mkIntersection(egraph, intersectionParts));
}
return mkUnion(egraph, std::move(newOperands));
}
// A & (B | C) -> (A & B) | (A & C)
//
// A & B & (C | D) -> A & (B & (C | D))
// -> A & ((B & C) | (B & D))
// -> (A & B & C) | (A & B & D)
void Simplifier::intersectionOfUnion(Id id)
{
id = find(id);
for (const auto [intersection, intersectionIndex] : Query<Intersection>(&egraph, id))
{
// For each operand O
// For each node N
// If N is a union U
// Create a new union comprised of every operand except O intersected with every operand of U
const Slice<const Id> operands = intersection->operands();
if (operands.size() < 2)
return;
if (occurs(egraph, id, operands))
continue;
for (Id operand : operands)
{
operand = find(operand);
if (operand == id)
break;
// Optimization: Decline to distribute any unions on an eclass that
// also contains a terminal node.
if (isTerminal(egraph, operand))
continue;
for (const auto [operandUnion, unionIndex] : Query<Union>(&egraph, operand))
{
if (occurs(egraph, id, operandUnion->operands()))
continue;
std::optional<Id> distributed = distributeIntersectionOfUnion(egraph, id, intersection, operand, operandUnion);
if (distributed)
subst(id, *distributed, "intersectionOfUnion", {{id, intersectionIndex}, {operand, unionIndex}});
}
}
}
}
// {"a": b} & {"a": c, ...} => {"a": b & c, ...}
void Simplifier::intersectTableProperty(Id id)
{
for (const auto [intersection, intersectionIndex] : Query<Intersection>(&egraph, id))
{
const Slice<const Id> intersectionParts = intersection->operands();
for (size_t i = 0; i < intersection->operands().size(); ++i)
{
const Id iId = intersection->operands()[i];
for (size_t j = 0; j < intersection->operands().size(); ++j)
{
if (i == j)
continue;
const Id jId = intersection->operands()[j];
if (iId == jId)
continue;
for (const auto [table1, table1Index] : Query<TImportedTable>(&egraph, iId))
{
const TableType* table1Ty = Luau::get<TableType>(table1->value());
LUAU_ASSERT(table1Ty);
if (table1Ty->props.size() != 1)
continue;
for (const auto [table2, table2Index] : Query<TImportedTable>(&egraph, jId))
{
const TableType* table2Ty = Luau::get<TableType>(table2->value());
LUAU_ASSERT(table2Ty);
auto it = table2Ty->props.find(table1Ty->props.begin()->first);
if (it != table2Ty->props.end())
{
std::vector<Id> newIntersectionParts;
newIntersectionParts.reserve(intersectionParts.size() - 1);
for (size_t index = 0; index < intersectionParts.size(); ++index)
{
if (index != i && index != j)
newIntersectionParts.push_back(intersectionParts[index]);
}
Id newTableProp = egraph.add(Intersection{
toId(egraph, builtinTypes, mappingIdToClass, stringCache, it->second.type()),
toId(egraph, builtinTypes, mappingIdToClass, stringCache, table1Ty->props.begin()->second.type())
});
newIntersectionParts.push_back(egraph.add(TTable{jId, {stringCache.add(it->first)}, {newTableProp}}));
subst(
id,
mkIntersection(egraph, std::move(newIntersectionParts)),
"intersectTableProperty",
{{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}}
);
}
}
}
}
}
}
}
// { prop: never } == never
void Simplifier::uninhabitedTable(Id id)
{
for (const auto [table, tableIndex] : Query<TImportedTable>(&egraph, id))
{
const TableType* tt = Luau::get<TableType>(table->value());
LUAU_ASSERT(tt);
for (const auto& [propName, prop] : tt->props)
{
if (prop.readTy && Luau::get<NeverType>(follow(*prop.readTy)))
{
subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}});
return;
}
if (prop.writeTy && Luau::get<NeverType>(follow(*prop.writeTy)))
{
subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}});
return;
}
}
}
for (const auto [table, tableIndex] : Query<TTable>(&egraph, id))
{
for (Id propType : table->propTypes())
{
if (isTag<TNever>(propType))
{
subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}});
return;
}
}
}
}
void Simplifier::unneededTableModification(Id id)
{
for (const auto [tbl, tblIndex] : Query<TTable>(&egraph, id))
{
const Id basis = tbl->getBasis();
for (const auto [importedTbl, importedTblIndex] : Query<TImportedTable>(&egraph, basis))
{
const TableType* tt = Luau::get<TableType>(importedTbl->value());
LUAU_ASSERT(tt);
bool skip = false;
for (size_t i = 0; i < tbl->propNames.size(); ++i)
{
StringId propName = tbl->propNames[i];
const Id propType = tbl->propTypes()[i];
Id importedProp = toId(egraph, builtinTypes, mappingIdToClass, stringCache, tt->props.at(stringCache.asString(propName)).type());
if (find(importedProp) != find(propType))
{
skip = true;
break;
}
}
if (!skip)
subst(id, basis, "unneededTableModification", {{id, tblIndex}, {basis, importedTblIndex}});
}
}
}
void Simplifier::builtinTypeFunctions(Id id)
{
for (const auto [tfun, index] : Query<TTypeFun>(&egraph, id))
{
const Slice<const Id>& args = tfun->operands();
if (args.size() != 2)
continue;
const std::string& name = tfun->value()->function->name;
if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod")
{
if (isTag<TNumber>(args[0]) && isTag<TNumber>(args[1]))
{
subst(id, add(TNumber{}), "builtinTypeFunctions", {{id, index}});
}
}
}
}
// Replace union<>, intersect<>, and refine<> with unions or intersections.
// These type functions exist primarily to cause simplification to defer until
// particular points in execution, so it is safe to get rid of them here.
//
// It's not clear that these type functions should exist at all.
void Simplifier::iffyTypeFunctions(Id id)
{
for (const auto [tfun, index] : Query<TTypeFun>(&egraph, id))
{
const Slice<const Id>& args = tfun->operands();
const std::string& name = tfun->value()->function->name;
if (name == "union")
subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}});
else if (name == "intersect")
subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}});
}
}
// Replace instances of `lt<X, Y>` and `le<X, Y>` when either X or Y is `number`
// or `string` with `boolean`. Lua semantics are that if we see the expression:
//
// x < y
//
// ... we error if `x` and `y` don't have the same type. We know that for
// `string` and `number`, comparisons will always return a boolean. So if either
// of the arguments to `lt<>` are equivalent to `number` or `string`, then the
// type is effectively `boolean`: either the other type is equivalent, in which
// case we eval to `boolean`, or we diverge (raise an error).
void Simplifier::strictMetamethods(Id id)
{
for (const auto [tfun, index] : Query<TTypeFun>(&egraph, id))
{
const Slice<const Id>& args = tfun->operands();
const std::string& name = tfun->value()->function->name;
if (!(name == "lt" || name == "le") || args.size() != 2)
continue;
if (isTag<TNumber>(args[0]) || isTag<TString>(args[0]) || isTag<TNumber>(args[1]) || isTag<TString>(args[1]))
{
subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}});
}
}
}
static void deleteSimplifier(Simplifier* s)
{
delete s;
}
SimplifierPtr newSimplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes)
{
return SimplifierPtr{new Simplifier(arena, builtinTypes), &deleteSimplifier};
}
} // namespace Luau::EqSatSimplification
namespace Luau
{
std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simplifier, TypeId ty)
{
using namespace Luau::EqSatSimplification;
std::unordered_map<size_t, Id> newMappings;
Id rootId = toId(simplifier->egraph, simplifier->builtinTypes, newMappings, simplifier->stringCache, ty);
simplifier->mappingIdToClass.insert(newMappings.begin(), newMappings.end());
Simplifier::RewriteRuleFn rules[] = {
&Simplifier::simplifyUnion,
&Simplifier::uninhabitedIntersection,
&Simplifier::intersectWithNegatedClass,
&Simplifier::intersectWithNegatedAtom,
&Simplifier::intersectWithNoRefine,
&Simplifier::cyclicIntersectionOfUnion,
&Simplifier::cyclicUnionOfIntersection,
&Simplifier::expandNegation,
&Simplifier::intersectionOfUnion,
&Simplifier::intersectTableProperty,
&Simplifier::uninhabitedTable,
&Simplifier::unneededTableModification,
&Simplifier::builtinTypeFunctions,
&Simplifier::iffyTypeFunctions,
&Simplifier::strictMetamethods,
};
std::unordered_set<Id> seen;
VecDeque<Id> worklist;
bool progressed = true;
int count = 0;
const int MAX_COUNT = 1000;
if (FFlag::DebugLuauLogSimplificationToDot)
std::ofstream("begin.dot") << toDot(simplifier->stringCache, simplifier->egraph);
auto& egraph = simplifier->egraph;
const auto& builtinTypes = simplifier->builtinTypes;
auto& arena = simplifier->arena;
if (FFlag::DebugLuauLogSimplification)
printf(">> simplify %s\n", toString(ty).c_str());
while (progressed && count < MAX_COUNT)
{
progressed = false;
worklist.clear();
seen.clear();
rootId = egraph.find(rootId);
worklist.push_back(rootId);
if (FFlag::DebugLuauLogSimplification)
{
std::vector<TypeId> newTypeFunctions;
const TypeId t = fromId(egraph, simplifier->stringCache, builtinTypes, arena, newTypeFunctions, rootId);
std::cout << "Begin (" << uint32_t(egraph.find(rootId)) << ")\t" << toString(t) << '\n';
}
while (!worklist.empty() && count < MAX_COUNT)
{
Id id = egraph.find(worklist.front());
worklist.pop_front();
const bool isFresh = seen.insert(id).second;
if (!isFresh)
continue;
simplifier->substs.clear();
// Optimization: If this class alraedy has a terminal node, don't
// try to run any rules on it.
bool shouldAbort = false;
for (const EType& enode : egraph[id].nodes)
{
if (isTerminal(enode))
{
shouldAbort = true;
break;
}
}
if (shouldAbort)
continue;
for (const EType& enode : egraph[id].nodes)
addChildren(egraph, &enode, worklist);
for (Simplifier::RewriteRuleFn rule : rules)
(simplifier.get()->*rule)(id);
if (simplifier->substs.empty())
continue;
for (const Subst& subst : simplifier->substs)
{
if (subst.newClass == subst.eclass)
continue;
if (FFlag::DebugLuauExtraEqSatSanityChecks)
{
const Id never = egraph.find(egraph.add(TNever{}));
const Id str = egraph.find(egraph.add(TString{}));
const Id unk = egraph.find(egraph.add(TUnknown{}));
LUAU_ASSERT(never != str);
LUAU_ASSERT(never != unk);
}
const bool isFresh = egraph.merge(subst.newClass, subst.eclass);
++count;
if (FFlag::DebugLuauLogSimplification && isFresh)
std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n';
if (FFlag::DebugLuauLogSimplificationToDot)
{
std::string filename = format("step%03d.dot", count);
std::ofstream(filename) << toDot(simplifier->stringCache, egraph);
}
if (FFlag::DebugLuauExtraEqSatSanityChecks)
{
const Id never = egraph.find(egraph.add(TNever{}));
const Id str = egraph.find(egraph.add(TString{}));
const Id unk = egraph.find(egraph.add(TUnknown{}));
const Id trueId = egraph.find(egraph.add(SBoolean{true}));
LUAU_ASSERT(never != str);
LUAU_ASSERT(never != unk);
LUAU_ASSERT(never != trueId);
}
progressed |= isFresh;
}
egraph.rebuild();
}
}
EqSatSimplificationResult result;
result.result = fromId(egraph, simplifier->stringCache, builtinTypes, arena, result.newTypeFunctions, rootId);
if (FFlag::DebugLuauLogSimplification)
printf("<< simplify %s\n", toString(result.result).c_str());
return result;
}
} // namespace Luau