#!/usr/bin/python3
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details

# The purpose of this script is to analyze disassembly generated by objdump or
# dumpbin to print (or to compare) the stack usage of functions/methods.
# This is a quickly written script, so it is quite possible it may not handle
# all code properly.
#
# The script expects the user to create a text assembly dump to be passed to
# the script.
#
# objdump Example
#   objdump --demangle --disassemble objfile.o > objfile.s
#
# dumpbin Example
#   dumpbin /disasm objfile.obj > objfile.s
#
# If the script is passed a single file, then all stack size information that
# is found it printed.  If two files are passed, then the script compares the
# stack usage of the two files (useful for A/B comparisons).
# Currently more than two input files are not supported. (But adding support shouldn't
# be very difficult.)
#
# Note: The script only handles x64 disassembly.  Supporting x86 is likely
# trivial, but ARM support could be difficult.
# Thus far the script has been tested with MSVC on Win64 and clang on OSX.

import argparse
import re

blank_re = re.compile('\s*')

class LineReader:
    def __init__(self, lines):
        self.lines = list(reversed(lines))
    def get_line(self):
        return self.lines.pop(-1)
    def peek_line(self):
        return self.lines[-1]
    def consume_blank_lines(self):
        while blank_re.fullmatch(self.peek_line()):
            self.get_line()
    def is_empty(self):
        return len(self.lines) == 0

def parse_objdump_assembly(in_file):
    results = {}
    text_section_re = re.compile('Disassembly of section __TEXT,__text:\s*')
    symbol_re = re.compile('[^<]*<(.*)>:\s*')
    stack_alloc = re.compile('.*subq\s*\$(\d*), %rsp\s*')

    lr = LineReader(in_file.readlines())

    def find_stack_alloc_size():
        while True:
            if lr.is_empty():
                return None
            if blank_re.fullmatch(lr.peek_line()):
                return None

            line = lr.get_line()
            mo = stack_alloc.fullmatch(line)
            if mo:
                lr.consume_blank_lines()
                return int(mo.group(1))

    # Find beginning of disassembly
    while not text_section_re.fullmatch(lr.get_line()):
        pass

    # Scan for symbols
    while not lr.is_empty():
        lr.consume_blank_lines()
        if lr.is_empty():
            break
        line = lr.get_line()
        mo = symbol_re.fullmatch(line)
        # Found a symbol
        if mo:
            symbol = mo.group(1)
            stack_size = find_stack_alloc_size()
            if stack_size != None:
                results[symbol] = stack_size

    return results

def parse_dumpbin_assembly(in_file):
    results = {}

    file_type_re = re.compile('File Type: COFF OBJECT\s*')
    symbol_re = re.compile('[^(]*\((.*)\):\s*')
    summary_re = re.compile('\s*Summary\s*')
    stack_alloc = re.compile('.*sub\s*rsp,([A-Z0-9]*)h\s*')

    lr = LineReader(in_file.readlines())

    def find_stack_alloc_size():
        while True:
            if lr.is_empty():
                return None
            if blank_re.fullmatch(lr.peek_line()):
                return None

            line = lr.get_line()
            mo = stack_alloc.fullmatch(line)
            if mo:
                lr.consume_blank_lines()
                return int(mo.group(1), 16) # return value in decimal

    # Find beginning of disassembly
    while not file_type_re.fullmatch(lr.get_line()):
        pass

    # Scan for symbols
    while not lr.is_empty():
        lr.consume_blank_lines()
        if lr.is_empty():
            break
        line = lr.get_line()
        if summary_re.fullmatch(line):
            break
        mo = symbol_re.fullmatch(line)
        # Found a symbol
        if mo:
            symbol = mo.group(1)
            stack_size = find_stack_alloc_size()
            if stack_size != None:
                results[symbol] = stack_size
    return results

def main():
    parser = argparse.ArgumentParser(description='Tool used for reporting or comparing the stack usage of functions/methods')
    parser.add_argument('--format', choices=['dumpbin', 'objdump'], required=True, help='Specifies the program used to generate the input files')
    parser.add_argument('--input', action='append', required=True, help='Input assembly file.  This option may be specified multiple times.')
    parser.add_argument('--md-output', action='store_true', help='Show table output in markdown format')
    parser.add_argument('--only-diffs', action='store_true', help='Only show stack info when it differs between the input files')
    args = parser.parse_args()

    parsers = {'dumpbin': parse_dumpbin_assembly, 'objdump' : parse_objdump_assembly}
    parse_func = parsers[args.format]

    input_results = []
    for input_name in args.input:
        with open(input_name) as in_file:
            results = parse_func(in_file)
            input_results.append(results)

    if len(input_results) == 1:
        # Print out the results sorted by size
        size_sorted = sorted([(size, symbol) for symbol, size in results.items()], reverse=True)
        print(input_name)
        for size, symbol in size_sorted:
            print(f'{size:10}\t{symbol}')
        print()
    elif len(input_results) == 2:
        common_symbols = set(input_results[0].keys()).intersection(set(input_results[1].keys()))
        print(f'Found {len(common_symbols)} common symbols')
        stack_sizes = sorted([(input_results[0][sym], input_results[1][sym], sym) for sym in common_symbols], reverse=True)
        if args.md_output:
            print('Before | After | Symbol')
            print('-- | -- | --')
        for size0, size1, symbol in stack_sizes:
            if args.only_diffs and size0 == size1:
                continue
            if args.md_output:
                print(f'{size0} | {size1} | {symbol}')
            else:
                print(f'{size0:10}\t{size1:10}\t{symbol}')
    else:
        print("TODO support more than 2 inputs")

if __name__ == '__main__':
    main()