mirror of
https://github.com/luau-lang/luau.git
synced 2024-12-13 21:40:43 +00:00
364 lines
8.3 KiB
C
364 lines
8.3 KiB
C
|
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include "Luau/EGraph.h"
|
||
|
#include "Luau/Id.h"
|
||
|
#include "Luau/Language.h"
|
||
|
#include "Luau/Lexer.h" // For Allocator
|
||
|
#include "Luau/NotNull.h"
|
||
|
#include "Luau/TypeArena.h"
|
||
|
#include "Luau/TypeFwd.h"
|
||
|
|
||
|
namespace Luau
|
||
|
{
|
||
|
struct TypeFunction;
|
||
|
}
|
||
|
|
||
|
namespace Luau::EqSatSimplification
|
||
|
{
|
||
|
|
||
|
using StringId = uint32_t;
|
||
|
using Id = Luau::EqSat::Id;
|
||
|
|
||
|
LUAU_EQSAT_UNIT(TNil);
|
||
|
LUAU_EQSAT_UNIT(TBoolean);
|
||
|
LUAU_EQSAT_UNIT(TNumber);
|
||
|
LUAU_EQSAT_UNIT(TString);
|
||
|
LUAU_EQSAT_UNIT(TThread);
|
||
|
LUAU_EQSAT_UNIT(TTopFunction);
|
||
|
LUAU_EQSAT_UNIT(TTopTable);
|
||
|
LUAU_EQSAT_UNIT(TTopClass);
|
||
|
LUAU_EQSAT_UNIT(TBuffer);
|
||
|
|
||
|
// Used for any type that eqsat can't do anything interesting with.
|
||
|
LUAU_EQSAT_ATOM(TOpaque, TypeId);
|
||
|
|
||
|
LUAU_EQSAT_ATOM(SBoolean, bool);
|
||
|
LUAU_EQSAT_ATOM(SString, StringId);
|
||
|
|
||
|
LUAU_EQSAT_ATOM(TFunction, TypeId);
|
||
|
|
||
|
LUAU_EQSAT_ATOM(TImportedTable, TypeId);
|
||
|
|
||
|
LUAU_EQSAT_ATOM(TClass, TypeId);
|
||
|
|
||
|
LUAU_EQSAT_UNIT(TAny);
|
||
|
LUAU_EQSAT_UNIT(TError);
|
||
|
LUAU_EQSAT_UNIT(TUnknown);
|
||
|
LUAU_EQSAT_UNIT(TNever);
|
||
|
|
||
|
LUAU_EQSAT_NODE_SET(Union);
|
||
|
LUAU_EQSAT_NODE_SET(Intersection);
|
||
|
|
||
|
LUAU_EQSAT_NODE_ARRAY(Negation, 1);
|
||
|
|
||
|
LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*);
|
||
|
|
||
|
LUAU_EQSAT_UNIT(TNoRefine);
|
||
|
LUAU_EQSAT_UNIT(Invalid);
|
||
|
|
||
|
// enodes are immutable, but types are cyclic. We need a way to tie the knot.
|
||
|
// We handle this by generating TBound nodes at points where we encounter cycles.
|
||
|
// Each TBound has an ordinal that we later map onto the type.
|
||
|
// We use a substitution rule to replace all TBound nodes with their referrent.
|
||
|
LUAU_EQSAT_ATOM(TBound, size_t);
|
||
|
|
||
|
// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it.
|
||
|
struct TTable
|
||
|
{
|
||
|
explicit TTable(Id basis);
|
||
|
TTable(Id basis, std::vector<StringId> propNames_, std::vector<Id> propTypes_);
|
||
|
|
||
|
// All TTables extend some other table. This may be TTopTable.
|
||
|
//
|
||
|
// It will frequently be a TImportedTable, in which case we can reuse things
|
||
|
// like source location and documentation info.
|
||
|
Id getBasis() const;
|
||
|
EqSat::Slice<const Id> propTypes() const;
|
||
|
// TODO: Also support read-only table props
|
||
|
// TODO: Indexer type, index result type.
|
||
|
|
||
|
std::vector<StringId> propNames;
|
||
|
|
||
|
// The enode interface
|
||
|
EqSat::Slice<Id> mutableOperands();
|
||
|
EqSat::Slice<const Id> operands() const;
|
||
|
bool operator==(const TTable& rhs) const;
|
||
|
bool operator!=(const TTable& rhs) const
|
||
|
{
|
||
|
return !(*this == rhs);
|
||
|
}
|
||
|
|
||
|
struct Hash
|
||
|
{
|
||
|
size_t operator()(const TTable& value) const;
|
||
|
};
|
||
|
|
||
|
private:
|
||
|
// The first element of this vector is the basis. Subsequent elements are
|
||
|
// property types. As we add other things like read-only properties and
|
||
|
// indexers, the structure of this array is likely to change.
|
||
|
//
|
||
|
// We encode our data in this way so that the operands() method can properly
|
||
|
// return a Slice<Id>.
|
||
|
std::vector<Id> storage;
|
||
|
};
|
||
|
|
||
|
using EType = EqSat::Language<
|
||
|
TNil,
|
||
|
TBoolean,
|
||
|
TNumber,
|
||
|
TString,
|
||
|
TThread,
|
||
|
TTopFunction,
|
||
|
TTopTable,
|
||
|
TTopClass,
|
||
|
TBuffer,
|
||
|
|
||
|
TOpaque,
|
||
|
|
||
|
SBoolean,
|
||
|
SString,
|
||
|
|
||
|
TFunction,
|
||
|
TTable,
|
||
|
TImportedTable,
|
||
|
TClass,
|
||
|
|
||
|
TAny,
|
||
|
TError,
|
||
|
TUnknown,
|
||
|
TNever,
|
||
|
|
||
|
Union,
|
||
|
Intersection,
|
||
|
|
||
|
Negation,
|
||
|
|
||
|
TTypeFun,
|
||
|
|
||
|
Invalid,
|
||
|
TNoRefine,
|
||
|
TBound>;
|
||
|
|
||
|
|
||
|
struct StringCache
|
||
|
{
|
||
|
Allocator allocator;
|
||
|
DenseHashMap<size_t, StringId> strings{{}};
|
||
|
std::vector<std::string_view> views;
|
||
|
|
||
|
StringId add(std::string_view s);
|
||
|
std::string_view asStringView(StringId id) const;
|
||
|
std::string asString(StringId id) const;
|
||
|
};
|
||
|
|
||
|
using EGraph = Luau::EqSat::EGraph<EType, struct Simplify>;
|
||
|
|
||
|
struct Simplify
|
||
|
{
|
||
|
using Data = bool;
|
||
|
|
||
|
template<typename T>
|
||
|
Data make(const EGraph&, const T&) const;
|
||
|
|
||
|
void join(Data& left, const Data& right) const;
|
||
|
};
|
||
|
|
||
|
struct Subst
|
||
|
{
|
||
|
Id eclass;
|
||
|
Id newClass;
|
||
|
|
||
|
std::string desc;
|
||
|
|
||
|
Subst(Id eclass, Id newClass, std::string desc = "");
|
||
|
};
|
||
|
|
||
|
struct Simplifier
|
||
|
{
|
||
|
NotNull<TypeArena> arena;
|
||
|
NotNull<BuiltinTypes> builtinTypes;
|
||
|
EGraph egraph;
|
||
|
StringCache stringCache;
|
||
|
|
||
|
// enodes are immutable but types can be cyclic, so we need some way to
|
||
|
// encode the cycle. This map is used to connect TBound nodes to the right
|
||
|
// eclass.
|
||
|
//
|
||
|
// The cyclicIntersection rewrite rule uses this to sense when a cycle can
|
||
|
// be deleted from an intersection or union.
|
||
|
std::unordered_map<size_t, Id> mappingIdToClass;
|
||
|
|
||
|
std::vector<Subst> substs;
|
||
|
|
||
|
using RewriteRuleFn = void (Simplifier::*)(Id id);
|
||
|
|
||
|
Simplifier(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes);
|
||
|
|
||
|
// Utilities
|
||
|
const EqSat::EClass<EType, Simplify::Data>& get(Id id) const;
|
||
|
Id find(Id id) const;
|
||
|
Id add(EType enode);
|
||
|
|
||
|
template<typename Tag>
|
||
|
const Tag* isTag(Id id) const;
|
||
|
|
||
|
template<typename Tag>
|
||
|
const Tag* isTag(const EType& enode) const;
|
||
|
|
||
|
void subst(Id from, Id to);
|
||
|
void subst(Id from, Id to, const std::string& ruleName);
|
||
|
void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map<Id, size_t>& forceNodes);
|
||
|
|
||
|
void unionClasses(std::vector<Id>& hereParts, Id there);
|
||
|
|
||
|
// Rewrite rules
|
||
|
void simplifyUnion(Id id);
|
||
|
void uninhabitedIntersection(Id id);
|
||
|
void intersectWithNegatedClass(Id id);
|
||
|
void intersectWithNoRefine(Id id);
|
||
|
void cyclicIntersectionOfUnion(Id id);
|
||
|
void cyclicUnionOfIntersection(Id id);
|
||
|
void expandNegation(Id id);
|
||
|
void intersectionOfUnion(Id id);
|
||
|
void intersectTableProperty(Id id);
|
||
|
void uninhabitedTable(Id id);
|
||
|
void unneededTableModification(Id id);
|
||
|
void builtinTypeFunctions(Id id);
|
||
|
void iffyTypeFunctions(Id id);
|
||
|
};
|
||
|
|
||
|
template<typename Tag>
|
||
|
struct QueryIterator
|
||
|
{
|
||
|
QueryIterator();
|
||
|
QueryIterator(EGraph* egraph, Id eclass);
|
||
|
|
||
|
bool operator==(const QueryIterator& other) const;
|
||
|
bool operator!=(const QueryIterator& other) const;
|
||
|
|
||
|
std::pair<const Tag*, size_t> operator*() const;
|
||
|
|
||
|
QueryIterator& operator++();
|
||
|
QueryIterator& operator++(int);
|
||
|
|
||
|
private:
|
||
|
EGraph* egraph = nullptr;
|
||
|
Id eclass;
|
||
|
size_t index = 0;
|
||
|
};
|
||
|
|
||
|
template<typename Tag>
|
||
|
struct Query
|
||
|
{
|
||
|
EGraph* egraph;
|
||
|
Id eclass;
|
||
|
|
||
|
Query(EGraph* egraph, Id eclass)
|
||
|
: egraph(egraph)
|
||
|
, eclass(eclass)
|
||
|
{
|
||
|
}
|
||
|
|
||
|
QueryIterator<Tag> begin()
|
||
|
{
|
||
|
return QueryIterator<Tag>{egraph, eclass};
|
||
|
}
|
||
|
|
||
|
QueryIterator<Tag> end()
|
||
|
{
|
||
|
return QueryIterator<Tag>{};
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template<typename Tag>
|
||
|
QueryIterator<Tag>::QueryIterator()
|
||
|
: egraph(nullptr)
|
||
|
, eclass(Id{0})
|
||
|
, index(0)
|
||
|
{
|
||
|
}
|
||
|
|
||
|
template<typename Tag>
|
||
|
QueryIterator<Tag>::QueryIterator(EGraph* egraph_, Id eclass)
|
||
|
: egraph(egraph_)
|
||
|
, eclass(eclass)
|
||
|
, index(0)
|
||
|
{
|
||
|
const auto& ecl = (*egraph)[eclass];
|
||
|
|
||
|
static constexpr const int idx = EType::VariantTy::getTypeId<Tag>();
|
||
|
|
||
|
for (const auto& enode : ecl.nodes)
|
||
|
{
|
||
|
if (enode.index() < idx)
|
||
|
++index;
|
||
|
else
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx)
|
||
|
{
|
||
|
egraph = nullptr;
|
||
|
index = 0;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template<typename Tag>
|
||
|
bool QueryIterator<Tag>::operator==(const QueryIterator<Tag>& rhs) const
|
||
|
{
|
||
|
if (egraph == nullptr && rhs.egraph == nullptr)
|
||
|
return true;
|
||
|
|
||
|
return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index;
|
||
|
}
|
||
|
|
||
|
template<typename Tag>
|
||
|
bool QueryIterator<Tag>::operator!=(const QueryIterator<Tag>& rhs) const
|
||
|
{
|
||
|
return !(*this == rhs);
|
||
|
}
|
||
|
|
||
|
template<typename Tag>
|
||
|
std::pair<const Tag*, size_t> QueryIterator<Tag>::operator*() const
|
||
|
{
|
||
|
LUAU_ASSERT(egraph != nullptr);
|
||
|
|
||
|
EGraph::EClassT& ecl = (*egraph)[eclass];
|
||
|
|
||
|
LUAU_ASSERT(index < ecl.nodes.size());
|
||
|
auto& enode = ecl.nodes[index];
|
||
|
Tag* result = enode.template get<Tag>();
|
||
|
LUAU_ASSERT(result);
|
||
|
return {result, index};
|
||
|
}
|
||
|
|
||
|
// pre-increment
|
||
|
template<typename Tag>
|
||
|
QueryIterator<Tag>& QueryIterator<Tag>::operator++()
|
||
|
{
|
||
|
const auto& ecl = (*egraph)[eclass];
|
||
|
|
||
|
++index;
|
||
|
if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId<Tag>())
|
||
|
{
|
||
|
egraph = nullptr;
|
||
|
index = 0;
|
||
|
}
|
||
|
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// post-increment
|
||
|
template<typename Tag>
|
||
|
QueryIterator<Tag>& QueryIterator<Tag>::operator++(int)
|
||
|
{
|
||
|
QueryIterator<Tag> res = *this;
|
||
|
++res;
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
} // namespace Luau::EqSatSimplification
|