#!/usr/bin/env python3

import argparse
import enum
import os
import socket
import struct
import subprocess
import sys
import tempfile

DEBUG = False


def debug(*args):
    if DEBUG:
        print(*args, file=sys.stderr)


class SshTypes(enum.IntEnum):
    # replies
    SSH_AGENT_FAILURE = 5
    SSH_AGENT_SUCCESS = 6
    SSH_AGENT_EXTENSION_FAILURE = 28
    SSH_AGENT_IDENTITIES_ANSWER = 12
    SSH_AGENT_SIGN_RESPONSE = 14

    # requests
    SSH_AGENTC_REQUEST_IDENTITIES = 11
    SSH_AGENTC_SIGN_REQUEST = 13
    SSH_AGENTC_ADD_IDENTITY = 17
    SSH_AGENTC_REMOVE_IDENTITY = 18
    SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19
    SSH_AGENTC_ADD_ID_CONSTRAINED = 25
    SSH_AGENTC_ADD_SMARTCARD_KEY = 20
    SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21
    SSH_AGENTC_LOCK = 22
    SSH_AGENTC_UNLOCK = 23
    SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26
    SSH_AGENTC_EXTENSION = 27


def pack_msg(msg):
    return struct.pack('>I', len(msg)) + msg


def recv_msg(conn):
    buf = conn.recv(4)
    if not buf:
        return b''
    (length,) = struct.unpack('>I', buf)
    contents = conn.recv(length)
    return contents


def forward_msg(msg, forward_sock):
    client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    client.connect(forward_sock)
    try:
        client.send(msg)
        contents = recv_msg(client)
        resp = pack_msg(contents)
        debug('fwd response', resp)
    finally:
        client.close()
    return resp


SERVICE_ACCOUNTS = {
    'registry-ro': 'my-gar-readonly@p.iam.gserviceaccount.com',
    'registry-rw': 'my-gar-admin@p.iam.gserviceaccount.com',
}


def get_token(account_id):
    service_email = SERVICE_ACCOUNTS[account_id]
    p = subprocess.Popen(
        [
            'gcloud',
            'auth',
            'print-access-token',
            '--impersonate-service-account',
            service_email,
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    stdout, stderr = p.communicate()
    return stdout.strip()


def run_server(sock, forward_sock):
    print("Listening...")
    sock.listen()
    while True:
        conn, addr = sock.accept()
        try:
            while True:
                contents = recv_msg(conn)
                if not contents:
                    break
                resp = None
                if contents[0] == SshTypes.SSH_AGENTC_EXTENSION:
                    debug('extension', contents[1:])
                    if b'token@py-ssh-agent.python.ca' in contents:
                        msg = contents[1:].decode('utf-8')
                        prefix, sep, account_id = msg.rpartition(' ')
                        token = get_token(account_id)
                        resp = pack_msg(token)
                debug(contents)
                if resp is None:
                    resp = forward_msg(pack_msg(contents), forward_sock)
                conn.sendall(resp)
        finally:
            conn.close()


def main():
    global DEBUG
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--debug',
        '-d',
        action='store_true',
        default=False,
        help="Enable debugging output",
    )
    parser.add_argument(
        '--sock', '-s', default=None, help="Path to SSH auth socket"
    )
    parser.add_argument(
        '--background',
        action='store_true',
        default=False,
        help="Run as background process.",
    )
    args = parser.parse_args()
    DEBUG = args.debug

    # Securely create folder to hold Unix socket file
    tmp_dir = os.environ.get('XDG_RUNTIME_DIR') or '/tmp'
    sock_dir = tempfile.mkdtemp(prefix='py_ssh_agent_', dir=tmp_dir)
    sock_filename = os.path.join(sock_dir, 'sock')

    print(
        f'To use agent: SSH_AUTH_SOCK={sock_filename}; export SSH_AUTH_SOCK'
    )

    # Real ssh-agent socket.
    if args.sock:
        forward_sock = args.sock
    else:
        forward_sock = os.environ['SSH_AUTH_SOCK']

    if args.background:
        pid = os.fork()
        if pid != 0:
            sys.stdout.flush()
            sys.stderr.flush()
            os._exit(0)  # exit parent
        else:
            # Child (background) process.  Close FDs to detach from terminal.
            os.close(0)
            os.close(1)
            os.close(2)

    print("Opening socket...")
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.bind(sock_filename)
    try:
        run_server(sock, forward_sock)
    finally:
        os.unlink(sock_filename)
        os.rmdir(sock_dir)
        print("Done")


if __name__ == '__main__':
    main()
