#!/usr/bin/env python
#
# WARNING: this code is almost certainly not correct.  Run it at your
# own risk.  If it kills your Rio Karma, don't blame me.
#
# Released to the public domain, by Neil Schemenauer, 2004/09/09.

import os
import sys
import socket
import struct
import time
from binascii import hexlify
import md5
from cStringIO import StringIO

# Message format:
#
#   <magic>
#   <message type>
#   <zero padding to 32-bit boundary (?)>
#
# Most fields seem to be u32 integers packed in little-endian format.
#

# prefixed to every message
MAGIC = 'Ri\xc5\x8d'


# message types (packed as little endian u32, after MAGIC)
GET_VERSION = 0 # major=2, minor=0 (?)
NAK = 1
PROGRESS = 2 # numerator=0..0x20, demominator=0x20
LOGIN_PHASE_1 = 3 # req="" resp=<salt>
LOGIN_PHASE_2 = 4 # req=md5(<salt><password>)
GET_DEVICE_INFO = 5
GET_STORAGE_INFO = 6
GET_DEVICE_SETTINGS = 7
CHANGE_DEVICE_SETTINGS = 8
LOCK = 9
UNLOCK = 10
PREPARE = 11
WRITE = 12
GET_ALL_FILE_INFO = 13
GET_FILE_INFO = 14
CHANGE_FILE_INFO = 15
READ = 16
DELETE = 17
FORMAT = 18
DEVICE_OPERATION = 19

DEFAULT_INFO = {
    'title': 'unnamed',
    'type': 'taxi',
    'codec': 'taxi',
    'bitrate': 'fs128',
    'duration': 0,
    'offset': 0,
    'samplerate': 0,
}


def p32(n):
    return struct.pack('<I', n)

def u32(s):
    return struct.unpack('<I', s)[0]

def p64(n):
    return struct.pack('<Q', n)

def pad32(n):
    p = 4 - (n % 4)
    if p == 4:
        return 0
    else:
        return p

def make_header(mtype):
    return MAGIC + p32(mtype)

VERSION_STRING = p32(2) + p32(0) # major, minor?


class ProtocolError(Exception):
    pass


class FileInfo:
    def __init__(self, **attrs):
        self.attrs = attrs

    def __str__(self):
        s = []
        for key, val in self.attrs.items():
            s.append('%s=%s\n' % (key, val))
        return ''.join(s)


def _parse_file_info(info):
    file_info = {}
    attrs = {}
    for line in info:
        if not line.strip():
            if 'fid' in attrs:
                file_info[int(attrs['fid'])] = FileInfo(**attrs)
            else:
                print 'ignoring attrs', attrs
            attrs.clear()
        else:
            i = line.find('=')
            if i > 0:
                attrs[line[:i]] = line[i+1:].strip()
            else:
                print 'ignoring line', `line`
    return file_info


class Connection:
    def __init__(self, ip, port=8302, password=""):
        self.ip = ip
        self.port = port
        self.password = password
        self.s = None
        self.locked = False
        self.file_info = {} # {fid : FileInfo}

    def debug(self, *args):
        print ' '.join(map(str, args))

    def connect(self):
        self.s = socket.socket()
        self.s.connect((self.ip, self.port))

    def close(self):
        if self.locked:
            self.unlock()
        self.s.close()

    def _send_message(self, mtype, *args):
        msg = make_header(mtype)
        if args:
            msg += ''.join(args)
        self.s.sendall(msg)

    def _recvn(self, n):
        hunks = []
        while n > 0:
            hunk = self.s.recv(n)
            if not hunk:
                raise IOError, 'connection reset by peer'
            n -= len(hunk)
            hunks.append(hunk)
        return ''.join(hunks)

    def _recv_header(self):
        m = self._recvn(8)
        if m[:4] != MAGIC:
            raise ValueError, 'message missing magic, got %s' % hexlify(m[:4])
        return u32(m[4:8])
        
    def _expect_header(self, mtype):
        n = self._recv_header()
        if mtype != n:
            raise ValueError, 'expecting mtype %s, got %s' % (mtype, n)

        
    def lock(self):
        if not self.locked:
            self._send_message(LOCK, p32(1))
            self._expect_header(LOCK)
            status = self._recvn(4)
            self.debug('LOCK status', hexlify(status))
            if u32(status) != 0:
                raise ProtocolError, 'lock failed'
            self.locked = True

    def unlock(self):
        if self.locked:
            self._send_message(UNLOCK)
            while 1:
                mtype = self._recv_header()
                if mtype != PROGRESS:
                    break
                status = self._recvn(8)
                #print 'progress', hexlify(status)
            assert mtype == UNLOCK
            status = self._recvn(4)
            print 'UNLOCK', hexlify(status)
            if u32(status) != 0:
                raise ProtocolError, 'unlock failed'
            self.locked = False

    def login(self):
        self._send_message(LOGIN_PHASE_1)
        self._expect_header(LOGIN_PHASE_1)
        salt = self._recvn(16)
        if len(salt) != 16:
            raise ProtocolError, 'short read'
        password_hash = md5.new(salt + self.password).digest()
        self._send_message(LOGIN_PHASE_2, password_hash)
        self._expect_header(LOGIN_PHASE_2)
        status = self._recvn(8)
        print 'LOGIN', hexlify(status)


    def get_all_file_info(self):
        self._send_message(GET_ALL_FILE_INFO)
        self._expect_header(GET_ALL_FILE_INFO)
        status = self._recvn(4)
        print 'GET_ALL_FILE_INFO status', hexlify(status)
        info = StringIO()
        while 1:
            hunk = self.s.recv(4000)
            if hunk[-1] == '\0':
                while hunk[-1:] == '\0':
                    hunk = hunk[:-1]
                info.write(hunk)
                break
            info.write(hunk)
        info.seek(0)
        self.file_info = _parse_file_info(info)

    def next_fid(self):
        fid = max(*self.file_info)
        fid += 16 # why? I don't know
        return fid

    def _write(self, fid, fp, offset, size):
        self._send_message(WRITE, p64(offset), p64(size), p64(fid))
        blocksize = 20000
        remaining = size
        while remaining > blocksize:
            self.s.sendall(fp.read(blocksize))
            remaining -= blocksize
        self.s.sendall(fp.read(remaining))
        if size % 4 != 0:
            self.s.sendall('\0' * pad32(size))
        self._expect_header(WRITE)
        status = u32(self._recvn(4))
        if status != 0:
            raise ProtocolError, 'file WRITE failed (%s)' % status
        print 'wrote', size, 'bytes'
        
    def write(self, fid, fp, length=None, title='unnamed'):
        if length is None:
            length = os.fstat(fp.fileno()).st_size
        remaining = length
        maxwrite = 10240000
        offset = 0
        while remaining > maxwrite:
            self._write(fid, fp, offset, maxwrite)
            offset += maxwrite
            remaining -= maxwrite
            sys.stdout.write('.')
            sys.stdout.flush()
        self._write(fid, fp, offset, remaining)
        print 'WRITE okay', length, 'bytes'
        self.change_file_info(fid, title=title, length=length)
        print 'CHANGE_FILE_INFO okay'

    def change_file_info(self, fid, **attrs):
        attrs['fid'] = fid
        for name, value in DEFAULT_INFO.items():
            if name not in attrs:
                attrs[name] = value
        now = int(time.time())
        if 'fid_generation' not in attrs:
            attrs['fid_generation'] = now
        if 'ctime' not in attrs:
            attrs['ctime'] = now
        info = str(FileInfo(**attrs))
        self._send_message(CHANGE_FILE_INFO, p32(fid))
        self.s.sendall(info)
        padding = pad32(len(info)) or 4
        print 'padding', padding
        self.s.send('\0'*padding)
        self._expect_header(CHANGE_FILE_INFO)
        status = u32(self._recvn(4))
        if status != 0:
            raise ProtocolError, 'CHANGE_FILE_INFO failed (%s)' % status
        print 'CHANGE_FILE_INFO', status

    def delete(self, fid):
        self._send_message(DELETE, p32(fid))
        self._expect_header(DELETE)
        status = u32(self._recvn(4))
        if status != 0:
            raise ProtocolError, 'DELETE failed (%s)' % status
        
    
def main():
    global c

    c = Connection(sys.argv[1], password='secret')
    c.connect()
    c.login()
    c.lock()
    try:
        c.get_all_file_info()
        fid = c.next_fid()
        c.write(fid, open('test2'), title='test')
        #c.delete(fid)
    finally:
        try:
            c.close()
        except:
            print 'close failed'

main()
