luau/EqSat/include/Luau/EGraph.h
Hunter Goldstein c759cd5581
Some checks failed
benchmark / callgrind (map[branch:main name:luau-lang/benchmark-data], ubuntu-22.04) (push) Has been cancelled
build / macos (push) Has been cancelled
build / macos-arm (push) Has been cancelled
build / ubuntu (push) Has been cancelled
build / windows (Win32) (push) Has been cancelled
build / windows (x64) (push) Has been cancelled
build / coverage (push) Has been cancelled
build / web (push) Has been cancelled
release / macos (push) Has been cancelled
release / ubuntu (push) Has been cancelled
release / windows (push) Has been cancelled
release / web (push) Has been cancelled
Sync to upstream/release/656 (#1612)
# General

All code has been re-formatted by `clang-format`; this is not
mechanically enforced, so Luau may go out-of-sync over the course of the
year.

# New Solver

* Track free types interior to a block of code on `Scope`, which should
reduce the number of free types that remain un-generalized after type
checking is complete (e.g.: less errors like `'a <: number is
incompatible with number`).

# Autocomplete

* Fragment autocomplete now does *not* provide suggestions within
comments (matching non-fragment autocomplete behavior).
* Autocomplete now respects iteration and recursion limits (some hangs
will now early exit with a "unification too complex error," some crashes
will now become internal complier exceptions).

# Runtime

* Add a limit to how many Luau codegen slot nodes addresses can be in
use at the same time (fixes #1605, fixes #1558).
* Added constant folding for vector arithmetic (fixes #1553).
* Added support for `buffer.readbits` and `buffer.writebits` (see:
https://github.com/luau-lang/rfcs/pull/18).

---

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: David Cope <dcope@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
2025-01-10 11:34:39 -08:00

311 lines
8.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Id.h"
#include "Luau/Language.h"
#include "Luau/UnionFind.h"
#include <optional>
#include <unordered_map>
#include <vector>
namespace Luau::EqSat
{
template<typename L, typename N>
struct EGraph;
template<typename L, typename N>
struct Analysis final
{
N analysis;
using D = typename N::Data;
Analysis() = default;
Analysis(N a)
: analysis(std::move(a))
{
}
template<typename T>
static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode)
{
return analysis.make(egraph, *enode.template get<T>());
}
template<typename... Ts>
D make(const EGraph<L, N>& egraph, const Language<Ts...>& enode) const
{
using FnMake = D (*)(const N&, const EGraph<L, N>&, const L&);
static constexpr FnMake tableMake[sizeof...(Ts)] = {&fnMake<Ts>...};
return tableMake[enode.index()](analysis, egraph, enode);
}
void join(D& a, const D& b) const
{
return analysis.join(a, b);
}
};
/// Each e-class is a set of e-nodes representing equivalent terms from a given language,
/// and an e-node is a function symbol paired with a list of children e-classes.
template<typename L, typename D>
struct EClass final
{
Id id;
std::vector<L> nodes;
D data;
std::vector<std::pair<L, Id>> parents;
};
/// See <https://arxiv.org/pdf/2004.03082>.
template<typename L, typename N>
struct EGraph final
{
using EClassT = EClass<L, typename N::Data>;
EGraph() = default;
explicit EGraph(N analysis)
: analysis(std::move(analysis))
{
}
Id find(Id id) const
{
return unionfind.find(id);
}
std::optional<Id> lookup(const L& enode) const
{
LUAU_ASSERT(isCanonical(enode));
if (auto it = hashcons.find(enode); it != hashcons.end())
return it->second;
return std::nullopt;
}
Id add(L enode)
{
canonicalize(enode);
if (auto id = lookup(enode))
return *id;
Id id = makeEClass(enode);
return id;
}
// Returns true if the two IDs were not previously merged.
bool merge(Id id1, Id id2)
{
id1 = find(id1);
id2 = find(id2);
if (id1 == id2)
return false;
const Id mergedId = unionfind.merge(id1, id2);
// Ensure that id1 is the Id that we keep, and id2 is the id that we drop.
if (mergedId == id2)
std::swap(id1, id2);
EClassT& eclass1 = get(id1);
EClassT eclass2 = std::move(get(id2));
classes.erase(id2);
eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end());
eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end());
std::sort(
eclass1.nodes.begin(),
eclass1.nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
worklist.reserve(worklist.size() + eclass1.parents.size());
for (const auto& [eclass, id] : eclass1.parents)
worklist.push_back(id);
analysis.join(eclass1.data, eclass2.data);
return true;
}
void rebuild()
{
std::unordered_set<Id> seen;
while (!worklist.empty())
{
Id id = worklist.back();
worklist.pop_back();
const bool isFresh = seen.insert(id).second;
if (!isFresh)
continue;
repair(find(id));
}
}
size_t size() const
{
return classes.size();
}
EClassT& operator[](Id id)
{
return get(find(id));
}
const EClassT& operator[](Id id) const
{
return const_cast<EGraph*>(this)->get(find(id));
}
const std::unordered_map<Id, EClassT>& getAllClasses() const
{
return classes;
}
private:
Analysis<L, N> analysis;
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
UnionFind unionfind;
/// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same
/// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the
/// e-class 𝑀[find(𝑎)].
std::unordered_map<Id, EClassT> classes;
/// The hashcons 𝐻 is a map from e-nodes to e-class ids.
std::unordered_map<L, Id, typename L::Hash> hashcons;
std::vector<Id> worklist;
private:
void canonicalize(L& enode)
{
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
Luau::EqSat::canonicalize(
enode,
[&](Id id)
{
return find(id);
}
);
}
bool isCanonical(const L& enode) const
{
bool canonical = true;
for (Id id : enode.operands())
canonical &= (id == find(id));
return canonical;
}
Id makeEClass(const L& enode)
{
LUAU_ASSERT(isCanonical(enode));
Id id = unionfind.makeSet();
classes.insert_or_assign(
id,
EClassT{
id,
{enode},
analysis.make(*this, enode),
{},
}
);
for (Id operand : enode.operands())
get(operand).parents.push_back({enode, id});
worklist.emplace_back(id);
hashcons.insert_or_assign(enode, id);
return id;
}
// Looks up for an eclass from a given non-canonicalized `id`.
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
EClassT& get(Id id)
{
LUAU_ASSERT(classes.count(id));
return classes.at(id);
}
void repair(Id id)
{
// In the egg paper, the `repair` function makes use of two loops over the `eclass.parents`
// by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id.
// And then in another loop that follows, deduplicate it.
//
// Here, we unify the two loops. I think it's equivalent?
// After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent.
std::unordered_map<L, Id, typename L::Hash> newParents;
// The eclass can be deallocated if it is merged into another eclass, so
// we take what we need from it and avoid retaining a pointer.
std::vector<std::pair<L, Id>> parents = get(id).parents;
for (auto& pair : parents)
{
L& enode = pair.first;
Id id = pair.second;
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
hashcons.erase(enode);
canonicalize(enode);
hashcons.insert_or_assign(enode, find(id));
if (auto it = newParents.find(enode); it != newParents.end())
merge(id, it->second);
newParents.insert_or_assign(enode, find(id));
}
// We reacquire the pointer because the prior loop potentially merges
// the eclass into another, which might move it around in memory.
EClassT* eclass = &get(find(id));
eclass->parents.clear();
for (const auto& [node, id] : newParents)
eclass->parents.emplace_back(std::move(node), std::move(id));
std::unordered_set<L, typename L::Hash> newNodes;
for (L node : eclass->nodes)
{
canonicalize(node);
newNodes.insert(std::move(node));
}
eclass->nodes.assign(newNodes.begin(), newNodes.end());
// FIXME: Extract into sortByTag()
std::sort(
eclass->nodes.begin(),
eclass->nodes.end(),
[](const L& left, const L& right)
{
return left.index() < right.index();
}
);
}
};
} // namespace Luau::EqSat