Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
W
wspy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Taddeüs Kroes
wspy
Commits
ba21de4a
Commit
ba21de4a
authored
Jul 26, 2013
by
Taddeüs Kroes
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactored handshaking process, moved it to a separate file, implemented HTTP authentication
parent
a215384a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
870 additions
and
257 deletions
+870
-257
handshake.py
handshake.py
+331
-0
python_digest.py
python_digest.py
+519
-0
websocket.py
websocket.py
+20
-257
No files found.
handshake.py
0 → 100644
View file @
ba21de4a
import
os
import
re
from
hashlib
import
sha1
from
base64
import
b64encode
from
urlparse
import
urlparse
from
python_digest
import
build_authorization_request
from
errors
import
HandshakeError
WS_GUID
=
'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION
=
'13'
MAX_REDIRECTS
=
10
def
split_stripped
(
value
,
delim
=
','
):
return
map
(
str
.
strip
,
str
(
value
).
split
(
delim
))
if
value
else
[]
class
Handshake
(
object
):
def
__init__
(
self
,
wsock
):
self
.
wsock
=
wsock
self
.
sock
=
wsock
.
sock
def
fail
(
self
,
msg
):
self
.
sock
.
close
()
raise
HandshakeError
(
msg
)
def
receive_request
(
self
):
raw
,
headers
=
self
.
receive_headers
()
# Request must be HTTP (at least 1.1) GET request, find the location
match
=
re
.
search
(
r'^GET (.*) HTTP/1.1\r\n'
,
raw
)
if
match
is
None
:
self
.
fail
(
'not a valid HTTP 1.1 GET request'
)
location
=
match
.
group
(
1
)
return
location
,
headers
def
receive_response
(
self
):
raw
,
headers
=
self
.
receive_headers
()
# Response must be HTTP (at least 1.1) with status 101
match
=
re
.
search
(
r'^HTTP/1\
.
1 (\
d{
3})'
,
raw
)
if
match
is
None
:
self
.
fail
(
'not a valid HTTP 1.1 response'
)
status
=
int
(
match
.
group
(
1
))
return
status
,
headers
def
receive_headers
(
self
):
# Receive entire HTTP header
raw_headers
=
''
while
raw_headers
[
-
4
:]
not
in
(
'
\
r
\
n
\
r
\
n
'
,
'
\
n
\
n
'
):
raw_headers
+=
self
.
sock
.
recv
(
512
).
decode
(
'utf-8'
,
'ignore'
)
headers
=
{}
for
key
,
value
in
re
.
findall
(
r'(.*?): ?(.*?)\r\n'
,
raw_headers
):
if
key
in
headers
:
headers
[
key
]
+=
', '
+
value
else
:
headers
[
key
]
=
value
return
raw_headers
,
headers
def
send_headers
(
self
,
headers
):
# Send request
for
hdr
in
list
(
headers
):
if
isinstance
(
hdr
,
tuple
):
hdr
=
'%s: %s'
%
hdr
self
.
sock
.
sendall
(
hdr
+
'
\
r
\
n
'
)
self
.
sock
.
sendall
(
'
\
r
\
n
'
)
def
perform
(
self
):
raise
NotImplementedError
class
ServerHandshake
(
Handshake
):
"""
Executes a handshake as the server end point of the socket. If the HTTP
request headers sent by the client are invalid, a HandshakeError is raised.
"""
def
perform
(
self
):
# Receive and validate client handshake
location
,
headers
=
self
.
receive_request
()
# Send server handshake in response
self
.
send_headers
(
self
.
response_headers
(
location
,
headers
))
def
response_headers
(
self
,
location
,
headers
):
# Check if headers that MUST be present are actually present
for
name
in
(
'Host'
,
'Upgrade'
,
'Connection'
,
'Sec-WebSocket-Key'
,
'Sec-WebSocket-Version'
):
if
name
not
in
headers
:
self
.
fail
(
'missing "%s" header'
%
name
)
# Check WebSocket version used by client
version
=
headers
[
'Sec-WebSocket-Version'
]
if
version
!=
WS_VERSION
:
self
.
fail
(
'WebSocket version %s requested (only %s is supported)'
%
(
version
,
WS_VERSION
))
# Verify required header keywords
if
'websocket'
not
in
headers
[
'Upgrade'
].
lower
():
self
.
fail
(
'"Upgrade" header must contain "websocket"'
)
if
'upgrade'
not
in
headers
[
'Connection'
].
lower
():
self
.
fail
(
'"Connection" header must contain "Upgrade"'
)
# Origin must be present if browser client, and must match the list of
# trusted origins
if
'Origin'
not
in
headers
:
if
'User-Agent'
in
headers
:
self
.
fail
(
'browser client must specify "Origin" header'
)
if
self
.
wsock
.
trusted_origins
:
self
.
fail
(
'no "Origin" header specified, assuming untrusted'
)
origin
=
'null'
elif
self
.
wsock
.
trusted_origins
:
origin
=
headers
[
'Origin'
]
if
origin
not
in
self
.
wsock
.
trusted_origins
:
self
.
fail
(
'untrusted origin "%s"'
%
origin
)
# Only a supported protocol can be returned
client_proto
=
split_stripped
(
headers
[
'Sec-WebSocket-Protocol'
])
\
if
'Sec-WebSocket-Protocol'
in
headers
else
[]
protocol
=
'null'
for
p
in
client_proto
:
if
p
in
self
.
wsock
.
proto
:
protocol
=
p
break
# Only supported extensions are returned
if
'Sec-WebSocket-Extensions'
in
headers
:
client_ext
=
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
])
extensions
=
[
e
for
e
in
client_ext
if
e
in
self
.
wsock
.
extensions
]
else
:
extensions
=
[]
# Encode acceptation key using the WebSocket GUID
key
=
headers
[
'Sec-WebSocket-Key'
].
strip
()
accept
=
b64encode
(
sha1
(
key
+
WS_GUID
).
digest
())
# Location scheme differs for SSL-enabled connections
scheme
=
'wss'
if
self
.
wsock
.
secure
else
'ws'
if
'Host'
in
headers
:
host
=
headers
[
'Host'
]
else
:
host
,
port
=
self
.
sock
.
getpeername
()
default_port
=
443
if
self
.
wsock
.
secure
else
80
if
port
!=
default_port
:
host
+=
':%d'
%
port
# Construct HTTP response header
yield
'HTTP/1.1 101 Web Socket Protocol Handshake'
yield
'Upgrade'
,
'websocket'
yield
'Connection'
,
'Upgrade'
yield
'WebSocket-Origin'
,
origin
yield
'WebSocket-Location'
,
'%s://%s%s'
%
(
scheme
,
host
,
location
)
yield
'Sec-WebSocket-Accept'
,
accept
yield
'Sec-WebSocket-Protocol'
,
protocol
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
extensions
)
class
ClientHandshake
(
Handshake
):
"""
Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid.
"""
def
__init__
(
self
,
wsock
):
Handshake
.
__init__
(
self
,
wsock
)
self
.
redirects
=
0
def
perform
(
self
):
self
.
send_headers
(
self
.
request_headers
())
self
.
handle_response
(
*
self
.
receive_response
())
def
handle_response
(
self
,
status
,
headers
):
if
status
==
101
:
self
.
handle_handshake
(
headers
)
elif
status
==
401
:
self
.
handle_auth
(
headers
)
elif
status
in
(
301
,
302
,
303
,
307
,
308
):
self
.
handle_redirect
(
headers
)
else
:
self
.
fail
(
'invalid HTTP response status %d'
%
status
)
def
handle_handshake
(
self
,
headers
):
# Check if headers that MUST be present are actually present
for
name
in
(
'Upgrade'
,
'Connection'
,
'Sec-WebSocket-Accept'
):
if
name
not
in
headers
:
self
.
fail
(
'missing "%s" header'
%
name
)
if
'websocket'
not
in
headers
[
'Upgrade'
].
lower
():
self
.
fail
(
'"Upgrade" header must contain "websocket"'
)
if
'upgrade'
not
in
headers
[
'Connection'
].
lower
():
self
.
fail
(
'"Connection" header must contain "Upgrade"'
)
# Verify accept header
accept
=
headers
[
'Sec-WebSocket-Accept'
].
strip
()
required_accept
=
b64encode
(
sha1
(
self
.
key
+
WS_GUID
).
digest
())
if
accept
!=
required_accept
:
self
.
fail
(
'invalid websocket accept header "%s"'
%
accept
)
# Compare extensions
if
'Sec-WebSocket-Extensions'
in
headers
:
server_ext
=
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
])
for
e
in
set
(
server_ext
)
-
set
(
self
.
wsock
.
extensions
):
self
.
fail
(
'server extension "%s" unsupported by client'
%
e
)
for
e
in
set
(
self
.
wsock
.
extensions
)
-
set
(
server_ext
):
self
.
fail
(
'client extension "%s" unsupported by server'
%
e
)
# Assert that returned protocol (if any) is supported
if
'Sec-WebSocket-Protocol'
in
headers
:
protocol
=
headers
[
'Sec-WebSocket-Protocol'
]
if
protocol
!=
'null'
and
protocol
not
in
self
.
wsock
.
protocols
:
self
.
fail
(
'unsupported protocol "%s"'
%
protocol
)
self
.
wsock
.
protocol
=
protocol
def
handle_auth
(
self
,
headers
):
# HTTP authentication is required in the request
hdr
=
headers
[
'WWW-Authenticate'
]
authres
=
dict
(
re
.
findall
(
r'(\
w+)[:=] ?
"?(\
w+)
"?'
,
hdr
))
mode
=
hdr
.
lstrip
().
split
(
' '
,
1
)[
0
]
if
not
self
.
wsock
.
auth
:
self
.
fail
(
'missing username and password for HTTP authentication'
)
if
mode
==
'Basic'
:
auth_hdr
=
self
.
http_auth_basic_headers
(
**
authres
)
elif
mode
==
'Digest'
:
auth_hdr
=
self
.
http_auth_digest_headers
(
**
authres
)
else
:
self
.
fail
(
'unsupported HTTP authentication mode "%s"'
%
mode
)
# Send new, authenticated handshake
self
.
send_headers
(
list
(
self
.
request_headers
())
+
list
(
auth_hdr
))
self
.
handle_response
(
*
self
.
receive_response
())
def
handle_redirect
(
self
,
headers
):
self
.
redirects
+=
1
if
self
.
redirects
>
MAX_REDIRECTS
:
self
.
fail
(
'reached maximum number of redirects (%d)'
%
MAX_REDIRECTS
)
# Handle HTTP redirect
url
=
urlparse
(
headers
[
'Location'
].
strip
())
# Reconnect socket to new host if net location changed
if
not
url
.
port
:
url
.
port
=
443
if
self
.
secure
else
80
addr
=
(
url
.
netloc
,
url
.
port
)
if
addr
!=
self
.
sock
.
getpeername
():
self
.
sock
.
close
()
self
.
sock
.
connect
(
addr
)
# Update websocket object and send new handshake
self
.
wsock
.
location
=
url
.
path
self
.
perform
()
def
request_headers
(
self
):
if
len
(
self
.
wsock
.
location
)
==
0
:
self
.
fail
(
'request location is empty'
)
# Generate a 16-byte random base64-encoded key for this connection
self
.
key
=
b64encode
(
os
.
urandom
(
16
))
# Send client handshake
yield
'GET %s HTTP/1.1'
%
self
.
wsock
.
location
yield
'Host'
,
'%s:%d'
%
self
.
sock
.
getpeername
()
yield
'Upgrade'
,
'websocket'
yield
'Connection'
,
'keep-alive, Upgrade'
yield
'Sec-WebSocket-Key'
,
self
.
key
yield
'Sec-WebSocket-Version'
,
WS_VERSION
if
self
.
wsock
.
origin
:
yield
'Origin'
,
self
.
wsock
.
origin
# These are for eagerly caching webservers
yield
'Pragma'
,
'no-cache'
yield
'Cache-Control'
,
'no-cache'
# Request protocols and extension, these are later checked with the
# actual supported values from the server's response
if
self
.
wsock
.
protocols
:
yield
'Sec-WebSocket-Protocol'
,
', '
.
join
(
self
.
wsock
.
protocols
)
if
self
.
wsock
.
extensions
:
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
self
.
wsock
.
extensions
)
def
http_auth_basic_headers
(
self
,
**
kwargs
):
u
,
p
=
self
.
wsock
.
auth
u
=
u
.
encode
(
'utf-8'
)
p
=
p
.
encode
(
'utf-8'
)
yield
'Authorization'
,
'Basic '
+
b64encode
(
u
+
':'
+
p
)
def
http_auth_digest_headers
(
self
,
**
kwargs
):
username
,
password
=
self
.
wsock
.
auth
yield
'Authorization'
,
build_authorization_request
(
username
=
username
.
encode
(
'utf-8'
),
method
=
'GET'
,
uri
=
self
.
wsock
.
location
,
nonce_count
=
0
,
realm
=
kwargs
[
'realm'
],
nonce
=
kwargs
[
'nonce'
],
opaque
=
kwargs
[
'opaque'
],
password
=
password
.
encode
(
'utf-8'
))
python_digest.py
0 → 100644
View file @
ba21de4a
'''
Copyright (c) 2009, Akoha, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of python-digest nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
try
:
import
hashlib
as
md5
except
ImportError
:
# Python <2.5
import
md5
try
:
from
cStringIO
import
StringIO
except
ImportError
:
from
StringIO
import
StringIO
import
random
import
types
import
urllib
import
urlparse
import
logging
# Make sure a NullHandler is available
# This was added in Python 2.7/3.2
try
:
from
logging
import
NullHandler
except
ImportError
:
class
NullHandler
(
logging
.
Handler
):
def
emit
(
self
,
record
):
pass
_REQUIRED_DIGEST_RESPONSE_PARTS
=
[
'username'
,
'realm'
,
'nonce'
,
'uri'
,
'response'
,
'algorithm'
,
'opaque'
,
'qop'
,
'nc'
,
'cnonce'
]
_REQUIRED_DIGEST_CHALLENGE_PARTS
=
[
'realm'
,
'nonce'
,
'stale'
,
'algorithm'
,
'opaque'
,
'qop'
]
l
=
logging
.
getLogger
(
__name__
)
l
.
addHandler
(
NullHandler
())
_LWS
=
[
chr
(
9
),
' '
,
'
\
r
'
,
'
\
n
'
]
_ILLEGAL_TOKEN_CHARACTERS
=
(
[
chr
(
n
)
for
n
in
range
(
0
-
31
)]
+
# control characters
[
chr
(
127
)]
+
# DEL
[
'('
,
')'
,
'<'
,
'>'
,
'@'
,
','
,
';'
,
':'
,
'
\
\
'
,
'"'
,
'/'
,
'['
,
']'
,
'?'
,
'='
,
'{'
,
'}'
,
' '
]
+
[
chr
(
9
)])
# horizontal tab
class
State
(
object
):
def
character
(
self
,
c
):
return
self
.
consume
(
c
)
def
close
(
self
):
return
self
.
eof
()
def
eof
(
self
):
raise
ValueError
(
'EOF not permitted in this state.'
)
'''
Return False to keep the current state, or True to pop it
'''
def
consume
(
c
):
raise
Exception
(
'Unimplemented'
)
class
ParentState
(
State
):
def
__init__
(
self
):
super
(
State
,
self
).
__init__
()
self
.
child
=
None
def
close
(
self
):
if
self
.
child
:
return
self
.
handle_child_return
(
self
.
child
.
close
())
else
:
return
self
.
eof
()
def
push_child
(
self
,
child
,
c
=
None
):
self
.
child
=
child
if
c
is
not
None
:
return
self
.
send_to_child
(
c
)
else
:
return
False
def
send_to_child
(
self
,
c
):
return
self
.
handle_child_return
(
self
.
child
.
character
(
c
))
def
handle_child_return
(
self
,
returned_value
):
if
returned_value
:
child
=
self
.
child
self
.
child
=
None
return
self
.
child_complete
(
child
)
return
False
'''
Return False to keep the current state, or True to pop it.
'''
def
child_complete
(
self
,
child
):
return
False
def
character
(
self
,
c
):
if
self
.
child
:
return
self
.
send_to_child
(
c
)
else
:
return
self
.
consume
(
c
)
def
consume
(
self
,
c
):
return
False
class
EscapedCharacterState
(
State
):
def
__init__
(
self
,
io
):
super
(
EscapedCharacterState
,
self
).
__init__
()
self
.
io
=
io
def
consume
(
self
,
c
):
self
.
io
.
write
(
c
)
return
True
class
KeyTrailingWhitespaceState
(
State
):
def
consume
(
self
,
c
):
if
c
in
_LWS
:
return
False
elif
c
==
'='
:
return
True
else
:
raise
ValueError
(
"Expected whitespace or '='"
)
class
ValueLeadingWhitespaceState
(
ParentState
):
def
__init__
(
self
,
io
):
super
(
ValueLeadingWhitespaceState
,
self
).
__init__
()
self
.
io
=
io
def
consume
(
self
,
c
):
if
c
in
_LWS
:
return
False
elif
c
==
'"'
:
return
self
.
push_child
(
QuotedValueState
(
self
.
io
))
elif
c
in
_ILLEGAL_TOKEN_CHARACTERS
:
raise
ValueError
(
'The character %r is not a legal token character'
%
c
)
else
:
self
.
io
.
write
(
c
)
return
self
.
push_child
(
UnquotedValueState
(
self
.
io
))
def
child_complete
(
self
,
child
):
return
True
class
ValueTrailingWhitespaceState
(
State
):
def
eof
(
self
):
return
True
def
consume
(
self
,
c
):
if
c
in
_LWS
:
return
False
elif
c
==
','
:
return
True
else
:
raise
ValueError
(
"Expected whitespace, ',', or EOF"
)
class
BaseQuotedState
(
ParentState
):
def
__init__
(
self
,
io
):
super
(
BaseQuotedState
,
self
).
__init__
()
self
.
key_io
=
io
def
consume
(
self
,
c
):
if
c
==
'
\
\
'
:
return
self
.
push_child
(
EscapedCharacterState
(
self
.
key_io
))
elif
c
==
'"'
:
return
self
.
push_child
(
self
.
TrailingWhitespaceState
())
else
:
self
.
key_io
.
write
(
c
)
return
False
def
child_complete
(
self
,
child
):
if
type
(
child
)
==
EscapedCharacterState
:
return
False
elif
type
(
child
)
==
self
.
TrailingWhitespaceState
:
return
True
class
BaseUnquotedState
(
ParentState
):
def
__init__
(
self
,
io
):
super
(
BaseUnquotedState
,
self
).
__init__
()
self
.
io
=
io
def
consume
(
self
,
c
):
if
c
==
self
.
terminating_character
:
return
True
elif
c
in
_LWS
:
return
self
.
push_child
(
self
.
TrailingWhitespaceState
())
elif
c
in
_ILLEGAL_TOKEN_CHARACTERS
:
raise
ValueError
(
'The character %r is not a legal token character'
%
c
)
else
:
self
.
io
.
write
(
c
)
return
False
def
child_complete
(
self
,
child
):
# type(child) == self.TrailingWhitespaceState
return
True
class
QuotedKeyState
(
BaseQuotedState
):
TrailingWhitespaceState
=
KeyTrailingWhitespaceState
class
QuotedValueState
(
BaseQuotedState
):
TrailingWhitespaceState
=
ValueTrailingWhitespaceState
class
UnquotedKeyState
(
BaseUnquotedState
):
TrailingWhitespaceState
=
KeyTrailingWhitespaceState
terminating_character
=
'='
class
UnquotedValueState
(
BaseUnquotedState
):
TrailingWhitespaceState
=
ValueTrailingWhitespaceState
terminating_character
=
','
def
eof
(
self
):
return
True
class
NewPartState
(
ParentState
):
def
__init__
(
self
,
parts
):
super
(
NewPartState
,
self
).
__init__
()
self
.
parts
=
parts
self
.
key_io
=
StringIO
()
self
.
value_io
=
StringIO
()
def
consume
(
self
,
c
):
if
c
in
_LWS
:
return
False
elif
c
==
'"'
:
return
self
.
push_child
(
QuotedKeyState
(
self
.
key_io
))
elif
c
in
_ILLEGAL_TOKEN_CHARACTERS
:
raise
ValueError
(
'The character %r is not a legal token character'
%
c
)
else
:
self
.
key_io
.
write
(
c
)
return
self
.
push_child
(
UnquotedKeyState
(
self
.
key_io
))
def
child_complete
(
self
,
child
):
if
type
(
child
)
in
[
QuotedKeyState
,
UnquotedKeyState
]:
return
self
.
push_child
(
ValueLeadingWhitespaceState
(
self
.
value_io
))
else
:
self
.
parts
[
self
.
key_io
.
getvalue
()]
=
self
.
value_io
.
getvalue
()
return
True
class
FoundationState
(
ParentState
):
def
__init__
(
self
,
defaults
):
super
(
FoundationState
,
self
).
__init__
()
self
.
parts
=
defaults
.
copy
()
def
result
(
self
):
return
self
.
parts
def
consume
(
self
,
c
):
return
self
.
push_child
(
NewPartState
(
self
.
parts
),
c
)
def
parse_parts
(
parts_string
,
defaults
=
{}):
state_machine
=
FoundationState
(
defaults
)
index
=
0
try
:
for
c
in
parts_string
:
state_machine
.
character
(
c
)
index
+=
1
state_machine
.
close
()
return
state_machine
.
result
()
except
ValueError
,
e
:
annotated_parts_string
=
"%s[%s]%s"
%
(
parts_string
[
0
:
index
],
index
<
len
(
parts_string
)
and
parts_string
[
index
]
or
''
,
index
+
1
<
len
(
parts_string
)
and
parts_string
[
index
+
1
:]
or
''
)
l
.
exception
(
"Failed to parse the Digest string "
"(offending character is in []): %r"
%
annotated_parts_string
)
return
None
def
format_parts
(
**
kwargs
):
return
", "
.
join
([
'%s="%s"'
%
(
k
,
v
.
encode
(
'utf-8'
))
for
(
k
,
v
)
in
kwargs
.
items
()])
def
validate_uri
(
digest_uri
,
request_path
):
digest_url_components
=
urlparse
.
urlparse
(
digest_uri
)
return
urllib
.
unquote
(
digest_url_components
[
2
])
==
request_path
def
validate_nonce
(
nonce
,
secret
):
'''
Is the nonce one that was generated by this library using the provided secret?
'''
nonce_components
=
nonce
.
split
(
':'
,
2
)
if
not
len
(
nonce_components
)
==
3
:
return
False
timestamp
=
nonce_components
[
0
]
salt
=
nonce_components
[
1
]
nonce_signature
=
nonce_components
[
2
]
calculated_nonce
=
calculate_nonce
(
timestamp
,
secret
,
salt
)
if
not
nonce
==
calculated_nonce
:
return
False
return
True
def
calculate_partial_digest
(
username
,
realm
,
password
):
'''
Calculate a partial digest that may be stored and used to authenticate future
HTTP Digest sessions.
'''
return
md5
.
md5
(
"%s:%s:%s"
%
(
username
.
encode
(
'utf-8'
),
realm
,
password
.
encode
(
'utf-8'
))).
hexdigest
()
def
build_digest_challenge
(
timestamp
,
secret
,
realm
,
opaque
,
stale
):
'''
Builds a Digest challenge that may be sent as the value of the 'WWW-Authenticate' header
in a 401 or 403 response.
'opaque' may be any value - it will be returned by the client.
'timestamp' will be incorporated and signed in the nonce - it may be retrieved from the
client's authentication request using get_nonce_timestamp()
'''
nonce
=
calculate_nonce
(
timestamp
,
secret
)
return
'Digest %s'
%
format_parts
(
realm
=
realm
,
qop
=
'auth'
,
nonce
=
nonce
,
opaque
=
opaque
,
algorithm
=
'MD5'
,
stale
=
stale
and
'true'
or
'false'
)
def
calculate_request_digest
(
method
,
partial_digest
,
digest_response
=
None
,
uri
=
None
,
nonce
=
None
,
nonce_count
=
None
,
client_nonce
=
None
):
'''
Calculates a value for the 'response' value of the client authentication request.
Requires the 'partial_digest' calculated from the realm, username, and password.
Either call it with a digest_response to use the values from an authentication request,
or pass the individual parameters (i.e. to generate an authentication request).
'''
if
digest_response
:
if
uri
or
nonce
or
nonce_count
or
client_nonce
:
raise
Exception
(
"Both digest_response and one or more "
"individual parameters were sent."
)
uri
=
digest_response
.
uri
nonce
=
digest_response
.
nonce
nonce_count
=
digest_response
.
nc
client_nonce
=
digest_response
.
cnonce
elif
not
(
uri
and
nonce
and
(
nonce_count
!=
None
)
and
client_nonce
):
raise
Exception
(
"Neither digest_response nor all individual parameters were sent."
)
ha2
=
md5
.
md5
(
"%s:%s"
%
(
method
,
uri
)).
hexdigest
()
data
=
"%s:%s:%s:%s:%s"
%
(
nonce
,
"%08x"
%
nonce_count
,
client_nonce
,
'auth'
,
ha2
)
kd
=
md5
.
md5
(
"%s:%s"
%
(
partial_digest
,
data
)).
hexdigest
()
return
kd
def
get_nonce_timestamp
(
nonce
):
'''
Extract the timestamp from a Nonce. To be sure the timestamp was generated by this site,
make sure you validate the nonce using validate_nonce().
'''
components
=
nonce
.
split
(
':'
,
2
)
if
not
len
(
components
)
==
3
:
return
None
try
:
return
float
(
components
[
0
])
except
ValueError
:
return
None
def
calculate_nonce
(
timestamp
,
secret
,
salt
=
None
):
'''
Generate a nonce using the provided timestamp, secret, and salt. If the salt is not provided,
(and one should only be provided when validating a nonce) one will be generated randomly
in order to ensure that two simultaneous requests do not generate identical nonces.
'''
if
not
salt
:
salt
=
''
.
join
([
random
.
choice
(
'0123456789ABCDEF'
)
for
x
in
range
(
4
)])
return
"%s:%s:%s"
%
(
timestamp
,
salt
,
md5
.
md5
(
"%s:%s:%s"
%
(
timestamp
,
salt
,
secret
)).
hexdigest
())
def
build_authorization_request
(
username
,
method
,
uri
,
nonce_count
,
digest_challenge
=
None
,
realm
=
None
,
nonce
=
None
,
opaque
=
None
,
password
=
None
,
request_digest
=
None
,
client_nonce
=
None
):
'''
Builds an authorization request that may be sent as the value of the 'Authorization'
header in an HTTP request.
Either a digest_challenge object (as returned from parse_digest_challenge) or its required
component parameters (nonce, realm, opaque) must be provided.
The nonce_count should be the last used nonce_count plus one.
Either the password or the request_digest should be provided - if provided, the password
will be used to generate a request digest. The client_nonce is optional - if not provided,
a random value will be generated.
'''
if
not
client_nonce
:
client_nonce
=
''
.
join
([
random
.
choice
(
'0123456789ABCDEF'
)
for
x
in
range
(
32
)])
if
digest_challenge
and
(
realm
or
nonce
or
opaque
):
raise
Exception
(
"Both digest_challenge and one or more of realm, nonce, and opaque"
"were sent."
)
if
digest_challenge
:
if
isinstance
(
digest_challenge
,
types
.
StringType
):
digest_challenge_header
=
digest_challenge
digest_challenge
=
parse_digest_challenge
(
digest_challenge_header
)
if
not
digest_challenge
:
raise
Exception
(
"The provided digest challenge header could not be parsed: %s"
%
digest_challenge_header
)
realm
=
digest_challenge
.
realm
nonce
=
digest_challenge
.
nonce
opaque
=
digest_challenge
.
opaque
elif
not
(
realm
and
nonce
and
opaque
):
raise
Exception
(
"Either digest_challenge or realm, nonce, and opaque must be sent."
)
if
password
and
request_digest
:
raise
Exception
(
"Both password and calculated request_digest were sent."
)
elif
not
request_digest
:
if
not
password
:
raise
Exception
(
"Either password or calculated request_digest must be provided."
)
partial_digest
=
calculate_partial_digest
(
username
,
realm
,
password
)
request_digest
=
calculate_request_digest
(
method
,
partial_digest
,
uri
=
uri
,
nonce
=
nonce
,
nonce_count
=
nonce_count
,
client_nonce
=
client_nonce
)
return
'Digest %s'
%
format_parts
(
username
=
username
,
realm
=
realm
,
nonce
=
nonce
,
uri
=
uri
,
response
=
request_digest
,
algorithm
=
'MD5'
,
opaque
=
opaque
,
qop
=
'auth'
,
nc
=
'%08x'
%
nonce_count
,
cnonce
=
client_nonce
)
def
_check_required_parts
(
parts
,
required_parts
):
if
parts
==
None
:
return
False
missing_parts
=
[
part
for
part
in
required_parts
if
not
part
in
parts
]
return
len
(
missing_parts
)
==
0
def
_build_object_from_parts
(
parts
,
names
):
obj
=
type
(
""
,
(),
{})()
for
part_name
in
names
:
val
=
parts
[
part_name
]
if
isinstance
(
val
,
basestring
):
val
=
unicode
(
val
,
"utf-8"
)
setattr
(
obj
,
part_name
,
val
)
return
obj
def
parse_digest_response
(
digest_response_string
):
'''
Parse the parameters of a Digest response. The input is a comma separated list of
token=(token|quoted-string). See RFCs 2616 and 2617 for details.
Known issue: this implementation will fail if there are commas embedded in quoted-strings.
'''
parts
=
parse_parts
(
digest_response_string
,
defaults
=
{
'algorithm'
:
'MD5'
})
if
not
_check_required_parts
(
parts
,
_REQUIRED_DIGEST_RESPONSE_PARTS
):
return
None
if
not
parts
[
'nc'
]
or
[
c
for
c
in
parts
[
'nc'
]
if
not
c
in
'0123456789abcdefABCDEF'
]:
return
None
parts
[
'nc'
]
=
int
(
parts
[
'nc'
],
16
)
digest_response
=
_build_object_from_parts
(
parts
,
_REQUIRED_DIGEST_RESPONSE_PARTS
)
if
(
'MD5'
,
'auth'
)
!=
(
digest_response
.
algorithm
,
digest_response
.
qop
):
return
None
return
digest_response
def
is_digest_credential
(
authorization_header
):
'''
Determines if the header value is potentially a Digest response sent by a client (i.e.
if it starts with 'Digest ' (case insensitive).
'''
return
authorization_header
[:
7
].
lower
()
==
'digest '
def
parse_digest_credentials
(
authorization_header
):
'''
Parses the value of an 'Authorization' header. Returns an object with properties
corresponding to each of the recognized parameters in the header.
'''
if
not
is_digest_credential
(
authorization_header
):
return
None
return
parse_digest_response
(
authorization_header
[
7
:])
def
is_digest_challenge
(
authentication_header
):
'''
Determines if the header value is potentially a Digest challenge sent by a server (i.e.
if it starts with 'Digest ' (case insensitive).
'''
return
authentication_header
[:
7
].
lower
()
==
'digest '
def
parse_digest_challenge
(
authentication_header
):
'''
Parses the value of a 'WWW-Authenticate' header. Returns an object with properties
corresponding to each of the recognized parameters in the header.
'''
if
not
is_digest_challenge
(
authentication_header
):
return
None
parts
=
parse_parts
(
authentication_header
[
7
:],
defaults
=
{
'algorithm'
:
'MD5'
,
'stale'
:
'false'
})
if
not
_check_required_parts
(
parts
,
_REQUIRED_DIGEST_CHALLENGE_PARTS
):
return
None
parts
[
'stale'
]
=
parts
[
'stale'
].
lower
()
==
'true'
digest_challenge
=
_build_object_from_parts
(
parts
,
_REQUIRED_DIGEST_CHALLENGE_PARTS
)
if
(
'MD5'
,
'auth'
)
!=
(
digest_challenge
.
algorithm
,
digest_challenge
.
qop
):
return
None
return
digest_challenge
websocket.py
View file @
ba21de4a
import
os
import
re
import
socket
import
ssl
from
hashlib
import
sha1
from
base64
import
b64encode
from
urlparse
import
urlparse
from
frame
import
receive_frame
from
errors
import
HandshakeError
,
SSLError
WS_GUID
=
'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION
=
'13'
def
split_stripped
(
value
,
delim
=
','
):
return
map
(
str
.
strip
,
str
(
value
).
split
(
delim
))
if
value
else
[]
from
handshake
import
ServerHandshake
,
ClientHandshake
from
errors
import
SSLError
class
websocket
(
object
):
...
...
@@ -38,12 +26,13 @@ class websocket(object):
Client example:
>>> import twspy
>>> sock = twspy.websocket()
>>> sock = twspy.websocket(
location='/my/path'
)
>>> sock.connect(('', 8000))
>>> sock.send(twspy.Frame(twspy.OPCODE_TEXT, 'Hello, Server!'))
"""
def
__init__
(
self
,
sock
=
None
,
protocols
=
[],
extensions
=
[],
origin
=
None
,
trusted_origins
=
[],
sfamily
=
socket
.
AF_INET
,
sproto
=
0
):
trusted_origins
=
[],
location
=
'/'
,
auth
=
None
,
sfamily
=
socket
.
AF_INET
,
sproto
=
0
):
"""
Create a regular TCP socket of family `family` and protocol
...
...
@@ -62,15 +51,23 @@ class websocket(object):
has value not in this list, a HandshakeError is raised. If the list is
empty (default), all origins are excepted.
`location` is optional, used for the HTTP handshake. In a URL, this
would show as ws://host[:port]/path.
`auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple.
`sfamily` and `sproto` are used for the regular socket constructor.
"""
self
.
protocols
=
protocols
self
.
extensions
=
extensions
self
.
origin
=
origin
self
.
trusted_origins
=
trusted_origins
self
.
location
=
location
self
.
auth
=
auth
self
.
sock
=
sock
or
socket
.
socket
(
sfamily
,
socket
.
SOCK_STREAM
,
sproto
)
self
.
secure
=
False
self
.
handshake_s
tarted
=
False
self
.
handshake_s
ent
=
False
def
bind
(
self
,
address
):
self
.
sock
.
bind
(
address
)
...
...
@@ -87,24 +84,20 @@ class websocket(object):
"""
sock
,
address
=
self
.
sock
.
accept
()
wsock
=
websocket
(
sock
)
wsock
.
server_handshake
()
ServerHandshake
(
wsock
).
perform
()
wsock
.
handshake_sent
=
True
return
wsock
,
address
def
connect
(
self
,
address
,
path
=
'/'
,
auth
=
None
):
def
connect
(
self
,
address
):
"""
Equivalent to socket.connect(), but sends an client handshake request
after connecting.
`address` is a (host, port) tuple of the server to connect to.
`path` is optional, used as the *location* part of the HTTP handshake.
In a URL, this would show as ws://host[:port]/path.
`auth` is optional, used for the HTTP "Authorization" header of the
handshake request.
"""
self
.
sock
.
connect
(
address
)
self
.
client_handshake
(
address
,
path
,
auth
)
ClientHandshake
(
self
).
perform
()
self
.
handshake_sent
=
True
def
send
(
self
,
*
args
):
"""
...
...
@@ -145,243 +138,13 @@ class websocket(object):
def
close
(
self
):
self
.
sock
.
close
()
def
server_handshake
(
self
):
"""
Execute a handshake as the server end point of the socket. If the HTTP
request headers sent by the client are invalid, a HandshakeError
is raised.
"""
def
fail
(
msg
):
self
.
sock
.
close
()
raise
HandshakeError
(
msg
)
# Receive HTTP header
raw_headers
=
''
while
raw_headers
[
-
4
:]
not
in
(
'
\
r
\
n
\
r
\
n
'
,
'
\
n
\
n
'
):
raw_headers
+=
self
.
sock
.
recv
(
512
).
decode
(
'utf-8'
,
'ignore'
)
# Request must be HTTP (at least 1.1) GET request, find the location
match
=
re
.
search
(
r'^GET (.*) HTTP/1.1\r\n'
,
raw_headers
)
if
match
is
None
:
fail
(
'not a valid HTTP 1.1 GET request'
)
location
=
match
.
group
(
1
)
headers
=
re
.
findall
(
r'(.*?): ?(.*?)\r\n'
,
raw_headers
)
header_names
=
[
name
for
name
,
value
in
headers
]
def
header
(
name
):
return
', '
.
join
([
v
for
n
,
v
in
headers
if
n
==
name
])
# Check if headers that MUST be present are actually present
for
name
in
(
'Host'
,
'Upgrade'
,
'Connection'
,
'Sec-WebSocket-Key'
,
'Sec-WebSocket-Version'
):
if
name
not
in
header_names
:
fail
(
'missing "%s" header'
%
name
)
# Check WebSocket version used by client
version
=
header
(
'Sec-WebSocket-Version'
)
if
version
!=
WS_VERSION
:
fail
(
'WebSocket version %s requested (only %s '
'is supported)'
%
(
version
,
WS_VERSION
))
# Verify required header keywords
if
'websocket'
not
in
header
(
'Upgrade'
).
lower
():
fail
(
'"Upgrade" header must contain "websocket"'
)
if
'upgrade'
not
in
header
(
'Connection'
).
lower
():
fail
(
'"Connection" header must contain "Upgrade"'
)
# Origin must be present if browser client, and must match the list of
# trusted origins
if
'Origin'
not
in
header_names
:
if
'User-Agent'
in
header_names
:
fail
(
'browser client must specify "Origin" header'
)
if
self
.
trusted_origins
:
fail
(
'no "Origin" header specified, assuming untrusted'
)
elif
self
.
trusted_origins
:
origin
=
header
(
'Origin'
)
if
origin
not
in
self
.
trusted_origins
:
fail
(
'untrusted origin "%s"'
%
origin
)
# Only supported protocols are returned
client_protocols
=
split_stripped
(
header
(
'Sec-WebSocket-Extensions'
))
protocol
=
'null'
for
p
in
client_protocols
:
if
p
in
self
.
protocols
:
protocol
=
p
break
# Only supported extensions are returned
extensions
=
split_stripped
(
header
(
'Sec-WebSocket-Extensions'
))
extensions
=
[
e
for
e
in
extensions
if
e
in
self
.
extensions
]
# Encode acceptation key using the WebSocket GUID
key
=
header
(
'Sec-WebSocket-Key'
).
strip
()
accept
=
b64encode
(
sha1
(
key
+
WS_GUID
).
digest
())
# Construct HTTP response header
shake
=
'HTTP/1.1 101 Web Socket Protocol Handshake
\
r
\
n
'
shake
+=
'Upgrade: websocket
\
r
\
n
'
shake
+=
'Connection: Upgrade
\
r
\
n
'
shake
+=
'WebSocket-Origin: %s
\
r
\
n
'
%
header
(
'Origin'
)
shake
+=
'WebSocket-Location: ws://%s%s
\
r
\
n
'
\
%
(
header
(
'Host'
),
location
)
shake
+=
'Sec-WebSocket-Accept: %s
\
r
\
n
'
%
accept
shake
+=
'Sec-WebSocket-Protocol: %s
\
r
\
n
'
%
protocol
shake
+=
'Sec-WebSocket-Extensions: %s
\
r
\
n
'
%
', '
.
join
(
extensions
)
self
.
sock
.
sendall
(
shake
+
'
\
r
\
n
'
)
self
.
handshake_started
=
True
def
client_handshake
(
self
,
address
,
location
,
auth
):
"""
Executes a handshake as the client end point of the socket. May raise a
HandshakeError if the server response is invalid.
"""
def
fail
(
msg
):
self
.
sock
.
close
()
raise
HandshakeError
(
msg
)
def
send_request
(
location
):
if
len
(
location
)
==
0
:
fail
(
'request location is empty'
)
# Generate a 16-byte random base64-encoded key for this connection
key
=
b64encode
(
os
.
urandom
(
16
))
# Send client handshake
shake
=
'GET %s HTTP/1.1
\
r
\
n
'
%
location
shake
+=
'Host: %s:%d
\
r
\
n
'
%
address
shake
+=
'Upgrade: websocket
\
r
\
n
'
shake
+=
'Connection: keep-alive, Upgrade
\
r
\
n
'
shake
+=
'Sec-WebSocket-Key: %s
\
r
\
n
'
%
key
shake
+=
'Sec-WebSocket-Version: %s
\
r
\
n
'
%
WS_VERSION
if
self
.
origin
:
shake
+=
'Origin: %s
\
r
\
n
'
%
self
.
origin
# These are for eagerly caching webservers
shake
+=
'Pragma: no-cache
\
r
\
n
'
shake
+=
'Cache-Control: no-cache
\
r
\
n
'
# Request protocols and extension, these are later checked with the
# actual supported values from the server's response
if
self
.
protocols
:
shake
+=
'Sec-WebSocket-Protocol: %s
\
r
\
n
'
\
%
', '
.
join
(
self
.
protocols
)
if
self
.
extensions
:
shake
+=
'Sec-WebSocket-Extensions: %s
\
r
\
n
'
\
%
', '
.
join
(
self
.
extensions
)
if
auth
:
shake
+=
'Authorization: %s
\
r
\
n
'
%
auth
self
.
sock
.
sendall
(
shake
+
'
\
r
\
n
'
)
return
key
def
receive_response
(
key
):
# Receive and process server handshake
raw_headers
=
''
while
raw_headers
[
-
4
:]
not
in
(
'
\
r
\
n
\
r
\
n
'
,
'
\
n
\
n
'
):
raw_headers
+=
self
.
sock
.
recv
(
512
).
decode
(
'utf-8'
,
'ignore'
)
# Response must be HTTP (at least 1.1) with status 101
match
=
re
.
search
(
r'^HTTP/1\
.
1 (\
d{
3})'
,
raw_headers
)
if
match
is
None
:
fail
(
'not a valid HTTP 1.1 response'
)
status
=
int
(
match
.
group
(
1
))
headers
=
re
.
findall
(
r'(.*?): ?(.*?)\r\n'
,
raw_headers
)
header_names
=
[
name
for
name
,
value
in
headers
]
def
header
(
name
):
return
', '
.
join
([
v
for
n
,
v
in
headers
if
n
==
name
])
if
status
==
401
:
# HTTP authentication is required in the request
raise
HandshakeError
(
'HTTP authentication required: %s'
%
header
(
'WWW-Authenticate'
))
if
status
in
(
301
,
302
,
303
,
307
,
308
):
# Handle HTTP redirect
url
=
urlparse
(
header
(
'Location'
).
strip
())
# Reconnect socket if net location changed
if
not
url
.
port
:
url
.
port
=
443
if
self
.
secure
else
80
addr
=
(
url
.
netloc
,
url
.
port
)
if
addr
!=
self
.
sock
.
getpeername
():
self
.
sock
.
close
()
self
.
sock
.
connect
(
addr
)
# Send new handshake
receive_response
(
send_request
(
url
.
path
))
return
if
status
!=
101
:
# 101 means server has accepted the connection and sent
# handshake headers
fail
(
'invalid HTTP response status %d'
%
status
)
# Check if headers that MUST be present are actually present
for
name
in
(
'Upgrade'
,
'Connection'
,
'Sec-WebSocket-Accept'
):
if
name
not
in
header_names
:
fail
(
'missing "%s" header'
%
name
)
if
'websocket'
not
in
header
(
'Upgrade'
).
lower
():
fail
(
'"Upgrade" header must contain "websocket"'
)
if
'upgrade'
not
in
header
(
'Connection'
).
lower
():
fail
(
'"Connection" header must contain "Upgrade"'
)
# Verify accept header
accept
=
header
(
'Sec-WebSocket-Accept'
).
strip
()
required_accept
=
b64encode
(
sha1
(
key
+
WS_GUID
).
digest
())
if
accept
!=
required_accept
:
fail
(
'invalid websocket accept header "%s"'
%
accept
)
# Compare extensions
server_ext
=
split_stripped
(
header
(
'Sec-WebSocket-Extensions'
))
for
e
in
server_ext
:
if
e
not
in
self
.
extensions
:
fail
(
'server extension "%s" is unsupported by client'
%
e
)
for
e
in
self
.
extensions
:
if
e
not
in
server_ext
:
fail
(
'client extension "%s" is unsupported by server'
%
e
)
# Assert that returned protocol (if any) is supported
protocol
=
header
(
'Sec-WebSocket-Protocol'
)
if
protocol
:
if
protocol
!=
'null'
and
protocol
not
in
self
.
protocols
:
fail
(
'unsupported protocol "%s"'
%
protocol
)
self
.
protocol
=
protocol
self
.
handshake_started
=
True
receive_response
(
send_request
(
location
))
def
enable_ssl
(
self
,
*
args
,
**
kwargs
):
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
"""
if
self
.
handshake_s
tarted
:
if
self
.
handshake_s
ent
:
raise
SSLError
(
'can only enable SSL before handshake'
)
self
.
secure
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment