Просмотр исходного кода

Refactored handshaking process, moved it to a separate file, implemented HTTP authentication

Taddeus Kroes 12 лет назад
Родитель
Сommit
ba21de4a2c
3 измененных файлов с 870 добавлено и 257 удалено
  1. 331 0
      handshake.py
  2. 519 0
      python_digest.py
  3. 20 257
      websocket.py

+ 331 - 0
handshake.py

@@ -0,0 +1,331 @@
+import os
+import re
+from hashlib import sha1
+from base64 import b64encode
+from urlparse import urlparse
+
+from python_digest import build_authorization_request
+from errors import HandshakeError
+
+
+WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+WS_VERSION = '13'
+MAX_REDIRECTS = 10
+
+
+def split_stripped(value, delim=','):
+    return map(str.strip, str(value).split(delim)) if value else []
+
+
+class Handshake(object):
+    def __init__(self, wsock):
+        self.wsock = wsock
+        self.sock = wsock.sock
+
+    def fail(self, msg):
+        self.sock.close()
+        raise HandshakeError(msg)
+
+    def receive_request(self):
+        raw, headers = self.receive_headers()
+
+        # Request must be HTTP (at least 1.1) GET request, find the location
+        match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw)
+
+        if match is None:
+            self.fail('not a valid HTTP 1.1 GET request')
+
+        location = match.group(1)
+        return location, headers
+
+    def receive_response(self):
+        raw, headers = self.receive_headers()
+
+        # Response must be HTTP (at least 1.1) with status 101
+        match = re.search(r'^HTTP/1\.1 (\d{3})', raw)
+
+        if match is None:
+            self.fail('not a valid HTTP 1.1 response')
+
+        status = int(match.group(1))
+        return status, headers
+
+    def receive_headers(self):
+        # Receive entire HTTP header
+        raw_headers = ''
+
+        while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
+            raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
+
+        headers = {}
+
+        for key, value in re.findall(r'(.*?): ?(.*?)\r\n', raw_headers):
+            if key in headers:
+                headers[key] += ', ' + value
+            else:
+                headers[key] = value
+
+        return raw_headers, headers
+
+    def send_headers(self, headers):
+        # Send request
+        for hdr in list(headers):
+            if isinstance(hdr, tuple):
+                hdr = '%s: %s' % hdr
+
+            self.sock.sendall(hdr + '\r\n')
+
+        self.sock.sendall('\r\n')
+
+    def perform(self):
+        raise NotImplementedError
+
+
+class ServerHandshake(Handshake):
+    """
+    Executes a handshake as the server end point of the socket. If the HTTP
+    request headers sent by the client are invalid, a HandshakeError is raised.
+    """
+
+    def perform(self):
+        # Receive and validate client handshake
+        location, headers = self.receive_request()
+
+        # Send server handshake in response
+        self.send_headers(self.response_headers(location, headers))
+
+    def response_headers(self, location, headers):
+        # Check if headers that MUST be present are actually present
+        for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
+                     'Sec-WebSocket-Version'):
+            if name not in headers:
+                self.fail('missing "%s" header' % name)
+
+        # Check WebSocket version used by client
+        version = headers['Sec-WebSocket-Version']
+
+        if version != WS_VERSION:
+            self.fail('WebSocket version %s requested (only %s is supported)'
+                      % (version, WS_VERSION))
+
+        # Verify required header keywords
+        if 'websocket' not in headers['Upgrade'].lower():
+            self.fail('"Upgrade" header must contain "websocket"')
+
+        if 'upgrade' not in headers['Connection'].lower():
+            self.fail('"Connection" header must contain "Upgrade"')
+
+        # Origin must be present if browser client, and must match the list of
+        # trusted origins
+        if 'Origin' not in headers:
+            if 'User-Agent' in headers:
+                self.fail('browser client must specify "Origin" header')
+
+            if self.wsock.trusted_origins:
+                self.fail('no "Origin" header specified, assuming untrusted')
+
+            origin = 'null'
+        elif self.wsock.trusted_origins:
+            origin = headers['Origin']
+
+            if origin not in self.wsock.trusted_origins:
+                self.fail('untrusted origin "%s"' % origin)
+
+        # Only a supported protocol can be returned
+        client_proto = split_stripped(headers['Sec-WebSocket-Protocol']) \
+                       if 'Sec-WebSocket-Protocol' in headers else []
+        protocol = 'null'
+
+        for p in client_proto:
+            if p in self.wsock.proto:
+                protocol = p
+                break
+
+        # Only supported extensions are returned
+        if 'Sec-WebSocket-Extensions' in headers:
+            client_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
+            extensions = [e for e in client_ext if e in self.wsock.extensions]
+        else:
+            extensions = []
+
+        # Encode acceptation key using the WebSocket GUID
+        key = headers['Sec-WebSocket-Key'].strip()
+        accept = b64encode(sha1(key + WS_GUID).digest())
+
+        # Location scheme differs for SSL-enabled connections
+        scheme = 'wss' if self.wsock.secure else 'ws'
+
+        if 'Host' in headers:
+            host = headers['Host']
+        else:
+            host, port = self.sock.getpeername()
+            default_port = 443 if self.wsock.secure else 80
+
+            if port != default_port:
+                host += ':%d' % port
+
+        # Construct HTTP response header
+        yield 'HTTP/1.1 101 Web Socket Protocol Handshake'
+        yield 'Upgrade', 'websocket'
+        yield 'Connection', 'Upgrade'
+        yield 'WebSocket-Origin', origin
+        yield 'WebSocket-Location', '%s://%s%s' % (scheme, host, location)
+        yield 'Sec-WebSocket-Accept', accept
+        yield 'Sec-WebSocket-Protocol', protocol
+        yield 'Sec-WebSocket-Extensions', ', '.join(extensions)
+
+
+class ClientHandshake(Handshake):
+    """
+    Executes a handshake as the client end point of the socket. May raise a
+    HandshakeError if the server response is invalid.
+    """
+
+    def __init__(self, wsock):
+        Handshake.__init__(self, wsock)
+        self.redirects = 0
+
+    def perform(self):
+        self.send_headers(self.request_headers())
+        self.handle_response(*self.receive_response())
+
+    def handle_response(self, status, headers):
+        if status == 101:
+            self.handle_handshake(headers)
+        elif status == 401:
+            self.handle_auth(headers)
+        elif status in (301, 302, 303, 307, 308):
+            self.handle_redirect(headers)
+        else:
+            self.fail('invalid HTTP response status %d' % status)
+
+    def handle_handshake(self, headers):
+        # Check if headers that MUST be present are actually present
+        for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
+            if name not in headers:
+                self.fail('missing "%s" header' % name)
+
+        if 'websocket' not in headers['Upgrade'].lower():
+            self.fail('"Upgrade" header must contain "websocket"')
+
+        if 'upgrade' not in headers['Connection'].lower():
+            self.fail('"Connection" header must contain "Upgrade"')
+
+        # Verify accept header
+        accept = headers['Sec-WebSocket-Accept'].strip()
+        required_accept = b64encode(sha1(self.key + WS_GUID).digest())
+
+        if accept != required_accept:
+            self.fail('invalid websocket accept header "%s"' % accept)
+
+        # Compare extensions
+        if 'Sec-WebSocket-Extensions' in headers:
+            server_ext = split_stripped(headers['Sec-WebSocket-Extensions'])
+
+            for e in set(server_ext) - set(self.wsock.extensions):
+                self.fail('server extension "%s" unsupported by client' % e)
+
+            for e in set(self.wsock.extensions) - set(server_ext):
+                self.fail('client extension "%s" unsupported by server' % e)
+
+        # Assert that returned protocol (if any) is supported
+        if 'Sec-WebSocket-Protocol' in headers:
+            protocol = headers['Sec-WebSocket-Protocol']
+
+            if protocol != 'null' and protocol not in self.wsock.protocols:
+                self.fail('unsupported protocol "%s"' % protocol)
+
+            self.wsock.protocol = protocol
+
+
+    def handle_auth(self, headers):
+        # HTTP authentication is required in the request
+        hdr = headers['WWW-Authenticate']
+        authres = dict(re.findall(r'(\w+)[:=] ?"?(\w+)"?', hdr))
+        mode = hdr.lstrip().split(' ', 1)[0]
+
+        if not self.wsock.auth:
+            self.fail('missing username and password for HTTP authentication')
+
+        if mode == 'Basic':
+            auth_hdr = self.http_auth_basic_headers(**authres)
+        elif mode == 'Digest':
+            auth_hdr = self.http_auth_digest_headers(**authres)
+        else:
+            self.fail('unsupported HTTP authentication mode "%s"' % mode)
+
+        # Send new, authenticated handshake
+        self.send_headers(list(self.request_headers()) + list(auth_hdr))
+        self.handle_response(*self.receive_response())
+
+    def handle_redirect(self, headers):
+        self.redirects += 1
+
+        if self.redirects > MAX_REDIRECTS:
+            self.fail('reached maximum number of redirects (%d)'
+                      % MAX_REDIRECTS)
+
+        # Handle HTTP redirect
+        url = urlparse(headers['Location'].strip())
+
+        # Reconnect socket to new host if net location changed
+        if not url.port:
+            url.port = 443 if self.secure else 80
+
+        addr = (url.netloc, url.port)
+
+        if addr != self.sock.getpeername():
+            self.sock.close()
+            self.sock.connect(addr)
+
+        # Update websocket object and send new handshake
+        self.wsock.location = url.path
+        self.perform()
+
+    def request_headers(self):
+        if len(self.wsock.location) == 0:
+            self.fail('request location is empty')
+
+        # Generate a 16-byte random base64-encoded key for this connection
+        self.key = b64encode(os.urandom(16))
+
+        # Send client handshake
+        yield 'GET %s HTTP/1.1' % self.wsock.location
+        yield 'Host', '%s:%d' % self.sock.getpeername()
+        yield 'Upgrade', 'websocket'
+        yield 'Connection', 'keep-alive, Upgrade'
+        yield 'Sec-WebSocket-Key', self.key
+        yield 'Sec-WebSocket-Version', WS_VERSION
+
+        if self.wsock.origin:
+            yield 'Origin', self.wsock.origin
+
+        # These are for eagerly caching webservers
+        yield 'Pragma', 'no-cache'
+        yield 'Cache-Control', 'no-cache'
+
+        # Request protocols and extension, these are later checked with the
+        # actual supported values from the server's response
+        if self.wsock.protocols:
+            yield 'Sec-WebSocket-Protocol', ', '.join(self.wsock.protocols)
+
+        if self.wsock.extensions:
+            yield 'Sec-WebSocket-Extensions', ', '.join(self.wsock.extensions)
+
+    def http_auth_basic_headers(self, **kwargs):
+        u, p = self.wsock.auth
+        u = u.encode('utf-8')
+        p = p.encode('utf-8')
+        yield 'Authorization', 'Basic ' + b64encode(u + ':' + p)
+
+    def http_auth_digest_headers(self, **kwargs):
+        username, password = self.wsock.auth
+        yield 'Authorization', build_authorization_request(
+                                username=username.encode('utf-8'),
+                                method='GET',
+                                uri=self.wsock.location,
+                                nonce_count=0,
+                                realm=kwargs['realm'],
+                                nonce=kwargs['nonce'],
+                                opaque=kwargs['opaque'],
+                                password=password.encode('utf-8'))

+ 519 - 0
python_digest.py

@@ -0,0 +1,519 @@
+'''
+Copyright (c) 2009, Akoha, Inc.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+   list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright notice,
+   this list of conditions and the following disclaimer in the documentation
+   and/or other materials provided with the distribution.
+ * Neither the name of python-digest nor the names of its contributors may be
+   used to endorse or promote products derived from this software without
+   specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+'''
+
+try:
+    import hashlib as md5
+except ImportError: # Python <2.5
+    import md5
+
+try:
+    from cStringIO import StringIO
+except ImportError:
+    from StringIO import StringIO
+
+import random
+import types
+import urllib
+import urlparse
+import logging
+
+# Make sure a NullHandler is available
+# This was added in Python 2.7/3.2
+try:
+    from logging import NullHandler
+except ImportError:
+    class NullHandler(logging.Handler):
+        def emit(self, record):
+            pass
+
+_REQUIRED_DIGEST_RESPONSE_PARTS = ['username', 'realm', 'nonce', 'uri', 'response', 'algorithm',
+                  'opaque', 'qop', 'nc', 'cnonce']
+_REQUIRED_DIGEST_CHALLENGE_PARTS = ['realm', 'nonce', 'stale', 'algorithm',
+                             'opaque', 'qop']
+
+l = logging.getLogger(__name__)
+l.addHandler(NullHandler())
+
+_LWS=[chr(9), ' ', '\r', '\n']
+_ILLEGAL_TOKEN_CHARACTERS = (
+    [chr(n) for n in range(0-31)] + # control characters
+    [chr(127)] + # DEL
+    ['(',')','<','>','@',',',';',':','\\','"','/','[',']','?','=','{','}',' '] +
+    [chr(9)]) # horizontal tab
+
+class State(object):
+    def character(self, c):
+        return self.consume(c)
+
+    def close(self):
+        return self.eof()
+
+    def eof(self):
+        raise ValueError('EOF not permitted in this state.')
+
+    '''
+    Return False to keep the current state, or True to pop it
+    '''
+    def consume(c):
+        raise Exception('Unimplemented')
+
+class ParentState(State):
+    def __init__(self):
+        super(State, self).__init__()
+        self.child = None
+
+    def close(self):
+        if self.child:
+            return self.handle_child_return(self.child.close())
+        else:
+            return self.eof()
+
+    def push_child(self, child, c=None):
+        self.child = child
+        if c is not None:
+            return self.send_to_child(c)
+        else:
+            return False
+
+    def send_to_child(self, c):
+        return self.handle_child_return(self.child.character(c))
+
+    def handle_child_return(self, returned_value):
+        if returned_value:
+            child = self.child
+            self.child = None
+            return self.child_complete(child)
+        return False
+
+    '''
+    Return False to keep the current state, or True to pop it.
+    '''
+    def child_complete(self, child):
+        return False
+
+    def character(self, c):
+        if self.child:
+            return self.send_to_child(c)
+        else:
+            return self.consume(c)
+
+    def consume(self, c):
+        return False
+
+
+class EscapedCharacterState(State):
+    def __init__(self, io):
+        super(EscapedCharacterState, self).__init__()
+        self.io = io
+
+    def consume(self, c):
+        self.io.write(c)
+        return True
+
+class KeyTrailingWhitespaceState(State):
+    def consume(self, c):
+        if c in _LWS:
+            return False
+        elif c == '=':
+            return True
+        else:
+            raise ValueError("Expected whitespace or '='")
+
+class ValueLeadingWhitespaceState(ParentState):
+    def __init__(self, io):
+        super(ValueLeadingWhitespaceState, self).__init__()
+        self.io = io
+
+    def consume(self, c):
+        if c in _LWS:
+            return False
+        elif c == '"':
+            return self.push_child(QuotedValueState(self.io))
+        elif c in _ILLEGAL_TOKEN_CHARACTERS:
+            raise ValueError('The character %r is not a legal token character' % c)
+        else:
+            self.io.write(c)
+            return self.push_child(UnquotedValueState(self.io))
+
+    def child_complete(self, child):
+        return True
+
+class ValueTrailingWhitespaceState(State):
+    def eof(self):
+        return True
+
+    def consume(self, c):
+        if c in _LWS:
+            return False
+        elif c == ',':
+            return True
+        else:
+            raise ValueError("Expected whitespace, ',', or EOF")
+
+class BaseQuotedState(ParentState):
+    def __init__(self, io):
+        super(BaseQuotedState, self).__init__()
+        self.key_io = io
+
+    def consume(self, c):
+        if c == '\\':
+            return self.push_child(EscapedCharacterState(self.key_io))
+        elif c == '"':
+            return self.push_child(self.TrailingWhitespaceState())
+        else:
+            self.key_io.write(c)
+            return False
+
+    def child_complete(self, child):
+        if type(child) == EscapedCharacterState:
+            return False
+        elif type(child) == self.TrailingWhitespaceState:
+            return True
+
+class BaseUnquotedState(ParentState):
+    def __init__(self, io):
+        super(BaseUnquotedState, self).__init__()
+        self.io = io
+
+    def consume(self, c):
+        if c == self.terminating_character:
+            return True
+        elif c in _LWS:
+            return self.push_child(self.TrailingWhitespaceState())
+        elif c in _ILLEGAL_TOKEN_CHARACTERS:
+            raise ValueError('The character %r is not a legal token character' % c)
+        else:
+            self.io.write(c)
+            return False
+
+    def child_complete(self, child):
+        # type(child) == self.TrailingWhitespaceState
+        return True
+
+class QuotedKeyState(BaseQuotedState):
+    TrailingWhitespaceState = KeyTrailingWhitespaceState
+
+class QuotedValueState(BaseQuotedState):
+    TrailingWhitespaceState = ValueTrailingWhitespaceState
+
+class UnquotedKeyState(BaseUnquotedState):
+    TrailingWhitespaceState = KeyTrailingWhitespaceState
+    terminating_character = '='
+
+class UnquotedValueState(BaseUnquotedState):
+    TrailingWhitespaceState = ValueTrailingWhitespaceState
+    terminating_character  = ','
+
+    def eof(self):
+        return True
+
+class NewPartState(ParentState):
+    def __init__(self, parts):
+        super(NewPartState, self).__init__()
+        self.parts = parts
+        self.key_io = StringIO()
+        self.value_io = StringIO()
+
+    def consume(self, c):
+        if c in _LWS:
+            return False
+        elif c == '"':
+            return self.push_child(QuotedKeyState(self.key_io))
+        elif c in _ILLEGAL_TOKEN_CHARACTERS:
+            raise ValueError('The character %r is not a legal token character' % c)
+        else:
+            self.key_io.write(c)
+            return self.push_child(UnquotedKeyState(self.key_io))
+
+    def child_complete(self, child):
+        if type(child) in [QuotedKeyState, UnquotedKeyState]:
+            return self.push_child(ValueLeadingWhitespaceState(self.value_io))
+        else:
+            self.parts[self.key_io.getvalue()] = self.value_io.getvalue()
+            return True
+
+class FoundationState(ParentState):
+    def __init__(self, defaults):
+        super(FoundationState, self).__init__()
+        self.parts = defaults.copy()
+
+    def result(self):
+        return self.parts
+
+    def consume(self, c):
+        return self.push_child(NewPartState(self.parts), c)
+
+def parse_parts(parts_string, defaults={}):
+    state_machine = FoundationState(defaults)
+    index = 0
+    try:
+        for c in parts_string:
+            state_machine.character(c)
+            index += 1
+        state_machine.close()
+        return state_machine.result()
+    except ValueError, e:
+        annotated_parts_string = "%s[%s]%s" % (parts_string[0:index],
+                                               index < len(parts_string) and parts_string[index] or '',
+                                               index + 1 < len(parts_string) and parts_string[index+1:] or '')
+        l.exception("Failed to parse the Digest string "
+                    "(offending character is in []): %r" % annotated_parts_string)
+        return None
+
+def format_parts(**kwargs):
+    return ", ".join(['%s="%s"' % (k,v.encode('utf-8')) for (k,v) in kwargs.items()])
+
+def validate_uri(digest_uri, request_path):
+    digest_url_components = urlparse.urlparse(digest_uri)
+    return urllib.unquote(digest_url_components[2]) == request_path
+
+def validate_nonce(nonce, secret):
+    '''
+    Is the nonce one that was generated by this library using the provided secret?
+    '''
+    nonce_components = nonce.split(':', 2)
+    if not len(nonce_components) == 3:
+        return False
+    timestamp = nonce_components[0]
+    salt = nonce_components[1]
+    nonce_signature = nonce_components[2]
+
+    calculated_nonce = calculate_nonce(timestamp, secret, salt)
+
+    if not nonce == calculated_nonce:
+        return False
+
+    return True
+
+def calculate_partial_digest(username, realm, password):
+    '''
+    Calculate a partial digest that may be stored and used to authenticate future
+    HTTP Digest sessions.
+    '''
+    return md5.md5("%s:%s:%s" % (username.encode('utf-8'), realm, password.encode('utf-8'))).hexdigest()
+
+def build_digest_challenge(timestamp, secret, realm, opaque, stale):
+    '''
+    Builds a Digest challenge that may be sent as the value of the 'WWW-Authenticate' header
+    in a 401 or 403 response.
+
+    'opaque' may be any value - it will be returned by the client.
+
+    'timestamp' will be incorporated and signed in the nonce - it may be retrieved from the
+    client's authentication request using get_nonce_timestamp()
+    '''
+    nonce = calculate_nonce(timestamp, secret)
+
+    return 'Digest %s' % format_parts(realm=realm, qop='auth', nonce=nonce,
+                                      opaque=opaque, algorithm='MD5',
+                                      stale=stale and 'true' or 'false')
+
+def calculate_request_digest(method, partial_digest, digest_response=None,
+                             uri=None, nonce=None, nonce_count=None, client_nonce=None):
+    '''
+    Calculates a value for the 'response' value of the client authentication request.
+    Requires the 'partial_digest' calculated from the realm, username, and password.
+
+    Either call it with a digest_response to use the values from an authentication request,
+    or pass the individual parameters (i.e. to generate an authentication request).
+    '''
+    if digest_response:
+        if uri or nonce or nonce_count or client_nonce:
+            raise Exception("Both digest_response and one or more "
+                            "individual parameters were sent.")
+        uri = digest_response.uri
+        nonce = digest_response.nonce
+        nonce_count = digest_response.nc
+        client_nonce=digest_response.cnonce
+    elif not (uri and nonce and (nonce_count != None) and client_nonce):
+        raise Exception("Neither digest_response nor all individual parameters were sent.")
+
+    ha2 = md5.md5("%s:%s" % (method, uri)).hexdigest()
+    data = "%s:%s:%s:%s:%s" % (nonce, "%08x" % nonce_count, client_nonce, 'auth', ha2)
+    kd = md5.md5("%s:%s" % (partial_digest, data)).hexdigest()
+    return kd
+
+def get_nonce_timestamp(nonce):
+    '''
+    Extract the timestamp from a Nonce. To be sure the timestamp was generated by this site,
+    make sure you validate the nonce using validate_nonce().
+    '''
+    components = nonce.split(':',2)
+    if not len(components) == 3:
+        return None
+
+    try:
+        return float(components[0])
+    except ValueError:
+        return None
+
+def calculate_nonce(timestamp, secret, salt=None):
+    '''
+    Generate a nonce using the provided timestamp, secret, and salt. If the salt is not provided,
+    (and one should only be provided when validating a nonce) one will be generated randomly
+    in order to ensure that two simultaneous requests do not generate identical nonces.
+    '''
+    if not salt:
+        salt = ''.join([random.choice('0123456789ABCDEF') for x in range(4)])
+    return "%s:%s:%s" % (timestamp, salt,
+                         md5.md5("%s:%s:%s" % (timestamp, salt, secret)).hexdigest())
+
+def build_authorization_request(username, method, uri, nonce_count, digest_challenge=None,
+                                realm=None, nonce=None, opaque=None, password=None,
+                                request_digest=None, client_nonce=None):
+    '''
+    Builds an authorization request that may be sent as the value of the 'Authorization'
+    header in an HTTP request.
+
+    Either a digest_challenge object (as returned from parse_digest_challenge) or its required
+    component parameters (nonce, realm, opaque) must be provided.
+
+    The nonce_count should be the last used nonce_count plus one.
+
+    Either the password or the request_digest should be provided - if provided, the password
+    will be used to generate a request digest. The client_nonce is optional - if not provided,
+    a random value will be generated.
+    '''
+    if not client_nonce:
+        client_nonce =  ''.join([random.choice('0123456789ABCDEF') for x in range(32)])
+
+    if digest_challenge and (realm or nonce or opaque):
+        raise Exception("Both digest_challenge and one or more of realm, nonce, and opaque"
+                        "were sent.")
+
+    if digest_challenge:
+        if isinstance(digest_challenge, types.StringType):
+            digest_challenge_header = digest_challenge
+            digest_challenge = parse_digest_challenge(digest_challenge_header)
+            if not digest_challenge:
+                raise Exception("The provided digest challenge header could not be parsed: %s" %
+                                digest_challenge_header)
+        realm = digest_challenge.realm
+        nonce = digest_challenge.nonce
+        opaque = digest_challenge.opaque
+    elif not (realm and nonce and opaque):
+        raise Exception("Either digest_challenge or realm, nonce, and opaque must be sent.")
+
+    if password and request_digest:
+        raise Exception("Both password and calculated request_digest were sent.")
+    elif not request_digest:
+        if not password:
+            raise Exception("Either password or calculated request_digest must be provided.")
+
+        partial_digest = calculate_partial_digest(username, realm, password)
+        request_digest = calculate_request_digest(method, partial_digest, uri=uri, nonce=nonce,
+                                                  nonce_count=nonce_count,
+                                                  client_nonce=client_nonce)
+
+    return 'Digest %s' % format_parts(username=username, realm=realm, nonce=nonce, uri=uri,
+                                      response=request_digest, algorithm='MD5', opaque=opaque,
+                                      qop='auth', nc='%08x' % nonce_count, cnonce=client_nonce)
+
+def _check_required_parts(parts, required_parts):
+    if parts == None:
+        return False
+
+    missing_parts = [part for part in required_parts if not part in parts]
+    return len(missing_parts) == 0
+
+def _build_object_from_parts(parts, names):
+    obj = type("", (), {})()
+    for part_name in names:
+        val = parts[part_name]
+        if isinstance(val, basestring):
+            val = unicode(val, "utf-8")
+        setattr(obj, part_name, val)
+    return obj
+
+def parse_digest_response(digest_response_string):
+    '''
+    Parse the parameters of a Digest response. The input is a comma separated list of
+    token=(token|quoted-string). See RFCs 2616 and 2617 for details.
+
+    Known issue: this implementation will fail if there are commas embedded in quoted-strings.
+    '''
+
+    parts = parse_parts(digest_response_string, defaults={'algorithm': 'MD5'})
+    if not _check_required_parts(parts, _REQUIRED_DIGEST_RESPONSE_PARTS):
+        return None
+
+    if not parts['nc'] or [c for c in parts['nc'] if not c in '0123456789abcdefABCDEF']:
+        return None
+    parts['nc'] = int(parts['nc'], 16)
+
+    digest_response = _build_object_from_parts(parts, _REQUIRED_DIGEST_RESPONSE_PARTS)
+    if ('MD5', 'auth') != (digest_response.algorithm, digest_response.qop):
+        return None
+
+    return digest_response
+
+def is_digest_credential(authorization_header):
+    '''
+    Determines if the header value is potentially a Digest response sent by a client (i.e.
+    if it starts with 'Digest ' (case insensitive).
+    '''
+    return authorization_header[:7].lower() == 'digest '
+
+def parse_digest_credentials(authorization_header):
+    '''
+    Parses the value of an 'Authorization' header. Returns an object with properties
+    corresponding to each of the recognized parameters in the header.
+    '''
+    if not is_digest_credential(authorization_header):
+        return None
+
+    return parse_digest_response(authorization_header[7:])
+
+def is_digest_challenge(authentication_header):
+    '''
+    Determines if the header value is potentially a Digest challenge sent by a server (i.e.
+    if it starts with 'Digest ' (case insensitive).
+    '''
+    return authentication_header[:7].lower() == 'digest '
+
+def parse_digest_challenge(authentication_header):
+    '''
+    Parses the value of a 'WWW-Authenticate' header. Returns an object with properties
+    corresponding to each of the recognized parameters in the header.
+    '''
+    if not is_digest_challenge(authentication_header):
+        return None
+
+    parts = parse_parts(authentication_header[7:], defaults={'algorithm': 'MD5',
+                                                             'stale': 'false'})
+    if not _check_required_parts(parts, _REQUIRED_DIGEST_CHALLENGE_PARTS):
+        return None
+
+    parts['stale'] = parts['stale'].lower() == 'true'
+
+    digest_challenge = _build_object_from_parts(parts, _REQUIRED_DIGEST_CHALLENGE_PARTS)
+    if ('MD5', 'auth') != (digest_challenge.algorithm, digest_challenge.qop):
+        return None
+
+    return digest_challenge
+

+ 20 - 257
websocket.py

@@ -1,21 +1,9 @@
-import os
-import re
 import socket
 import socket
 import ssl
 import ssl
-from hashlib import sha1
-from base64 import b64encode
-from urlparse import urlparse
 
 
 from frame import receive_frame
 from frame import receive_frame
-from errors import HandshakeError, SSLError
-
-
-WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
-WS_VERSION = '13'
-
-
-def split_stripped(value, delim=','):
-    return map(str.strip, str(value).split(delim)) if value else []
+from handshake import ServerHandshake, ClientHandshake
+from errors import SSLError
 
 
 
 
 class websocket(object):
 class websocket(object):
@@ -38,12 +26,13 @@ class websocket(object):
 
 
     Client example:
     Client example:
     >>> import twspy
     >>> import twspy
-    >>> sock = twspy.websocket()
+    >>> sock = twspy.websocket(location='/my/path')
     >>> sock.connect(('', 8000))
     >>> sock.connect(('', 8000))
     >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
     >>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
     """
     """
     def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
     def __init__(self, sock=None, protocols=[], extensions=[], origin=None,
-                 trusted_origins=[], sfamily=socket.AF_INET, sproto=0):
+                 trusted_origins=[], location='/', auth=None,
+                 sfamily=socket.AF_INET, sproto=0):
         """
         """
         Create a regular TCP socket of family `family` and protocol
         Create a regular TCP socket of family `family` and protocol
 
 
@@ -62,15 +51,23 @@ class websocket(object):
         has value not in this list, a HandshakeError is raised. If the list is
         has value not in this list, a HandshakeError is raised. If the list is
         empty (default), all origins are excepted.
         empty (default), all origins are excepted.
 
 
+        `location` is optional, used for the HTTP handshake. In a URL, this
+        would show as ws://host[:port]/path.
+
+        `auth` is optional, used for HTTP Basic or Digest authentication during
+        the handshake. It must be specified as a (username, password) tuple.
+
         `sfamily` and `sproto` are used for the regular socket constructor.
         `sfamily` and `sproto` are used for the regular socket constructor.
         """
         """
         self.protocols = protocols
         self.protocols = protocols
         self.extensions = extensions
         self.extensions = extensions
         self.origin = origin
         self.origin = origin
         self.trusted_origins = trusted_origins
         self.trusted_origins = trusted_origins
+        self.location = location
+        self.auth = auth
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.sock = sock or socket.socket(sfamily, socket.SOCK_STREAM, sproto)
         self.secure = False
         self.secure = False
-        self.handshake_started = False
+        self.handshake_sent = False
 
 
     def bind(self, address):
     def bind(self, address):
         self.sock.bind(address)
         self.sock.bind(address)
@@ -87,24 +84,20 @@ class websocket(object):
         """
         """
         sock, address = self.sock.accept()
         sock, address = self.sock.accept()
         wsock = websocket(sock)
         wsock = websocket(sock)
-        wsock.server_handshake()
+        ServerHandshake(wsock).perform()
+        wsock.handshake_sent = True
         return wsock, address
         return wsock, address
 
 
-    def connect(self, address, path='/', auth=None):
+    def connect(self, address):
         """
         """
         Equivalent to socket.connect(), but sends an client handshake request
         Equivalent to socket.connect(), but sends an client handshake request
         after connecting.
         after connecting.
 
 
         `address` is a (host, port) tuple of the server to connect to.
         `address` is a (host, port) tuple of the server to connect to.
-
-        `path` is optional, used as the *location* part of the HTTP handshake.
-        In a URL, this would show as ws://host[:port]/path.
-
-        `auth` is optional, used for the HTTP "Authorization" header of the
-        handshake request.
         """
         """
         self.sock.connect(address)
         self.sock.connect(address)
-        self.client_handshake(address, path, auth)
+        ClientHandshake(self).perform()
+        self.handshake_sent = True
 
 
     def send(self, *args):
     def send(self, *args):
         """
         """
@@ -145,243 +138,13 @@ class websocket(object):
     def close(self):
     def close(self):
         self.sock.close()
         self.sock.close()
 
 
-    def server_handshake(self):
-        """
-        Execute a handshake as the server end point of the socket. If the HTTP
-        request headers sent by the client are invalid, a HandshakeError
-        is raised.
-        """
-        def fail(msg):
-            self.sock.close()
-            raise HandshakeError(msg)
-
-        # Receive HTTP header
-        raw_headers = ''
-
-        while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
-            raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
-
-        # Request must be HTTP (at least 1.1) GET request, find the location
-        match = re.search(r'^GET (.*) HTTP/1.1\r\n', raw_headers)
-
-        if match is None:
-            fail('not a valid HTTP 1.1 GET request')
-
-        location = match.group(1)
-        headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
-        header_names = [name for name, value in headers]
-
-        def header(name):
-            return ', '.join([v for n, v in headers if n == name])
-
-        # Check if headers that MUST be present are actually present
-        for name in ('Host', 'Upgrade', 'Connection', 'Sec-WebSocket-Key',
-                     'Sec-WebSocket-Version'):
-            if name not in header_names:
-                fail('missing "%s" header' % name)
-
-        # Check WebSocket version used by client
-        version = header('Sec-WebSocket-Version')
-
-        if version != WS_VERSION:
-            fail('WebSocket version %s requested (only %s '
-                                 'is supported)' % (version, WS_VERSION))
-
-        # Verify required header keywords
-        if 'websocket' not in header('Upgrade').lower():
-            fail('"Upgrade" header must contain "websocket"')
-
-        if 'upgrade' not in header('Connection').lower():
-            fail('"Connection" header must contain "Upgrade"')
-
-        # Origin must be present if browser client, and must match the list of
-        # trusted origins
-        if 'Origin' not in header_names:
-            if 'User-Agent' in header_names:
-                fail('browser client must specify "Origin" header')
-
-            if self.trusted_origins:
-                fail('no "Origin" header specified, assuming untrusted')
-        elif self.trusted_origins:
-            origin = header('Origin')
-
-            if origin not in self.trusted_origins:
-                fail('untrusted origin "%s"' % origin)
-
-        # Only supported protocols are returned
-        client_protocols = split_stripped(header('Sec-WebSocket-Extensions'))
-        protocol = 'null'
-
-        for p in client_protocols:
-            if p in self.protocols:
-                protocol = p
-                break
-
-        # Only supported extensions are returned
-        extensions = split_stripped(header('Sec-WebSocket-Extensions'))
-        extensions = [e for e in extensions if e in self.extensions]
-
-        # Encode acceptation key using the WebSocket GUID
-        key = header('Sec-WebSocket-Key').strip()
-        accept = b64encode(sha1(key + WS_GUID).digest())
-
-        # Construct HTTP response header
-        shake = 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n'
-        shake += 'Upgrade: websocket\r\n'
-        shake += 'Connection: Upgrade\r\n'
-        shake += 'WebSocket-Origin: %s\r\n' % header('Origin')
-        shake += 'WebSocket-Location: ws://%s%s\r\n' \
-                 % (header('Host'), location)
-        shake += 'Sec-WebSocket-Accept: %s\r\n' % accept
-        shake += 'Sec-WebSocket-Protocol: %s\r\n' % protocol
-        shake += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(extensions)
-
-        self.sock.sendall(shake + '\r\n')
-        self.handshake_started = True
-
-    def client_handshake(self, address, location, auth):
-        """
-        Executes a handshake as the client end point of the socket. May raise a
-        HandshakeError if the server response is invalid.
-        """
-        def fail(msg):
-            self.sock.close()
-            raise HandshakeError(msg)
-
-        def send_request(location):
-            if len(location) == 0:
-                fail('request location is empty')
-
-            # Generate a 16-byte random base64-encoded key for this connection
-            key = b64encode(os.urandom(16))
-
-            # Send client handshake
-            shake = 'GET %s HTTP/1.1\r\n' % location
-            shake += 'Host: %s:%d\r\n' % address
-            shake += 'Upgrade: websocket\r\n'
-            shake += 'Connection: keep-alive, Upgrade\r\n'
-            shake += 'Sec-WebSocket-Key: %s\r\n' % key
-            shake += 'Sec-WebSocket-Version: %s\r\n' % WS_VERSION
-
-            if self.origin:
-                shake += 'Origin: %s\r\n' % self.origin
-
-            # These are for eagerly caching webservers
-            shake += 'Pragma: no-cache\r\n'
-            shake += 'Cache-Control: no-cache\r\n'
-
-            # Request protocols and extension, these are later checked with the
-            # actual supported values from the server's response
-            if self.protocols:
-                shake += 'Sec-WebSocket-Protocol: %s\r\n' \
-                         % ', '.join(self.protocols)
-
-            if self.extensions:
-                shake += 'Sec-WebSocket-Extensions: %s\r\n' \
-                         % ', '.join(self.extensions)
-
-            if auth:
-                shake += 'Authorization: %s\r\n' % auth
-
-            self.sock.sendall(shake + '\r\n')
-            return key
-
-        def receive_response(key):
-            # Receive and process server handshake
-            raw_headers = ''
-
-            while raw_headers[-4:] not in ('\r\n\r\n', '\n\n'):
-                raw_headers += self.sock.recv(512).decode('utf-8', 'ignore')
-
-            # Response must be HTTP (at least 1.1) with status 101
-            match = re.search(r'^HTTP/1\.1 (\d{3})', raw_headers)
-
-            if match is None:
-                fail('not a valid HTTP 1.1 response')
-
-            status = int(match.group(1))
-            headers = re.findall(r'(.*?): ?(.*?)\r\n', raw_headers)
-            header_names = [name for name, value in headers]
-
-            def header(name):
-                return ', '.join([v for n, v in headers if n == name])
-
-            if status == 401:
-                # HTTP authentication is required in the request
-                raise HandshakeError('HTTP authentication required: %s'
-                                     % header('WWW-Authenticate'))
-
-            if status in (301, 302, 303, 307, 308):
-                # Handle HTTP redirect
-                url = urlparse(header('Location').strip())
-
-                # Reconnect socket if net location changed
-                if not url.port:
-                    url.port = 443 if self.secure else 80
-
-                addr = (url.netloc, url.port)
-
-                if addr != self.sock.getpeername():
-                    self.sock.close()
-                    self.sock.connect(addr)
-
-                # Send new handshake
-                receive_response(send_request(url.path))
-                return
-
-            if status != 101:
-                # 101 means server has accepted the connection and sent
-                # handshake headers
-                fail('invalid HTTP response status %d' % status)
-
-            # Check if headers that MUST be present are actually present
-            for name in ('Upgrade', 'Connection', 'Sec-WebSocket-Accept'):
-                if name not in header_names:
-                    fail('missing "%s" header' % name)
-
-            if 'websocket' not in header('Upgrade').lower():
-                fail('"Upgrade" header must contain "websocket"')
-
-            if 'upgrade' not in header('Connection').lower():
-                fail('"Connection" header must contain "Upgrade"')
-
-            # Verify accept header
-            accept = header('Sec-WebSocket-Accept').strip()
-            required_accept = b64encode(sha1(key + WS_GUID).digest())
-
-            if accept != required_accept:
-                fail('invalid websocket accept header "%s"' % accept)
-
-            # Compare extensions
-            server_ext = split_stripped(header('Sec-WebSocket-Extensions'))
-
-            for e in server_ext:
-                if e not in self.extensions:
-                    fail('server extension "%s" is unsupported by client' % e)
-
-            for e in self.extensions:
-                if e not in server_ext:
-                    fail('client extension "%s" is unsupported by server' % e)
-
-            # Assert that returned protocol (if any) is supported
-            protocol = header('Sec-WebSocket-Protocol')
-
-            if protocol:
-                if protocol != 'null' and protocol not in self.protocols:
-                    fail('unsupported protocol "%s"' % protocol)
-
-                self.protocol = protocol
-
-        self.handshake_started = True
-        receive_response(send_request(location))
-
     def enable_ssl(self, *args, **kwargs):
     def enable_ssl(self, *args, **kwargs):
         """
         """
         Transforms the regular socket.socket to an ssl.SSLSocket for secure
         Transforms the regular socket.socket to an ssl.SSLSocket for secure
         connections. Any arguments are passed to ssl.wrap_socket:
         connections. Any arguments are passed to ssl.wrap_socket:
         http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
         """
         """
-        if self.handshake_started:
+        if self.handshake_sent:
             raise SSLError('can only enable SSL before handshake')
             raise SSLError('can only enable SSL before handshake')
 
 
         self.secure = True
         self.secure = True