#!/usr/bin/env python3
#
# Print out disassembly of executable along with source line numbers for each
# instruction.

import collections
import os
import re
import io
import subprocess


def dump(exe):
    rv = subprocess.run(
        ['objdump', '-d', exe],
        encoding='utf-8',
        stdout=subprocess.PIPE,
    )
    return rv.stdout.split('\n')


def read_context(line):
    fn, _, lineno = line.rpartition(':')
    try:
        lineno = int(lineno)
    except ValueError:
        return ''
    if not os.path.exists(fn):
        return ''
    context = 4
    cur_line = lineno - context
    lines = []
    with open(fn) as fp:
        skip = max(0, lineno - context)
        for i in range(skip - 1):
            fp.readline()
        for i in range(context * 2):
            if cur_line == lineno:
                marker = '>>>'
            else:
                marker = '   '
            lines.append(f' {marker} {fp.readline().rstrip()}')
            cur_line += 1
    return '\n'.join(lines)


def main():
    out = io.StringIO()
    print_sizes = False
    cwd = os.getcwd()
    target_exe = './python'
    target_func = '_PyEval_EvalFrameDefault'
    counts = collections.Counter()
    p = subprocess.Popen(
        ['addr2line', '-e', './python'],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        encoding='utf-8',
    )
    base = ''
    func = ''
    func_pat = re.compile('([0-f]+) <(.*)>:')
    for line in dump(target_exe):
        if m := func_pat.match(line):
            if base:
                size = int(m.group(1), 16) - int(base, 16)
                if print_sizes:
                    print('function size', func, size, file=out)
                elif func == target_func:
                    print(line, file=out)
            base = m.group(1)
            func = m.group(2)
            continue
        if func != target_func:
            continue
        try:
            addr, rest = line.strip().split(':\t', 1)
        except ValueError:
            print(line, file=out)
            continue
        full_addr = '0x' + base[: -len(addr)] + addr
        print(full_addr, file=p.stdin, flush=True)
        code_line = p.stdout.readline().strip()
        code_line = code_line.replace(cwd + '/', '')
        counts[code_line] += 1
        print(line, file=out)
        print('  ', code_line, file=out)
        print(file=out)

    print('Most common source lines:')
    for line, n in counts.most_common(100):
        chunk = read_context(line)
        print(n, line)
        print(chunk)
    print(out.getvalue(), end='')


if __name__ == '__main__':
    main()
