luau/Compiler/src/ConstantFolding.cpp

406 lines
12 KiB
C++
Raw Normal View History

2022-01-14 16:20:09 +00:00
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "ConstantFolding.h"
#include <math.h>
namespace Luau
{
namespace Compile
{
static bool constantsEqual(const Constant& la, const Constant& ra)
{
LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown);
switch (la.type)
{
case Constant::Type_Nil:
return ra.type == Constant::Type_Nil;
case Constant::Type_Boolean:
return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean;
case Constant::Type_Number:
return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber;
case Constant::Type_String:
return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0;
default:
LUAU_ASSERT(!"Unexpected constant type in comparison");
return false;
}
}
static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg)
{
switch (op)
{
case AstExprUnary::Not:
if (arg.type != Constant::Type_Unknown)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = !arg.isTruthful();
}
break;
case AstExprUnary::Minus:
if (arg.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = -arg.valueNumber;
}
break;
case AstExprUnary::Len:
if (arg.type == Constant::Type_String)
{
result.type = Constant::Type_Number;
result.valueNumber = double(arg.stringLength);
}
break;
default:
LUAU_ASSERT(!"Unexpected unary operation");
}
}
static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra)
{
switch (op)
{
case AstExprBinary::Add:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber + ra.valueNumber;
}
break;
case AstExprBinary::Sub:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber - ra.valueNumber;
}
break;
case AstExprBinary::Mul:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber * ra.valueNumber;
}
break;
case AstExprBinary::Div:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber / ra.valueNumber;
}
break;
case AstExprBinary::Mod:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber;
}
break;
case AstExprBinary::Pow:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Number;
result.valueNumber = pow(la.valueNumber, ra.valueNumber);
}
break;
case AstExprBinary::Concat:
break;
case AstExprBinary::CompareNe:
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = !constantsEqual(la, ra);
}
break;
case AstExprBinary::CompareEq:
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = constantsEqual(la, ra);
}
break;
case AstExprBinary::CompareLt:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber < ra.valueNumber;
}
break;
case AstExprBinary::CompareLe:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber <= ra.valueNumber;
}
break;
case AstExprBinary::CompareGt:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber > ra.valueNumber;
}
break;
case AstExprBinary::CompareGe:
if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number)
{
result.type = Constant::Type_Boolean;
result.valueBoolean = la.valueNumber >= ra.valueNumber;
}
break;
case AstExprBinary::And:
if (la.type != Constant::Type_Unknown)
{
result = la.isTruthful() ? ra : la;
}
break;
case AstExprBinary::Or:
if (la.type != Constant::Type_Unknown)
{
result = la.isTruthful() ? la : ra;
}
break;
default:
LUAU_ASSERT(!"Unexpected binary operation");
}
}
struct ConstantVisitor : AstVisitor
{
DenseHashMap<AstExpr*, Constant>& constants;
DenseHashMap<AstLocal*, Variable>& variables;
2022-04-21 22:44:27 +01:00
DenseHashMap<AstLocal*, Constant>& locals;
2022-01-14 16:20:09 +00:00
2022-05-13 20:36:37 +01:00
bool wasEmpty = false;
2022-04-21 22:44:27 +01:00
ConstantVisitor(
DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Constant>& locals)
2022-01-14 16:20:09 +00:00
: constants(constants)
, variables(variables)
2022-04-21 22:44:27 +01:00
, locals(locals)
2022-01-14 16:20:09 +00:00
{
2022-05-13 20:36:37 +01:00
// since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries
wasEmpty = constants.empty() && locals.empty();
2022-01-14 16:20:09 +00:00
}
Constant analyze(AstExpr* node)
{
Constant result;
result.type = Constant::Type_Unknown;
if (AstExprGroup* expr = node->as<AstExprGroup>())
{
result = analyze(expr->expr);
}
else if (node->is<AstExprConstantNil>())
{
result.type = Constant::Type_Nil;
}
else if (AstExprConstantBool* expr = node->as<AstExprConstantBool>())
{
result.type = Constant::Type_Boolean;
result.valueBoolean = expr->value;
}
else if (AstExprConstantNumber* expr = node->as<AstExprConstantNumber>())
{
result.type = Constant::Type_Number;
result.valueNumber = expr->value;
}
else if (AstExprConstantString* expr = node->as<AstExprConstantString>())
{
result.type = Constant::Type_String;
result.valueString = expr->value.data;
result.stringLength = unsigned(expr->value.size);
}
else if (AstExprLocal* expr = node->as<AstExprLocal>())
{
const Constant* l = locals.find(expr->local);
if (l)
result = *l;
}
else if (node->is<AstExprGlobal>())
{
// nope
}
else if (node->is<AstExprVarargs>())
{
// nope
}
else if (AstExprCall* expr = node->as<AstExprCall>())
{
analyze(expr->func);
for (size_t i = 0; i < expr->args.size; ++i)
analyze(expr->args.data[i]);
}
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{
analyze(expr->expr);
}
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
{
analyze(expr->expr);
analyze(expr->index);
}
else if (AstExprFunction* expr = node->as<AstExprFunction>())
{
// this is necessary to propagate constant information in all child functions
expr->body->visit(this);
}
else if (AstExprTable* expr = node->as<AstExprTable>())
{
for (size_t i = 0; i < expr->items.size; ++i)
{
const AstExprTable::Item& item = expr->items.data[i];
if (item.key)
analyze(item.key);
analyze(item.value);
}
}
else if (AstExprUnary* expr = node->as<AstExprUnary>())
{
Constant arg = analyze(expr->expr);
if (arg.type != Constant::Type_Unknown)
foldUnary(result, expr->op, arg);
}
else if (AstExprBinary* expr = node->as<AstExprBinary>())
{
Constant la = analyze(expr->left);
Constant ra = analyze(expr->right);
2022-04-07 22:29:01 +01:00
// note: ra doesn't need to be constant to fold and/or
if (la.type != Constant::Type_Unknown)
2022-01-14 16:20:09 +00:00
foldBinary(result, expr->op, la, ra);
}
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
{
Constant arg = analyze(expr->expr);
result = arg;
}
else if (AstExprIfElse* expr = node->as<AstExprIfElse>())
{
Constant cond = analyze(expr->condition);
Constant trueExpr = analyze(expr->trueExpr);
Constant falseExpr = analyze(expr->falseExpr);
if (cond.type != Constant::Type_Unknown)
result = cond.isTruthful() ? trueExpr : falseExpr;
}
else
{
LUAU_ASSERT(!"Unknown expression type");
}
2022-05-06 01:03:43 +01:00
recordConstant(constants, node, result);
2022-01-14 16:20:09 +00:00
return result;
}
2022-05-06 01:03:43 +01:00
template<typename T>
void recordConstant(DenseHashMap<T, Constant>& map, T key, const Constant& value)
{
if (value.type != Constant::Type_Unknown)
map[key] = value;
2022-05-20 01:02:24 +01:00
else if (wasEmpty)
2022-05-06 01:03:43 +01:00
;
else if (Constant* old = map.find(key))
old->type = Constant::Type_Unknown;
}
void recordValue(AstLocal* local, const Constant& value)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(local);
LUAU_ASSERT(v);
if (!v->written)
{
v->constant = (value.type != Constant::Type_Unknown);
recordConstant(locals, local, value);
}
}
2022-01-14 16:20:09 +00:00
bool visit(AstExpr* node) override
{
// note: we short-circuit the visitor traversal through any expression trees by returning false
// recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression
analyze(node);
return false;
}
bool visit(AstStatLocal* node) override
{
// all values that align wrt indexing are simple - we just match them 1-1
for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i)
{
Constant arg = analyze(node->values.data[i]);
2022-05-06 01:03:43 +01:00
recordValue(node->vars.data[i], arg);
2022-01-14 16:20:09 +00:00
}
if (node->vars.size > node->values.size)
{
// if we have trailing variables, then depending on whether the last value is capable of returning multiple values
// (aka call or varargs), we either don't know anything about these vars, or we know they're nil
AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr;
bool multRet = last && (last->is<AstExprCall>() || last->is<AstExprVarargs>());
if (!multRet)
{
for (size_t i = node->values.size; i < node->vars.size; ++i)
{
2022-05-06 01:03:43 +01:00
Constant nil = {Constant::Type_Nil};
recordValue(node->vars.data[i], nil);
2022-01-14 16:20:09 +00:00
}
}
}
else
{
// we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside
// them
for (size_t i = node->vars.size; i < node->values.size; ++i)
analyze(node->values.data[i]);
}
return false;
}
};
2022-04-21 22:44:27 +01:00
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root)
2022-01-14 16:20:09 +00:00
{
2022-04-21 22:44:27 +01:00
ConstantVisitor visitor{constants, variables, locals};
2022-01-14 16:20:09 +00:00
root->visit(&visitor);
}
} // namespace Compile
} // namespace Luau