#!/bin/env python
"""
Copyright (C) 2005 Adam Bregenzer <adam@bregenzer.net>

TypeKey Authentication Service
    An authentication class for Six Apart's TypeKey service located at <http://www.typekey.com/>.

License: GPL-2
Example:
    tk = TypeKey('my_type_key_token')
    tk.verify(email='test@test.com', name='test', nick='test', ts=1106084427, sig='214012410412:12412412512')
"""

import base64
import binascii
import os
import sha
import stat
import time
import urllib

class TypeKey:
    """This class handles TypeKey logins.
    """

    # Base url for generating login and logout urls.
    base_url = 'https://www.typekey.com/t/typekey/'

    # Url used to download the public key.
    key_url = 'http://www.typekey.com/extras/regkeys.txt'

    # Location for caching the public key.
    key_cache_path = '/tmp/tk_key_cache'

    # Length of time to wait before refreshing the public key cache, in seconds.
    # Defaults to two days.
    key_cache_timeout = 60 * 60 * 48

    # Length of time logins remain valid, in seconds.
    # Defaults to five minutes.
    login_timeout = 60 * 5


    def __init__(self, token, version = '1.1'):
        self.token = token
        self.version = version


    def verify(self, email, name, nick, ts, sig, key = None):
        """Verify a typekey login
        """
        if key is None:
            key = self.getKey()

        if self.version == '1.1':
            message = '::'.join((email, name, nick, str(ts), self.token))
        else:
            message = '::'.join((email, name, nick, str(ts)))

        if self._dsaVerify(message, sig, key):
            if (time.time() - ts) > self.login_timeout:
                return False
            return True
        else:
            return False


    def getLoginUrl(self, return_url, email = False):
        """Return a URL to login to TypeKey
        """
        if email:
            email = '&need_email=1'
        else:
            email = ''

        url  = self.base_url
        url += 'login?t=' + self.token
        url += email
        url += '&_return=' + urllib.quote_plus(return_url)
        url += '&v=' + self.version
        return url


    def getLogoutUrl(self, return_url):
        """Return a URL to logout of TypeKey
        """
        return self.base_url + 'logout?_return=' + urllib.quote_plus(return_url)


    def getKey(self, url = None):
        """Return the TypeKey public keys, cache results unless a url is passed
        """
        if url is None:
            try:
                mod_time = os.stat(self.key_cache_path)[stat.ST_MTIME]
            except OSError:
                mod_time = 0

            if (time.time() - mod_time) < self.key_cache_timeout:
                fh = file(self.key_cache_path, 'r')
                key_string = fh.readline()
                fh.close()
            else:
                fh = urllib.urlopen(self.key_url)
                key_string = fh.readline()
                fh.close()

                fp = file(self.key_cache_path, 'w')
                fp.write(key_string)
                fp.close()
        else:
            fh = urllib.urlopen(url)
            key_string = fh.readline()
            fh.close()

        tk_key = dict()
        for pair in key_string.strip().split(' '):
            (key, value) = pair.split('=')
            tk_key[key] = long(value)

        return tk_key


    def _dsaVerify(self, message, sig, key):
        """Verify a DSA signature
        """
        hash_m = long(binascii.hexlify(sha.new(message).digest()), 16)

        (r_sig, s_sig) = sig.split(':')
        r_sig = long(binascii.hexlify(base64.decodestring(r_sig)), 16)
        s_sig = long(binascii.hexlify(base64.decodestring(s_sig)), 16)

        w = self._invert(s_sig, key['q'])

        u1 = (hash_m * w) % key['q']
        u2 = (r_sig * w) % key['q']

        v = ((self._powerModulus(key['g'], u1, key['p']) *
              self._powerModulus(key['pub_key'], u2, key['p'])
             ) % key['p']
            ) % key['q']

        return v == r_sig


    def _invert(self, x, y):
        """Return the inverse of x and y
        """
        while x < 0:
            x += y

        gcd = self._gcd(x, y)

        if gcd[2] == 1:
            inverse = gcd[0]
            while inverse < 0:
                inverse += y
            return inverse
        else:
            return False


    def _gcd(self, x, y):
        """Return the greatest common denominator of x and y
        """
        (a, a_last) = (1, 0)
        (b, b_last) = (0, 1)

        while y > 0:
            q = x / y
            (x, y) = (y, x % y)
            (a, a_last) = (a_last, a - (q * a_last))
            (b, b_last) = (b_last, b - (q * b_last))

        return (a, b, x)


    def _powerModulus(self, base, exp, mod):
        """Return the result of (base ** exp) % mod
        """
        if exp == 1:
            return base % mod
        elif (exp % 2) == 0:
            return pow(self._powerModulus(base, exp / 2, mod), 2) % mod
        else:
            return (base * self._powerModulus(base, exp - 1, mod)) % mod

