mirror of
https://github.com/luau-lang/luau.git
synced 2024-12-13 05:20:38 +00:00
447 lines
14 KiB
C++
447 lines
14 KiB
C++
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
|
#include "ConstantFolding.h"
|
|
|
|
#include "BuiltinFolding.h"
|
|
|
|
#include <vector>
|
|
#include <math.h>
|
|
|
|
namespace Luau
|
|
{
|
|
namespace Compile
|
|
{
|
|
|
|
static bool constantsEqual(const Constant& la, const Constant& ra)
|
|
{
|
|
LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown);
|
|
|
|
switch (la.type)
|
|
{
|
|
case Constant::Type_Nil:
|
|
return ra.type == Constant::Type_Nil;
|
|
|
|
case Constant::Type_Boolean:
|
|
return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean;
|
|
|
|
case Constant::Type_Number:
|
|
return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber;
|
|
|
|
case Constant::Type_String:
|
|
return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0;
|
|
|
|
default:
|
|
LUAU_ASSERT(!"Unexpected constant type in comparison");
|
|
return false;
|
|
}
|
|
}
|
|
|
|
static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg)
|
|
{
|
|
switch (op)
|
|
{
|
|
case AstExprUnary::Not:
|
|
if (arg.type != Constant::Type_Unknown)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = !arg.isTruthful();
|
|
}
|
|
break;
|
|
|
|
case AstExprUnary::Minus:
|
|
if (arg.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = -arg.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprUnary::Len:
|
|
if (arg.type == Constant::Type_String)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = double(arg.stringLength);
|
|
}
|
|
break;
|
|
|
|
default:
|
|
LUAU_ASSERT(!"Unexpected unary operation");
|
|
}
|
|
}
|
|
|
|
static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra)
|
|
{
|
|
switch (op)
|
|
{
|
|
case AstExprBinary::Add:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = la.valueNumber + ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Sub:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = la.valueNumber - ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Mul:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = la.valueNumber * ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Div:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = la.valueNumber / ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Mod:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Pow:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = pow(la.valueNumber, ra.valueNumber);
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Concat:
|
|
break;
|
|
|
|
case AstExprBinary::CompareNe:
|
|
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = !constantsEqual(la, ra);
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::CompareEq:
|
|
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = constantsEqual(la, ra);
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::CompareLt:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = la.valueNumber < ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::CompareLe:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = la.valueNumber <= ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::CompareGt:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = la.valueNumber > ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::CompareGe:
|
|
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = la.valueNumber >= ra.valueNumber;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::And:
|
|
if (la.type != Constant::Type_Unknown)
|
|
{
|
|
result = la.isTruthful() ? ra : la;
|
|
}
|
|
break;
|
|
|
|
case AstExprBinary::Or:
|
|
if (la.type != Constant::Type_Unknown)
|
|
{
|
|
result = la.isTruthful() ? la : ra;
|
|
}
|
|
break;
|
|
|
|
default:
|
|
LUAU_ASSERT(!"Unexpected binary operation");
|
|
}
|
|
}
|
|
|
|
struct ConstantVisitor : AstVisitor
|
|
{
|
|
DenseHashMap<AstExpr*, Constant>& constants;
|
|
DenseHashMap<AstLocal*, Variable>& variables;
|
|
DenseHashMap<AstLocal*, Constant>& locals;
|
|
|
|
const DenseHashMap<AstExprCall*, int>* builtins;
|
|
|
|
bool wasEmpty = false;
|
|
|
|
std::vector<Constant> builtinArgs;
|
|
|
|
ConstantVisitor(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
|
|
DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins)
|
|
: constants(constants)
|
|
, variables(variables)
|
|
, locals(locals)
|
|
, builtins(builtins)
|
|
{
|
|
// since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries
|
|
wasEmpty = constants.empty() && locals.empty();
|
|
}
|
|
|
|
Constant analyze(AstExpr* node)
|
|
{
|
|
Constant result;
|
|
result.type = Constant::Type_Unknown;
|
|
|
|
if (AstExprGroup* expr = node->as<AstExprGroup>())
|
|
{
|
|
result = analyze(expr->expr);
|
|
}
|
|
else if (node->is<AstExprConstantNil>())
|
|
{
|
|
result.type = Constant::Type_Nil;
|
|
}
|
|
else if (AstExprConstantBool* expr = node->as<AstExprConstantBool>())
|
|
{
|
|
result.type = Constant::Type_Boolean;
|
|
result.valueBoolean = expr->value;
|
|
}
|
|
else if (AstExprConstantNumber* expr = node->as<AstExprConstantNumber>())
|
|
{
|
|
result.type = Constant::Type_Number;
|
|
result.valueNumber = expr->value;
|
|
}
|
|
else if (AstExprConstantString* expr = node->as<AstExprConstantString>())
|
|
{
|
|
result.type = Constant::Type_String;
|
|
result.valueString = expr->value.data;
|
|
result.stringLength = unsigned(expr->value.size);
|
|
}
|
|
else if (AstExprLocal* expr = node->as<AstExprLocal>())
|
|
{
|
|
const Constant* l = locals.find(expr->local);
|
|
|
|
if (l)
|
|
result = *l;
|
|
}
|
|
else if (node->is<AstExprGlobal>())
|
|
{
|
|
// nope
|
|
}
|
|
else if (node->is<AstExprVarargs>())
|
|
{
|
|
// nope
|
|
}
|
|
else if (AstExprCall* expr = node->as<AstExprCall>())
|
|
{
|
|
analyze(expr->func);
|
|
|
|
if (const int* bfid = builtins ? builtins->find(expr) : nullptr)
|
|
{
|
|
// since recursive calls to analyze() may reuse the vector we need to be careful and preserve existing contents
|
|
size_t offset = builtinArgs.size();
|
|
bool canFold = true;
|
|
|
|
builtinArgs.reserve(offset + expr->args.size);
|
|
|
|
for (size_t i = 0; i < expr->args.size; ++i)
|
|
{
|
|
Constant ac = analyze(expr->args.data[i]);
|
|
|
|
if (ac.type == Constant::Type_Unknown)
|
|
canFold = false;
|
|
else
|
|
builtinArgs.push_back(ac);
|
|
}
|
|
|
|
if (canFold)
|
|
{
|
|
LUAU_ASSERT(builtinArgs.size() == offset + expr->args.size);
|
|
result = foldBuiltin(*bfid, builtinArgs.data() + offset, expr->args.size);
|
|
}
|
|
|
|
builtinArgs.resize(offset);
|
|
}
|
|
else
|
|
{
|
|
for (size_t i = 0; i < expr->args.size; ++i)
|
|
analyze(expr->args.data[i]);
|
|
}
|
|
}
|
|
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
|
|
{
|
|
analyze(expr->expr);
|
|
}
|
|
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
|
|
{
|
|
analyze(expr->expr);
|
|
analyze(expr->index);
|
|
}
|
|
else if (AstExprFunction* expr = node->as<AstExprFunction>())
|
|
{
|
|
// this is necessary to propagate constant information in all child functions
|
|
expr->body->visit(this);
|
|
}
|
|
else if (AstExprTable* expr = node->as<AstExprTable>())
|
|
{
|
|
for (size_t i = 0; i < expr->items.size; ++i)
|
|
{
|
|
const AstExprTable::Item& item = expr->items.data[i];
|
|
|
|
if (item.key)
|
|
analyze(item.key);
|
|
|
|
analyze(item.value);
|
|
}
|
|
}
|
|
else if (AstExprUnary* expr = node->as<AstExprUnary>())
|
|
{
|
|
Constant arg = analyze(expr->expr);
|
|
|
|
if (arg.type != Constant::Type_Unknown)
|
|
foldUnary(result, expr->op, arg);
|
|
}
|
|
else if (AstExprBinary* expr = node->as<AstExprBinary>())
|
|
{
|
|
Constant la = analyze(expr->left);
|
|
Constant ra = analyze(expr->right);
|
|
|
|
// note: ra doesn't need to be constant to fold and/or
|
|
if (la.type != Constant::Type_Unknown)
|
|
foldBinary(result, expr->op, la, ra);
|
|
}
|
|
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
|
|
{
|
|
Constant arg = analyze(expr->expr);
|
|
|
|
result = arg;
|
|
}
|
|
else if (AstExprIfElse* expr = node->as<AstExprIfElse>())
|
|
{
|
|
Constant cond = analyze(expr->condition);
|
|
Constant trueExpr = analyze(expr->trueExpr);
|
|
Constant falseExpr = analyze(expr->falseExpr);
|
|
|
|
if (cond.type != Constant::Type_Unknown)
|
|
result = cond.isTruthful() ? trueExpr : falseExpr;
|
|
}
|
|
else if (AstExprInterpString* expr = node->as<AstExprInterpString>())
|
|
{
|
|
for (AstExpr* expression : expr->expressions)
|
|
analyze(expression);
|
|
}
|
|
else
|
|
{
|
|
LUAU_ASSERT(!"Unknown expression type");
|
|
}
|
|
|
|
recordConstant(constants, node, result);
|
|
|
|
return result;
|
|
}
|
|
|
|
template<typename T>
|
|
void recordConstant(DenseHashMap<T, Constant>& map, T key, const Constant& value)
|
|
{
|
|
if (value.type != Constant::Type_Unknown)
|
|
map[key] = value;
|
|
else if (wasEmpty)
|
|
;
|
|
else if (Constant* old = map.find(key))
|
|
old->type = Constant::Type_Unknown;
|
|
}
|
|
|
|
void recordValue(AstLocal* local, const Constant& value)
|
|
{
|
|
// note: we rely on trackValues to have been run before us
|
|
Variable* v = variables.find(local);
|
|
LUAU_ASSERT(v);
|
|
|
|
if (!v->written)
|
|
{
|
|
v->constant = (value.type != Constant::Type_Unknown);
|
|
recordConstant(locals, local, value);
|
|
}
|
|
}
|
|
|
|
bool visit(AstExpr* node) override
|
|
{
|
|
// note: we short-circuit the visitor traversal through any expression trees by returning false
|
|
// recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression
|
|
analyze(node);
|
|
|
|
return false;
|
|
}
|
|
|
|
bool visit(AstStatLocal* node) override
|
|
{
|
|
// all values that align wrt indexing are simple - we just match them 1-1
|
|
for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i)
|
|
{
|
|
Constant arg = analyze(node->values.data[i]);
|
|
|
|
recordValue(node->vars.data[i], arg);
|
|
}
|
|
|
|
if (node->vars.size > node->values.size)
|
|
{
|
|
// if we have trailing variables, then depending on whether the last value is capable of returning multiple values
|
|
// (aka call or varargs), we either don't know anything about these vars, or we know they're nil
|
|
AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr;
|
|
bool multRet = last && (last->is<AstExprCall>() || last->is<AstExprVarargs>());
|
|
|
|
if (!multRet)
|
|
{
|
|
for (size_t i = node->values.size; i < node->vars.size; ++i)
|
|
{
|
|
Constant nil = {Constant::Type_Nil};
|
|
recordValue(node->vars.data[i], nil);
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside
|
|
// them
|
|
for (size_t i = node->vars.size; i < node->values.size; ++i)
|
|
analyze(node->values.data[i]);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
};
|
|
|
|
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
|
|
DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins, AstNode* root)
|
|
{
|
|
ConstantVisitor visitor{constants, variables, locals, builtins};
|
|
root->visit(&visitor);
|
|
}
|
|
|
|
} // namespace Compile
|
|
} // namespace Luau
|