From 66d379d29333f952f305496e46fcf95650663575 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 27 Nov 2023 09:31:42 -0800 Subject: [PATCH] Add SUBRK and DIVRK bytecode instructions to bytecode v5 Right now, we can compile R*K for all arithmetic instructions, but K*R gets compiled into two instructions (LOADN/LOADK + arithmetic opcode). This is problematic since it leads to reduced performance for some code. However, we'd like to avoid adding reverse variants of ADDK et al for all opcodes to avoid the increase in I$ footprint for interpreter. Looking at the arithmetic instructions, % and // don't have interesting use cases for K*V; ^ is sometimes used with constant on the left hand side but this would need to call pow() by necessity in all cases so it would be slow regardless of the dispatch overhead. This leaves the four basic arithmetic operations. For + and *, we can implement a compiler-side optimization in the future that transforms K*R to R*K automatically. This could either be done unconditionally at -O2, or conditionally based on the type of the value (driven by type annotations / inference) -- this technically changes behavior in presence of metamethods, although it might be sensible to just always do this because non-commutative +/* are evil. However, for - and / it is impossible for the compiler to optimize this in the future, so we need dedicated opcodes. This only increases the interpreter size by ~300 bytes (~1.5%) on X64. This makes spectral-norm and math-partial-sums 6% faster. To avoid the proliferation of bytecode versions this change piggybacks on the bytecode version bump that was just made in 604 for vector constants; we would still be able to enable these independently but we'll consider v5 complete when both are enabled. --- CodeGen/include/Luau/IrVisitUseDef.h | 5 +++ CodeGen/src/EmitCommonX64.cpp | 4 +-- CodeGen/src/EmitCommonX64.h | 2 +- CodeGen/src/IrBuilder.cpp | 6 ++++ CodeGen/src/IrLoweringA64.cpp | 6 +++- CodeGen/src/IrLoweringX64.cpp | 7 ++-- CodeGen/src/IrTranslation.cpp | 41 +++++++++++++++++----- CodeGen/src/IrTranslation.h | 1 + Common/include/Luau/Bytecode.h | 13 ++++--- Compiler/src/BytecodeBuilder.cpp | 22 +++++++++++- Compiler/src/Compiler.cpp | 16 +++++++++ VM/src/lvmexecute.cpp | 51 ++++++++++++++++++++++++---- 12 files changed, 144 insertions(+), 30 deletions(-) diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 8134e0d5..603f1ec5 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -41,6 +41,11 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i break; // A <- B, C case IrCmd::DO_ARITH: + visitor.maybeUse(inst.b); // Argument can also be a VmConst + visitor.maybeUse(inst.c); // Argument can also be a VmConst + + visitor.def(inst.a); + break; case IrCmd::GET_TABLE: visitor.use(inst.b); visitor.maybeUse(inst.c); // Argument can also be a VmConst diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 008dadd5..014f5a46 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -148,12 +148,12 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi build.jcc(ConditionX64::NotZero, label); } -void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm) +void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, OperandX64 b, OperandX64 c, TMS tm) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, luauRegAddress(ra)); - callWrap.addArgument(SizeX64::qword, luauRegAddress(rb)); + callWrap.addArgument(SizeX64::qword, b); callWrap.addArgument(SizeX64::qword, c); callWrap.addArgument(SizeX64::dword, tm); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 3418a09f..bc1f99c9 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -200,7 +200,7 @@ ConditionX64 getConditionInt(IrCondition cond); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); -void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); +void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, OperandX64 b, OperandX64 c, TMS tm); void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb); void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 56bbf904..765433a4 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -381,6 +381,12 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) case LOP_POWK: translateInstBinaryK(*this, pc, i, TM_POW); break; + case LOP_SUBRK: + translateInstBinaryRK(*this, pc, i, TM_SUB); + break; + case LOP_DIVRK: + translateInstBinaryRK(*this, pc, i, TM_DIV); + break; case LOP_NOT: translateInstNot(*this, pc); break; diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 6a1733a5..5d0be75a 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -1067,7 +1067,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) regs.spill(build, index); build.mov(x0, rState); build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); + + if (inst.b.kind == IrOpKind::VmConst) + emitAddOffset(build, x2, rConstants, vmConstOp(inst.b) * sizeof(TValue)); + else + build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); if (inst.c.kind == IrOpKind::VmConst) emitAddOffset(build, x3, rConstants, vmConstOp(inst.c) * sizeof(TValue)); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 74a5bfd6..21031793 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -962,10 +962,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) break; } case IrCmd::DO_ARITH: - if (inst.c.kind == IrOpKind::VmReg) - callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d))); - else - callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d))); + callArithHelper(regs, build, vmRegOp(inst.a), + inst.b.kind == IrOpKind::VmReg ? luauRegAddress(vmRegOp(inst.b)) : luauConstantAddress(vmConstOp(inst.b)), + inst.c.kind == IrOpKind::VmReg ? luauRegAddress(vmRegOp(inst.c)) : luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d))); break; case IrCmd::DO_LEN: callLengthHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b)); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 91e87fdb..76d58265 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -327,13 +327,16 @@ void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(next); } -static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opc, int pcpos, TMS tm) +static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opb, IrOp opc, int pcpos, TMS tm) { IrOp fallback = build.block(IrBlockKind::Fallback); // fast-path: number - IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); - build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback); + if (rb != -1) + { + IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb)); + build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback); + } if (rc != -1 && rc != rb) // TODO: optimization should handle second check, but we'll test it later { @@ -341,11 +344,23 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback); } - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); - IrOp vc; - + IrOp vb, vc; IrOp result; + if (opb.kind == IrOpKind::VmConst) + { + LUAU_ASSERT(build.function.proto); + TValue protok = build.function.proto->k[vmConstOp(opb)]; + + LUAU_ASSERT(protok.tt == LUA_TNUMBER); + + vb = build.constDouble(protok.value.n); + } + else + { + vb = build.inst(IrCmd::LOAD_DOUBLE, opb); + } + if (opc.kind == IrOpKind::VmConst) { LUAU_ASSERT(build.function.proto); @@ -409,18 +424,26 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, FallbackStreamScope scope(build, fallback, next); build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), opc, build.constInt(tm)); + build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm)); build.inst(IrCmd::JUMP, next); } void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm) { - translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm); + translateInstBinaryNumeric( + build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm); } void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm) { - translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm); + translateInstBinaryNumeric( + build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmReg(LUAU_INSN_B(*pc)), build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm); +} + +void translateInstBinaryRK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm) +{ + translateInstBinaryNumeric( + build, LUAU_INSN_A(*pc), -1, LUAU_INSN_C(*pc), build.vmConst(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm); } void translateInstNot(IrBuilder& build, const Instruction* pc) diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 29c14356..99c3d1c3 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -35,6 +35,7 @@ void translateInstJumpxEqN(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm); void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm); +void translateInstBinaryRK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm); void translateInstNot(IrBuilder& build, const Instruction* pc); void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 36dfabdb..e3c20670 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -45,7 +45,7 @@ // Version 2: Adds Proto::linedefined. Supported until 0.544. // Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. // Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported. -// Version 5: Adds vector constants. Currently supported. +// Version 5: Adds SUBRK/DIVRK and vector constants. Currently supported. // Bytecode opcode, part of the instruction header enum LuauOpcode @@ -219,7 +219,7 @@ enum LuauOpcode // ADDK, SUBK, MULK, DIVK, MODK, POWK: compute arithmetic operation between the source register and a constant and put the result into target register // A: target register // B: source register - // C: constant table index (0..255) + // C: constant table index (0..255); must refer to a number LOP_ADDK, LOP_SUBK, LOP_MULK, @@ -348,9 +348,12 @@ enum LuauOpcode // B: source register (for VAL/REF) or upvalue index (for UPVAL/UPREF) LOP_CAPTURE, - // removed in v3 - LOP_DEP_JUMPIFEQK, - LOP_DEP_JUMPIFNOTEQK, + // SUBRK, DIVRK: compute arithmetic operation between the constant and a source register and put the result into target register + // A: target register + // B: source register + // C: constant table index (0..255); must refer to a number + LOP_SUBRK, + LOP_DIVRK, // FASTCALL1: perform a fast call of a built-in function using 1 register argument // A: builtin function id (see LuauBuiltinFunction) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 8dc7b88e..ae376657 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAG(LuauVectorLiterals) +LUAU_FASTFLAG(LuauCompileRevK) namespace Luau { @@ -1123,7 +1124,7 @@ std::string BytecodeBuilder::getError(const std::string& message) uint8_t BytecodeBuilder::getVersion() { // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags - return (FFlag::LuauVectorLiterals ? 5 : LBC_VERSION_TARGET); + return (FFlag::LuauVectorLiterals || FFlag::LuauCompileRevK) ? 5 : LBC_VERSION_TARGET; } uint8_t BytecodeBuilder::getTypeEncodingVersion() @@ -1351,6 +1352,13 @@ void BytecodeBuilder::validateInstructions() const VCONST(LUAU_INSN_C(insn), Number); break; + case LOP_SUBRK: + case LOP_DIVRK: + VREG(LUAU_INSN_A(insn)); + VCONST(LUAU_INSN_B(insn), Number); + VREG(LUAU_INSN_C(insn)); + break; + case LOP_AND: case LOP_OR: VREG(LUAU_INSN_A(insn)); @@ -1973,6 +1981,18 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, result.append("]\n"); break; + case LOP_SUBRK: + formatAppend(result, "SUBRK R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + dumpConstant(result, LUAU_INSN_B(insn)); + formatAppend(result, "] R%d\n", LUAU_INSN_C(insn)); + break; + + case LOP_DIVRK: + formatAppend(result, "DIVRK R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + dumpConstant(result, LUAU_INSN_B(insn)); + formatAppend(result, "] R%d\n", LUAU_INSN_C(insn)); + break; + case LOP_AND: formatAppend(result, "AND R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 0a5463a2..1a6d6e5c 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -29,6 +29,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileSideEffects, false) LUAU_FASTFLAGVARIABLE(LuauCompileDeadIf, false) +LUAU_FASTFLAGVARIABLE(LuauCompileRevK, false) + namespace Luau { @@ -1516,6 +1518,20 @@ struct Compiler } else { + if (FFlag::LuauCompileRevK && (expr->op == AstExprBinary::Sub || expr->op == AstExprBinary::Div)) + { + int32_t lc = getConstantNumber(expr->left); + + if (lc >= 0 && lc <= 255) + { + uint8_t rr = compileExprAuto(expr->right, rs); + LuauOpcode op = (expr->op == AstExprBinary::Sub) ? LOP_SUBRK : LOP_DIVRK; + + bytecode.emitABC(op, target, uint8_t(lc), uint8_t(rr)); + return; + } + } + uint8_t rl = compileExprAuto(expr->left, rs); uint8_t rr = compileExprAuto(expr->right, rs); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 451433ee..9af25af6 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -101,7 +101,7 @@ VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_DEP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ VM_DISPATCH_OP(LOP_NATIVECALL), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ - VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_DEP_JUMPIFEQK), VM_DISPATCH_OP(LOP_DEP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ + VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_SUBRK), VM_DISPATCH_OP(LOP_DIVRK), VM_DISPATCH_OP(LOP_FASTCALL1), \ VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \ VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), VM_DISPATCH_OP(LOP_IDIV), \ VM_DISPATCH_OP(LOP_IDIVK), @@ -2697,16 +2697,53 @@ reentry: LUAU_UNREACHABLE(); } - VM_CASE(LOP_DEP_JUMPIFEQK) + VM_CASE(LOP_SUBRK) { - LUAU_ASSERT(!"Unsupported deprecated opcode"); - LUAU_UNREACHABLE(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rc)) + { + setnvalue(ra, nvalue(kv) - nvalue(rc)); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_SUB)); + VM_NEXT(); + } } - VM_CASE(LOP_DEP_JUMPIFNOTEQK) + VM_CASE(LOP_DIVRK) { - LUAU_ASSERT(!"Unsupported deprecated opcode"); - LUAU_UNREACHABLE(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (LUAU_LIKELY(ttisnumber(rc))) + { + setnvalue(ra, nvalue(kv) / nvalue(rc)); + VM_NEXT(); + } + else if (ttisvector(rc)) + { + float vb = cast_to(float, nvalue(kv)); + const float* vc = rc->value.v; + setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2], vb / vc[3]); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_DIV)); + VM_NEXT(); + } } VM_CASE(LOP_FASTCALL1)