Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
W
wspy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Taddeüs Kroes
wspy
Commits
6c79550e
Commit
6c79550e
authored
Dec 20, 2014
by
Taddeüs Kroes
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'async'
parents
4d9fbb0c
447ee6fa
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
592 additions
and
224 deletions
+592
-224
README.md
README.md
+1
-0
__init__.py
__init__.py
+3
-2
async.py
async.py
+190
-0
connection.py
connection.py
+52
-47
deflate_frame.py
deflate_frame.py
+27
-32
errors.py
errors.py
+8
-4
extension.py
extension.py
+23
-24
frame.py
frame.py
+92
-22
handshake.py
handshake.py
+16
-17
server.py
server.py
+24
-19
test/client.py
test/client.py
+10
-2
test/talk.py
test/talk.py
+45
-0
websocket.py
websocket.py
+101
-55
No files found.
README.md
View file @
6c79550e
...
@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
...
@@ -22,6 +22,7 @@ Her is a quick overview of the features in this library:
-
Secure sockets using SSL certificates (for 'wss://...' URLs).
-
Secure sockets using SSL certificates (for 'wss://...' URLs).
-
The possibility to add extensions to the web socket protocol. An included
-
The possibility to add extensions to the web socket protocol. An included
implementation is
[
deflate-frame
](
http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
)
.
implementation is
[
deflate-frame
](
http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
)
.
-
Asynchronous sockets with an EPOLL-based server.
Installation
Installation
...
...
__init__.py
View file @
6c79550e
...
@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
...
@@ -4,10 +4,11 @@ from frame import Frame, ControlFrame, OPCODE_CONTINUATION, OPCODE_TEXT, \
OPCODE_BINARY
,
OPCODE_CLOSE
,
OPCODE_PING
,
OPCODE_PONG
,
CLOSE_NORMAL
,
\
OPCODE_BINARY
,
OPCODE_CLOSE
,
OPCODE_PING
,
OPCODE_PONG
,
CLOSE_NORMAL
,
\
CLOSE_GOING_AWAY
,
CLOSE_PROTOCOL_ERROR
,
CLOSE_NOACCEPT_DTYPE
,
\
CLOSE_GOING_AWAY
,
CLOSE_PROTOCOL_ERROR
,
CLOSE_NOACCEPT_DTYPE
,
\
CLOSE_INVALID_DATA
,
CLOSE_POLICY
,
CLOSE_MESSAGE_TOOBIG
,
\
CLOSE_INVALID_DATA
,
CLOSE_POLICY
,
CLOSE_MESSAGE_TOOBIG
,
\
CLOSE_MISSING_EXTENSIONS
,
CLOSE_UNABLE
CLOSE_MISSING_EXTENSIONS
,
CLOSE_UNABLE
,
read_frame
,
pop_frame
,
\
contains_frame
from
connection
import
Connection
from
connection
import
Connection
from
message
import
Message
,
TextMessage
,
BinaryMessage
from
message
import
Message
,
TextMessage
,
BinaryMessage
from
errors
import
SocketClosed
,
HandshakeError
,
PingError
,
SSLError
from
errors
import
SocketClosed
,
HandshakeError
,
PingError
,
SSLError
from
extension
import
Extension
from
extension
import
Extension
from
deflate_frame
import
DeflateFrame
,
WebkitDeflateFrame
from
deflate_frame
import
DeflateFrame
,
WebkitDeflateFrame
#from multiplex import Multiplex
from
async
import
AsyncConnection
,
AsyncServer
async.py
0 → 100644
View file @
6c79550e
import
socket
from
select
import
epoll
,
EPOLLIN
,
EPOLLOUT
,
EPOLLHUP
from
traceback
import
format_exc
import
logging
from
connection
import
Connection
from
frame
import
ControlFrame
,
OPCODE_PING
,
OPCODE_CONTINUATION
,
\
create_close_frame
from
server
import
Server
,
Client
from
errors
import
HandshakeError
,
SocketClosed
class
AsyncConnection
(
Connection
):
def
__init__
(
self
,
sock
):
sock
.
recv_callback
=
self
.
contruct_message
sock
.
recv_close_callback
=
self
.
onclose
self
.
recvbuf
=
[]
Connection
.
__init__
(
self
,
sock
)
def
contruct_message
(
self
,
frame
):
if
isinstance
(
frame
,
ControlFrame
):
self
.
handle_control_frame
(
frame
)
return
self
.
recvbuf
.
append
(
frame
)
if
frame
.
final
:
message
=
self
.
concat_fragments
(
self
.
recvbuf
)
self
.
recvbuf
=
[]
self
.
onmessage
(
message
)
elif
len
(
self
.
recvbuf
)
>
1
and
frame
.
opcode
!=
OPCODE_CONTINUATION
:
raise
ValueError
(
'expected continuation/control frame, got %s '
'instead'
%
frame
)
def
send
(
self
,
message
,
fragment_size
=
None
,
mask
=
False
):
frames
=
list
(
self
.
message_to_frames
(
message
,
fragment_size
,
mask
))
for
frame
in
frames
[:
-
1
]:
self
.
sock
.
queue_send
(
frame
)
self
.
sock
.
queue_send
(
frames
[
-
1
],
lambda
:
self
.
onsend
(
message
))
def
send_frame
(
self
,
frame
,
callback
):
self
.
sock
.
queue_send
(
frame
,
callback
)
def
do_async_send
(
self
):
self
.
execute_controlled
(
self
.
sock
.
do_async_send
)
def
do_async_recv
(
self
,
bufsize
):
self
.
execute_controlled
(
self
.
sock
.
do_async_recv
,
bufsize
)
def
execute_controlled
(
self
,
func
,
*
args
,
**
kwargs
):
try
:
func
(
*
args
,
**
kwargs
)
except
(
KeyboardInterrupt
,
SystemExit
,
SocketClosed
):
raise
except
Exception
as
e
:
self
.
onerror
(
e
)
self
.
onclose
(
None
,
'error: %s'
%
e
)
try
:
self
.
sock
.
close
()
except
socket
.
error
:
pass
raise
e
def
send_close_frame
(
self
,
code
,
reason
):
self
.
sock
.
queue_send
(
create_close_frame
(
code
,
reason
),
self
.
shutdown_write
)
self
.
close_frame_sent
=
True
def
close
(
self
,
code
=
None
,
reason
=
''
):
self
.
send_close_frame
(
code
,
reason
)
def
send_ping
(
self
,
payload
=
''
):
self
.
sock
.
queue_send
(
ControlFrame
(
OPCODE_PING
,
payload
),
lambda
:
self
.
onping
(
payload
))
self
.
ping_payload
=
payload
self
.
ping_sent
=
True
def
onsend
(
self
,
message
):
"""
Called after a message has been written.
"""
return
NotImplemented
class
AsyncServer
(
Server
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
Server
.
__init__
(
self
,
*
args
,
**
kwargs
)
self
.
recvbuf_size
=
kwargs
.
get
(
'recvbuf_size'
,
2048
)
self
.
epoll
=
epoll
()
self
.
epoll
.
register
(
self
.
sock
.
fileno
(),
EPOLLIN
)
self
.
conns
=
{}
@
property
def
clients
(
self
):
return
self
.
conns
.
values
()
def
remove_client
(
self
,
client
,
code
,
reason
):
self
.
epoll
.
unregister
(
client
.
fno
)
del
self
.
conns
[
client
.
fno
]
self
.
onclose
(
client
,
code
,
reason
)
def
handle_events
(
self
):
for
fileno
,
event
in
self
.
epoll
.
poll
(
1
):
if
fileno
==
self
.
sock
.
fileno
():
try
:
sock
,
addr
=
self
.
sock
.
accept
()
except
HandshakeError
as
e
:
logging
.
error
(
'Invalid request: %s'
,
e
.
message
)
continue
client
=
AsyncClient
(
self
,
sock
)
client
.
fno
=
sock
.
fileno
()
sock
.
setblocking
(
0
)
self
.
epoll
.
register
(
client
.
fno
,
EPOLLIN
)
self
.
conns
[
client
.
fno
]
=
client
logging
.
debug
(
'Registered client %s'
,
client
)
elif
event
&
EPOLLHUP
:
self
.
epoll
.
unregister
(
fileno
)
del
self
.
conns
[
fileno
]
else
:
conn
=
self
.
conns
[
fileno
]
try
:
if
event
&
EPOLLOUT
:
conn
.
do_async_send
()
elif
event
&
EPOLLIN
:
conn
.
do_async_recv
(
self
.
recvbuf_size
)
except
(
KeyboardInterrupt
,
SystemExit
):
raise
except
SocketClosed
:
continue
except
Exception
as
e
:
logging
.
error
(
format_exc
(
e
).
rstrip
())
continue
self
.
update_mask
(
conn
)
def
run
(
self
):
try
:
while
True
:
self
.
handle_events
()
except
(
KeyboardInterrupt
,
SystemExit
):
logging
.
info
(
'Received interrupt, stopping server...'
)
finally
:
self
.
epoll
.
unregister
(
self
.
sock
.
fileno
())
self
.
epoll
.
close
()
self
.
sock
.
close
()
def
update_mask
(
self
,
conn
):
mask
=
0
if
conn
.
sock
.
can_send
():
mask
|=
EPOLLOUT
if
conn
.
sock
.
can_recv
():
mask
|=
EPOLLIN
self
.
epoll
.
modify
(
conn
.
sock
.
fileno
(),
mask
)
def
onsend
(
self
,
client
,
message
):
return
NotImplemented
class
AsyncClient
(
Client
,
AsyncConnection
):
def
__init__
(
self
,
server
,
sock
):
self
.
server
=
server
AsyncConnection
.
__init__
(
self
,
sock
)
def
send
(
self
,
message
,
fragment_size
=
None
,
mask
=
False
):
logging
.
debug
(
'Enqueueing %s to %s'
,
message
,
self
)
AsyncConnection
.
send
(
self
,
message
,
fragment_size
,
mask
)
self
.
server
.
update_mask
(
self
)
def
onsend
(
self
,
message
):
logging
.
debug
(
'Finished sending %s to %s'
,
message
,
self
)
self
.
server
.
onsend
(
self
,
message
)
if
__name__
==
'__main__'
:
import
sys
port
=
int
(
sys
.
argv
[
1
])
if
len
(
sys
.
argv
)
>
1
else
8000
AsyncServer
((
''
,
port
),
loglevel
=
logging
.
DEBUG
).
run
()
connection.py
View file @
6c79550e
import
struct
import
socket
import
socket
from
frame
import
ControlFrame
,
OPCODE_CLOSE
,
OPCODE_PING
,
OPCODE_PONG
,
\
from
frame
import
ControlFrame
,
OPCODE_CLOSE
,
OPCODE_PING
,
OPCODE_PONG
,
\
OPCODE_CONTINUATION
OPCODE_CONTINUATION
,
create_close_frame
from
message
import
create_message
from
message
import
create_message
from
errors
import
SocketClosed
,
PingError
from
errors
import
SocketClosed
,
PingError
...
@@ -54,19 +53,30 @@ class Connection(object):
...
@@ -54,19 +53,30 @@ class Connection(object):
self
.
onopen
()
self
.
onopen
()
def
message_to_frames
(
self
,
message
,
fragment_size
=
None
,
mask
=
False
):
for
hook
in
self
.
hooks_send
:
message
=
hook
(
message
)
if
fragment_size
is
None
:
yield
message
.
frame
(
mask
=
mask
)
else
:
for
frame
in
message
.
fragment
(
fragment_size
,
mask
=
mask
):
yield
frame
def
send
(
self
,
message
,
fragment_size
=
None
,
mask
=
False
):
def
send
(
self
,
message
,
fragment_size
=
None
,
mask
=
False
):
"""
"""
Send a message. If `fragment_size` is specified, the message is
Send a message. If `fragment_size` is specified, the message is
fragmented into multiple frames whose payload size does not extend
fragmented into multiple frames whose payload size does not extend
`fragment_size`.
`fragment_size`.
"""
"""
for
hook
in
self
.
hooks_send
:
for
frame
in
self
.
message_to_frames
(
message
,
fragment_size
,
mask
)
:
message
=
hook
(
messag
e
)
self
.
send_frame
(
fram
e
)
if
fragment_size
is
None
:
def
send_frame
(
self
,
frame
,
callback
=
None
):
self
.
sock
.
send
(
message
.
frame
(
mask
=
mask
))
self
.
sock
.
send
(
frame
)
else
:
self
.
sock
.
send
(
*
message
.
fragment
(
fragment_size
,
mask
=
mask
))
if
callback
:
callback
()
def
recv
(
self
):
def
recv
(
self
):
"""
"""
...
@@ -82,12 +92,15 @@ class Connection(object):
...
@@ -82,12 +92,15 @@ class Connection(object):
if
isinstance
(
frame
,
ControlFrame
):
if
isinstance
(
frame
,
ControlFrame
):
self
.
handle_control_frame
(
frame
)
self
.
handle_control_frame
(
frame
)
elif
len
(
fragments
)
and
frame
.
opcode
!=
OPCODE_CONTINUATION
:
elif
len
(
fragments
)
>
0
and
frame
.
opcode
!=
OPCODE_CONTINUATION
:
raise
ValueError
(
'expected continuation/control frame, got %s '
raise
ValueError
(
'expected continuation/control frame, got %s '
'instead'
%
frame
)
'instead'
%
frame
)
else
:
else
:
fragments
.
append
(
frame
)
fragments
.
append
(
frame
)
return
self
.
concat_fragments
(
fragments
)
def
concat_fragments
(
self
,
fragments
):
payload
=
bytearray
()
payload
=
bytearray
()
for
f
in
fragments
:
for
f
in
fragments
:
...
@@ -105,16 +118,20 @@ class Connection(object):
...
@@ -105,16 +118,20 @@ class Connection(object):
Handle a control frame as defined by RFC 6455.
Handle a control frame as defined by RFC 6455.
"""
"""
if
frame
.
opcode
==
OPCODE_CLOSE
:
if
frame
.
opcode
==
OPCODE_CLOSE
:
# Close the connection from this end as well
self
.
close_frame_received
=
True
self
.
close_frame_received
=
True
code
,
reason
=
frame
.
unpack_close
()
code
,
reason
=
frame
.
unpack_close
()
# No more receiving data after a close message
if
self
.
close_frame_sent
:
raise
SocketClosed
(
code
,
reason
)
self
.
onclose
(
code
,
reason
)
self
.
sock
.
close
()
raise
SocketClosed
(
True
)
else
:
self
.
close_params
=
(
code
,
reason
)
self
.
send_close_frame
(
code
,
reason
)
elif
frame
.
opcode
==
OPCODE_PING
:
elif
frame
.
opcode
==
OPCODE_PING
:
# Respond with a pong message with identical payload
# Respond with a pong message with identical payload
self
.
s
ock
.
send
(
ControlFrame
(
OPCODE_PONG
,
frame
.
payload
))
self
.
s
end_frame
(
ControlFrame
(
OPCODE_PONG
,
frame
.
payload
))
elif
frame
.
opcode
==
OPCODE_PONG
:
elif
frame
.
opcode
==
OPCODE_PONG
:
# Assert that the PONG payload is identical to that of the PING
# Assert that the PONG payload is identical to that of the PING
...
@@ -138,38 +155,40 @@ class Connection(object):
...
@@ -138,38 +155,40 @@ class Connection(object):
while
True
:
while
True
:
try
:
try
:
self
.
onmessage
(
self
.
recv
())
self
.
onmessage
(
self
.
recv
())
except
SocketClosed
as
e
:
except
(
KeyboardInterrupt
,
SystemExit
,
SocketClosed
):
self
.
close
(
e
.
code
,
e
.
reason
)
break
break
except
socket
.
error
as
e
:
except
Exception
as
e
:
self
.
onerror
(
e
)
self
.
onerror
(
e
)
self
.
onclose
(
None
,
'error: %s'
%
e
)
try
:
try
:
self
.
sock
.
close
()
self
.
sock
.
close
()
except
socket
.
error
:
except
socket
.
error
:
pass
pass
self
.
onclose
(
None
,
''
)
raise
e
break
except
Exception
as
e
:
self
.
onerror
(
e
)
def
send_ping
(
self
,
payload
=
''
):
def
send_ping
(
self
,
payload
=
''
):
"""
"""
Send a PING control frame with an optional payload.
Send a PING control frame with an optional payload.
"""
"""
self
.
sock
.
send
(
ControlFrame
(
OPCODE_PING
,
payload
))
self
.
send_frame
(
ControlFrame
(
OPCODE_PING
,
payload
),
lambda
:
self
.
onping
(
payload
))
self
.
ping_payload
=
payload
self
.
ping_payload
=
payload
self
.
ping_sent
=
True
self
.
ping_sent
=
True
self
.
onping
(
payload
)
def
send_close_frame
(
self
,
code
=
None
,
reason
=
''
):
def
send_close_frame
(
self
,
code
,
reason
):
"""
self
.
send_frame
(
create_close_frame
(
code
,
reason
))
Send a CLOSE control frame.
"""
payload
=
''
if
code
is
None
else
struct
.
pack
(
'!H'
,
code
)
+
reason
self
.
sock
.
send
(
ControlFrame
(
OPCODE_CLOSE
,
payload
))
self
.
close_frame_sent
=
True
self
.
close_frame_sent
=
True
self
.
shutdown_write
()
def
shutdown_write
(
self
):
if
self
.
close_frame_received
:
self
.
onclose
(
*
self
.
close_params
)
self
.
sock
.
close
()
raise
SocketClosed
(
False
)
else
:
self
.
sock
.
shutdown
(
socket
.
SHUT_WR
)
def
close
(
self
,
code
=
None
,
reason
=
''
):
def
close
(
self
,
code
=
None
,
reason
=
''
):
"""
"""
...
@@ -179,28 +198,14 @@ class Connection(object):
...
@@ -179,28 +198,14 @@ class Connection(object):
called after the response has been received, but before the socket is
called after the response has been received, but before the socket is
actually closed.
actually closed.
"""
"""
# Send CLOSE frame
if
not
self
.
close_frame_sent
:
self
.
send_close_frame
(
code
,
reason
)
self
.
send_close_frame
(
code
,
reason
)
# Receive CLOSE frame
if
not
self
.
close_frame_received
:
frame
=
self
.
sock
.
recv
()
frame
=
self
.
sock
.
recv
()
if
frame
.
opcode
!=
OPCODE_CLOSE
:
if
frame
.
opcode
!=
OPCODE_CLOSE
:
raise
ValueError
(
'expected CLOSE frame, got %s'
%
frame
)
raise
ValueError
(
'expected CLOSE frame, got %s'
%
frame
)
self
.
close_frame_received
=
True
self
.
handle_control_frame
(
frame
)
res_code
,
res_reason
=
frame
.
unpack_close
()
# FIXME: check if res_code == code and res_reason == reason?
# FIXME: alternatively, keep receiving frames in a loop until a
# CLOSE frame is received, so that a fragmented chain may arrive
# fully first
self
.
onclose
(
code
,
reason
)
self
.
sock
.
close
()
def
add_hook
(
self
,
send
=
None
,
recv
=
None
,
prepend
=
False
):
def
add_hook
(
self
,
send
=
None
,
recv
=
None
,
prepend
=
False
):
"""
"""
...
...
deflate_frame.py
View file @
6c79550e
...
@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
...
@@ -20,38 +20,33 @@ class DeflateFrame(Extension):
name
=
'deflate-frame'
name
=
'deflate-frame'
rsv1
=
True
rsv1
=
True
defaults
=
{
'max_window_bits'
:
15
,
'no_context_takeover'
:
False
}
defaults
=
{
'max_window_bits'
:
zlib
.
MAX_WBITS
,
'no_context_takeover'
:
False
}
def
__init__
(
self
,
defaults
=
{},
request
=
{}):
COMPRESSION_THRESHOLD
=
64
# minimal payload size for compression
Extension
.
__init__
(
self
,
defaults
,
request
)
def
init
(
self
):
mwb
=
self
.
defaults
[
'max_window_bits'
]
mwb
=
self
.
defaults
[
'max_window_bits'
]
cto
=
self
.
defaults
[
'no_context_takeover'
]
cto
=
self
.
defaults
[
'no_context_takeover'
]
if
not
isinstance
(
mwb
,
int
):
if
not
isinstance
(
mwb
,
int
)
or
mwb
<
1
or
mwb
>
zlib
.
MAX_WBITS
:
raise
ValueError
(
'"max_window_bits" must be an integer'
)
raise
ValueError
(
'"max_window_bits" must be in range 1-15'
)
elif
mwb
>
15
:
raise
ValueError
(
'"max_window_bits" may not be larger than 15'
)
if
cto
is
not
False
and
cto
is
not
True
:
if
cto
is
not
False
and
cto
is
not
True
:
raise
ValueError
(
'"no_context_takeover" must have no value'
)
raise
ValueError
(
'"no_context_takeover" must have no value'
)
class
Hook
(
Extension
.
Hook
):
class
Hook
(
Extension
.
Hook
):
def
__init__
(
self
,
extension
,
**
kwargs
):
def
init
(
self
,
extension
):
Extension
.
Hook
.
__init__
(
self
,
extension
,
**
kwargs
)
if
not
self
.
no_context_takeover
:
self
.
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
self
.
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
zlib
.
DEFLATED
,
zlib
.
DEFLATED
,
-
self
.
max_window_bits
)
-
self
.
max_window_bits
)
other_wbits
=
extension
.
request
.
get
(
'max_window_bits'
,
zlib
.
MAX_WBITS
)
other_wbits
=
self
.
extension
.
request
.
get
(
'max_window_bits'
,
15
)
self
.
dec
=
zlib
.
decompressobj
(
-
other_wbits
)
self
.
dec
=
zlib
.
decompressobj
(
-
other_wbits
)
def
send
(
self
,
frame
):
def
send
(
self
,
frame
):
if
not
frame
.
rsv1
and
not
isinstance
(
frame
,
ControlFrame
):
# FIXME: this does not seem to work properly on Android
if
not
frame
.
rsv1
and
not
isinstance
(
frame
,
ControlFrame
)
and
\
len
(
frame
.
payload
)
>
DeflateFrame
.
COMPRESSION_THRESHOLD
:
frame
.
rsv1
=
True
frame
.
rsv1
=
True
frame
.
payload
=
self
.
deflate
(
frame
.
payload
)
frame
.
payload
=
self
.
deflate
(
frame
)
return
frame
return
frame
...
@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
...
@@ -65,23 +60,23 @@ class DeflateFrame(Extension):
return
frame
return
frame
def
deflate
(
self
,
data
):
def
deflate
(
self
,
frame
):
if
self
.
no_context_takeover
:
compressed
=
self
.
defl
.
compress
(
frame
.
payload
)
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
zlib
.
DEFLATED
,
-
self
.
max_window_bits
)
# FIXME: why the '\x00' below? This was borrowed from
# https://github.com/fancycode/tornado/blob/bc317b6dcf63608ff004ff1f57073be0504b6550/tornado/websocket.py#L91
return
defl
.
compress
(
data
)
+
defl
.
flush
(
zlib
.
Z_FINISH
)
+
'
\
x00
'
compressed
=
self
.
defl
.
compress
(
data
)
if
frame
.
final
or
self
.
no_context_takeover
:
compressed
+=
self
.
defl
.
flush
(
zlib
.
Z_FINISH
)
+
'
\
x00
'
self
.
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
zlib
.
DEFLATED
,
-
self
.
max_window_bits
)
else
:
compressed
+=
self
.
defl
.
flush
(
zlib
.
Z_SYNC_FLUSH
)
compressed
+=
self
.
defl
.
flush
(
zlib
.
Z_SYNC_FLUSH
)
assert
compressed
[
-
4
:]
==
'
\
x00
\
x00
\
xff
\
xff
'
assert
compressed
[
-
4
:]
==
'
\
x00
\
x00
\
xff
\
xff
'
return
compressed
[:
-
4
]
compressed
=
compressed
[:
-
4
]
return
compressed
def
inflate
(
self
,
data
):
def
inflate
(
self
,
data
):
data
=
self
.
dec
.
decompress
(
str
(
data
+
'
\
x00
\
x00
\
xff
\
xff
'
))
return
self
.
dec
.
decompress
(
data
+
'
\
x00
\
x00
\
xff
\
xff
'
)
+
\
assert
not
self
.
dec
.
unused_data
self
.
dec
.
flush
(
zlib
.
Z_SYNC_FLUSH
)
return
data
class
WebkitDeflateFrame
(
DeflateFrame
):
class
WebkitDeflateFrame
(
DeflateFrame
):
...
...
errors.py
View file @
6c79550e
class
SocketClosed
(
Exception
):
class
SocketClosed
(
Exception
):
def
__init__
(
self
,
code
=
None
,
reason
=
''
):
def
__init__
(
self
,
initialized
):
self
.
code
=
code
self
.
initialized
=
initialized
self
.
reason
=
reason
@
property
@
property
def
message
(
self
):
def
message
(
self
):
return
(
''
if
self
.
code
is
None
else
'[%d] '
%
self
.
code
)
+
self
.
reason
s
=
'socket closed'
if
self
.
initialized
:
s
+=
' (initialized)'
return
s
class
HandshakeError
(
Exception
):
class
HandshakeError
(
Exception
):
...
...
extension.py
View file @
6c79550e
...
@@ -19,23 +19,31 @@ class Extension(object):
...
@@ -19,23 +19,31 @@ class Extension(object):
self
.
request
=
dict
(
self
.
__class__
.
request
)
self
.
request
=
dict
(
self
.
__class__
.
request
)
self
.
request
.
update
(
request
)
self
.
request
.
update
(
request
)
self
.
init
()
def
__str__
(
self
):
def
__str__
(
self
):
return
'<Extension "%s" defaults=%s request=%s>'
\
return
'<Extension "%s" defaults=%s request=%s>'
\
%
(
self
.
name
,
self
.
defaults
,
self
.
request
)
%
(
self
.
name
,
self
.
defaults
,
self
.
request
)
def
init
(
self
):
return
NotImplemented
def
create_hook
(
self
,
**
kwargs
):
def
create_hook
(
self
,
**
kwargs
):
params
=
{}
params
=
{}
params
.
update
(
self
.
defaults
)
params
.
update
(
self
.
defaults
)
params
.
update
(
kwargs
)
params
.
update
(
kwargs
)
return
self
.
Hook
(
self
,
**
params
)
hook
=
self
.
Hook
(
**
params
)
hook
.
init
(
self
)
return
hook
class
Hook
:
class
Hook
:
def
__init__
(
self
,
extension
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
self
.
extension
=
extension
for
param
,
value
in
kwargs
.
iteritems
():
for
param
,
value
in
kwargs
.
iteritems
():
setattr
(
self
,
param
,
value
)
setattr
(
self
,
param
,
value
)
def
init
(
self
,
extension
):
return
NotImplemented
def
send
(
self
,
frame
):
def
send
(
self
,
frame
):
return
frame
return
frame
...
@@ -43,28 +51,19 @@ class Extension(object):
...
@@ -43,28 +51,19 @@ class Extension(object):
return
frame
return
frame
def
filter_extensions
(
extensions
):
def
extension_conflicts
(
ext
,
existing
):
"""
Remove extensions that use conflicting rsv bits and/or opcodes, with the
first options being the most preferable.
"""
rsv1_reserved
=
False
rsv1_reserved
=
False
rsv2_reserved
=
False
rsv2_reserved
=
False
rsv3_reserved
=
False
rsv3_reserved
=
False
opcodes_reserved
=
[]
reserved_opcodes
=
[]
compat
=
[]
for
ext
in
extensions
:
for
e
in
existing
:
if
ext
.
rsv1
and
rsv1_reserved
\
rsv1_reserved
|=
e
.
rsv1
rsv2_reserved
|=
e
.
rsv2
rsv3_reserved
|=
e
.
rsv3
reserved_opcodes
.
extend
(
e
.
opcodes
)
return
ext
.
rsv1
and
rsv1_reserved
\
or
ext
.
rsv2
and
rsv2_reserved
\
or
ext
.
rsv2
and
rsv2_reserved
\
or
ext
.
rsv3
and
rsv3_reserved
\
or
ext
.
rsv3
and
rsv3_reserved
\
or
len
(
set
(
ext
.
opcodes
)
&
set
(
opcodes_reserved
)):
or
len
(
set
(
ext
.
opcodes
)
&
set
(
reserved_opcodes
))
continue
rsv1_reserved
|=
ext
.
rsv1
rsv2_reserved
|=
ext
.
rsv2
rsv3_reserved
|=
ext
.
rsv3
opcodes_reserved
.
extend
(
ext
.
opcodes
)
compat
.
append
(
ext
)
return
compat
frame.py
View file @
6c79550e
...
@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009
...
@@ -21,9 +21,11 @@ CLOSE_MESSAGE_TOOBIG = 1009
CLOSE_MISSING_EXTENSIONS
=
1010
CLOSE_MISSING_EXTENSIONS
=
1010
CLOSE_UNABLE
=
1011
CLOSE_UNABLE
=
1011
line_printable
=
[
c
for
c
in
printable
if
c
not
in
'
\
r
\
n
\
x0b
\
x0c
'
]
def
printstr
(
s
):
def
printstr
(
s
):
return
''
.
join
(
c
if
c
in
printable
else
'.'
for
c
in
s
)
return
''
.
join
(
c
if
c
in
line_printable
else
'.'
for
c
in
str
(
s
)
)
class
Frame
(
object
):
class
Frame
(
object
):
...
@@ -154,7 +156,18 @@ class Frame(object):
...
@@ -154,7 +156,18 @@ class Frame(object):
if
len
(
self
.
payload
)
>
max_pl_disp
:
if
len
(
self
.
payload
)
>
max_pl_disp
:
pl
+=
'...'
pl
+=
'...'
return
s
+
' payload=%s>'
%
pl
s
+=
' payload=%s'
%
pl
if
self
.
rsv1
:
s
+=
' rsv1'
if
self
.
rsv2
:
s
+=
' rsv2'
if
self
.
rsv3
:
s
+=
' rsv3'
return
s
+
'>'
class
ControlFrame
(
Frame
):
class
ControlFrame
(
Frame
):
...
@@ -194,12 +207,8 @@ class ControlFrame(Frame):
...
@@ -194,12 +207,8 @@ class ControlFrame(Frame):
return
code
,
reason
return
code
,
reason
def
receive_frame
(
sock
):
def
decode_frame
(
reader
):
"""
b1
,
b2
=
struct
.
unpack
(
'!BB'
,
reader
.
readn
(
2
))
Receive a single frame on socket `sock`. The frame scheme is explained in
the docs of Frame.pack().
"""
b1
,
b2
=
struct
.
unpack
(
'!BB'
,
recvn
(
sock
,
2
))
final
=
bool
(
b1
&
0x80
)
final
=
bool
(
b1
&
0x80
)
rsv1
=
bool
(
b1
&
0x40
)
rsv1
=
bool
(
b1
&
0x40
)
...
@@ -211,16 +220,16 @@ def receive_frame(sock):
...
@@ -211,16 +220,16 @@ def receive_frame(sock):
payload_len
=
b2
&
0x7F
payload_len
=
b2
&
0x7F
if
payload_len
==
126
:
if
payload_len
==
126
:
payload_len
=
struct
.
unpack
(
'!H'
,
re
cvn
(
sock
,
2
))
payload_len
=
struct
.
unpack
(
'!H'
,
re
ader
.
readn
(
2
))
elif
payload_len
==
127
:
elif
payload_len
==
127
:
payload_len
=
struct
.
unpack
(
'!Q'
,
re
cvn
(
sock
,
8
))
payload_len
=
struct
.
unpack
(
'!Q'
,
re
ader
.
readn
(
8
))
if
masked
:
if
masked
:
masking_key
=
re
cvn
(
sock
,
4
)
masking_key
=
re
ader
.
readn
(
4
)
payload
=
mask
(
masking_key
,
re
cvn
(
sock
,
payload_len
))
payload
=
mask
(
masking_key
,
re
ader
.
readn
(
payload_len
))
else
:
else
:
masking_key
=
''
masking_key
=
''
payload
=
re
cvn
(
sock
,
payload_len
)
payload
=
re
ader
.
readn
(
payload_len
)
# Control frames have most significant bit 1
# Control frames have most significant bit 1
cls
=
ControlFrame
if
opcode
&
0x8
else
Frame
cls
=
ControlFrame
if
opcode
&
0x8
else
Frame
...
@@ -229,14 +238,44 @@ def receive_frame(sock):
...
@@ -229,14 +238,44 @@ def receive_frame(sock):
rsv1
=
rsv1
,
rsv2
=
rsv2
,
rsv3
=
rsv3
)
rsv1
=
rsv1
,
rsv2
=
rsv2
,
rsv3
=
rsv3
)
def
recvn
(
sock
,
n
):
def
receive_frame
(
sock
):
return
decode_frame
(
SocketReader
(
sock
))
def
read_frame
(
data
):
reader
=
BufferReader
(
data
)
frame
=
decode_frame
(
reader
)
return
frame
,
reader
.
offset
def
pop_frame
(
data
):
frame
,
size
=
read_frame
(
data
)
return
frame
,
data
[
size
:]
class
BufferReader
(
object
):
def
__init__
(
self
,
data
):
self
.
data
=
data
self
.
offset
=
0
def
readn
(
self
,
n
):
assert
len
(
self
.
data
)
-
self
.
offset
>=
n
self
.
offset
+=
n
return
self
.
data
[
self
.
offset
-
n
:
self
.
offset
]
class
SocketReader
(
object
):
def
__init__
(
self
,
sock
):
self
.
sock
=
sock
def
readn
(
self
,
n
):
"""
"""
Keep receiving data from `sock`
until exactly `n` bytes have been read.
Keep receiving data
until exactly `n` bytes have been read.
"""
"""
data
=
''
data
=
''
while
len
(
data
)
<
n
:
while
len
(
data
)
<
n
:
received
=
sock
.
recv
(
n
-
len
(
data
))
received
=
self
.
sock
.
recv
(
n
-
len
(
data
))
if
not
len
(
received
):
if
not
len
(
received
):
raise
socket
.
error
(
'no data read from socket'
)
raise
socket
.
error
(
'no data read from socket'
)
...
@@ -246,6 +285,32 @@ def recvn(sock, n):
...
@@ -246,6 +285,32 @@ def recvn(sock, n):
return
data
return
data
def
contains_frame
(
data
):
"""
Read the frame length from the start of `data` and check if the data is
long enough to contain the entire frame.
"""
if
len
(
data
)
<
2
:
return
False
b2
=
struct
.
unpack
(
'!B'
,
data
[
1
])[
0
]
payload_len
=
b2
&
0x7F
payload_start
=
2
if
payload_len
==
126
:
if
len
(
data
)
>
4
:
payload_len
=
struct
.
unpack
(
'!H'
,
data
[
2
:
4
])
payload_start
=
4
elif
payload_len
==
127
:
if
len
(
data
)
>
12
:
payload_len
=
struct
.
unpack
(
'!Q'
,
data
[
4
:
12
])
payload_start
=
12
return
len
(
data
)
>=
payload_len
+
payload_start
def
mask
(
key
,
original
):
def
mask
(
key
,
original
):
"""
"""
Mask an octet string using the given masking key.
Mask an octet string using the given masking key.
...
@@ -265,3 +330,8 @@ def mask(key, original):
...
@@ -265,3 +330,8 @@ def mask(key, original):
masked
[
i
]
^=
key
[
i
%
4
]
masked
[
i
]
^=
key
[
i
%
4
]
return
masked
return
masked
def
create_close_frame
(
code
,
reason
):
payload
=
''
if
code
is
None
else
struct
.
pack
(
'!H'
,
code
)
+
reason
return
ControlFrame
(
OPCODE_CLOSE
,
payload
)
handshake.py
View file @
6c79550e
...
@@ -7,7 +7,7 @@ from hashlib import sha1
...
@@ -7,7 +7,7 @@ from hashlib import sha1
from
urlparse
import
urlparse
from
urlparse
import
urlparse
from
errors
import
HandshakeError
from
errors
import
HandshakeError
from
extension
import
filter_extension
s
from
extension
import
extension_conflict
s
from
python_digest
import
build_authorization_request
from
python_digest
import
build_authorization_request
...
@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
...
@@ -15,7 +15,7 @@ WS_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
WS_VERSION
=
'13'
WS_VERSION
=
'13'
MAX_REDIRECTS
=
10
MAX_REDIRECTS
=
10
HDR_TIMEOUT
=
5
HDR_TIMEOUT
=
5
MAX_HDR_LEN
=
512
MAX_HDR_LEN
=
1024
class
Handshake
(
object
):
class
Handshake
(
object
):
...
@@ -65,7 +65,11 @@ class Handshake(object):
...
@@ -65,7 +65,11 @@ class Handshake(object):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
hdr
[
-
4
:]
!=
'
\
r
\
n
\
r
\
n
'
and
len
(
hdr
)
<
MAX_HDR_LEN
:
while
hdr
[
-
4
:]
!=
'
\
r
\
n
\
r
\
n
'
:
if
len
(
hdr
)
==
MAX_HDR_LEN
:
raise
HandshakeError
(
'request exceeds maximum header '
'length of %d'
%
MAX_HDR_LEN
)
hdr
+=
self
.
sock
.
recv
(
1
)
hdr
+=
self
.
sock
.
recv
(
1
)
time_diff
=
time
.
time
()
-
start_time
time_diff
=
time
.
time
()
-
start_time
...
@@ -169,23 +173,19 @@ class ServerHandshake(Handshake):
...
@@ -169,23 +173,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned
# Only supported extensions are returned
if
'Sec-WebSocket-Extensions'
in
headers
:
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
ssock
.
extensions
)
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
ssock
.
extensions
)
self
.
wsock
.
extension_hooks
=
[]
extensions
=
[]
extensions
=
[]
all_params
=
[]
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
ext
)
name
,
params
=
parse_param_hdr
(
ext
)
if
name
in
supported_ext
:
if
name
in
supported_ext
:
extensions
.
append
(
supported_ext
[
name
])
ext
=
supported_ext
[
name
]
all_params
.
append
(
params
)
self
.
wsock
.
extensions
=
filter_extensions
(
extensions
)
for
ext
,
params
in
zip
(
self
.
wsock
.
extensions
,
all_params
):
if
not
extension_conflicts
(
ext
,
extensions
):
extensions
.
append
(
ext
)
hook
=
ext
.
create_hook
(
**
params
)
hook
=
ext
.
create_hook
(
**
params
)
self
.
wsock
.
add_hook
(
send
=
hook
.
send
,
recv
=
hook
.
recv
)
self
.
wsock
.
extension_hooks
.
append
(
hook
)
else
:
self
.
wsock
.
extensions
=
[]
# Check if requested resource location is served by this server
# Check if requested resource location is served by this server
if
ssock
.
locations
:
if
ssock
.
locations
:
...
@@ -212,7 +212,7 @@ class ServerHandshake(Handshake):
...
@@ -212,7 +212,7 @@ class ServerHandshake(Handshake):
location
=
'%s://%s%s'
%
(
scheme
,
host
,
self
.
wsock
.
location
)
location
=
'%s://%s%s'
%
(
scheme
,
host
,
self
.
wsock
.
location
)
# Construct HTTP response header
# Construct HTTP response header
yield
'HTTP/1.1 101
Web Socket Protocol Handshake
'
yield
'HTTP/1.1 101
Switching Protocols
'
yield
'Upgrade'
,
'websocket'
yield
'Upgrade'
,
'websocket'
yield
'Connection'
,
'Upgrade'
yield
'Connection'
,
'Upgrade'
yield
'Sec-WebSocket-Origin'
,
origin
yield
'Sec-WebSocket-Origin'
,
origin
...
@@ -274,7 +274,7 @@ class ClientHandshake(Handshake):
...
@@ -274,7 +274,7 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server
# Compare extensions, add hooks only for those returned by server
if
'Sec-WebSocket-Extensions'
in
headers
:
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
self
.
wsock
.
extensions
=
[]
self
.
wsock
.
extension
_hook
s
=
[]
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
ext
)
name
,
params
=
parse_param_hdr
(
ext
)
...
@@ -284,8 +284,7 @@ class ClientHandshake(Handshake):
...
@@ -284,8 +284,7 @@ class ClientHandshake(Handshake):
'unsupported extension "%s"'
%
name
)
'unsupported extension "%s"'
%
name
)
hook
=
supported_ext
[
name
].
create_hook
(
**
params
)
hook
=
supported_ext
[
name
].
create_hook
(
**
params
)
self
.
wsock
.
extensions
.
append
(
supported_ext
[
name
])
self
.
wsock
.
extension_hooks
.
append
(
hook
)
self
.
wsock
.
add_hook
(
send
=
hook
.
send
,
recv
=
hook
.
recv
)
# Assert that returned protocol (if any) is supported
# Assert that returned protocol (if any) is supported
if
'Sec-WebSocket-Protocol'
in
headers
:
if
'Sec-WebSocket-Protocol'
in
headers
:
...
...
server.py
View file @
6c79550e
...
@@ -32,7 +32,7 @@ class Server(object):
...
@@ -32,7 +32,7 @@ class Server(object):
"""
"""
def
__init__
(
self
,
address
,
loglevel
=
logging
.
INFO
,
ssl_args
=
None
,
def
__init__
(
self
,
address
,
loglevel
=
logging
.
INFO
,
ssl_args
=
None
,
max_join_time
=
2.0
,
**
kwargs
):
max_join_time
=
2.0
,
backlog_size
=
32
,
**
kwargs
):
"""
"""
Constructor for a simple web socket server.
Constructor for a simple web socket server.
...
@@ -53,6 +53,8 @@ class Server(object):
...
@@ -53,6 +53,8 @@ class Server(object):
`max_join_time` is the maximum time (in seconds) to wait for client
`max_join_time` is the maximum time (in seconds) to wait for client
responses after sending CLOSE frames, it defaults to 2 seconds.
responses after sending CLOSE frames, it defaults to 2 seconds.
`backlog_size` is directly passed to `websocket.listen`.
"""
"""
logging
.
basicConfig
(
level
=
loglevel
,
logging
.
basicConfig
(
level
=
loglevel
,
format
=
'%(asctime)s: %(levelname)s: %(message)s'
,
format
=
'%(asctime)s: %(levelname)s: %(message)s'
,
...
@@ -69,14 +71,14 @@ class Server(object):
...
@@ -69,14 +71,14 @@ class Server(object):
self
.
sock
.
enable_ssl
(
server_side
=
True
,
**
ssl_args
)
self
.
sock
.
enable_ssl
(
server_side
=
True
,
**
ssl_args
)
self
.
sock
.
bind
(
address
)
self
.
sock
.
bind
(
address
)
self
.
sock
.
listen
(
5
)
self
.
sock
.
listen
(
backlog_size
)
self
.
clients
=
[]
self
.
client_threads
=
[]
self
.
max_join_time
=
max_join_time
self
.
max_join_time
=
max_join_time
def
run
(
self
):
def
run
(
self
):
self
.
clients
=
[]
self
.
client_threads
=
[]
while
True
:
while
True
:
try
:
try
:
sock
,
address
=
self
.
sock
.
accept
()
sock
,
address
=
self
.
sock
.
accept
()
...
@@ -134,30 +136,22 @@ class Server(object):
...
@@ -134,30 +136,22 @@ class Server(object):
self
.
onclose
(
client
,
code
,
reason
)
self
.
onclose
(
client
,
code
,
reason
)
def
onopen
(
self
,
client
):
def
onopen
(
self
,
client
):
logging
.
debug
(
'Opened socket to %s'
,
client
)
return
NotImplemented
def
onmessage
(
self
,
client
,
message
):
def
onmessage
(
self
,
client
,
message
):
logging
.
debug
(
'Received %s from %s'
,
message
,
client
)
return
NotImplemented
def
onping
(
self
,
client
,
payload
):
def
onping
(
self
,
client
,
payload
):
logging
.
debug
(
'Sent ping "%s" to %s'
,
payload
,
client
)
return
NotImplemented
def
onpong
(
self
,
client
,
payload
):
def
onpong
(
self
,
client
,
payload
):
logging
.
debug
(
'Received pong "%s" from %s'
,
payload
,
client
)
return
NotImplemented
def
onclose
(
self
,
client
,
code
,
reason
):
def
onclose
(
self
,
client
,
code
,
reason
):
msg
=
'Closed socket to %s'
%
client
return
NotImplemented
if
code
is
not
None
:
msg
+=
' [%d]'
%
code
if
len
(
reason
):
msg
+=
': '
+
reason
logging
.
debug
(
msg
)
def
onerror
(
self
,
client
,
e
):
def
onerror
(
self
,
client
,
e
):
logging
.
error
(
format_exc
(
e
))
return
NotImplemented
class
Client
(
Connection
):
class
Client
(
Connection
):
...
@@ -176,21 +170,32 @@ class Client(Connection):
...
@@ -176,21 +170,32 @@ class Client(Connection):
Connection
.
send
(
self
,
message
,
fragment_size
=
fragment_size
,
mask
=
mask
)
Connection
.
send
(
self
,
message
,
fragment_size
=
fragment_size
,
mask
=
mask
)
def
onopen
(
self
):
def
onopen
(
self
):
logging
.
debug
(
'Opened socket to %s'
,
self
)
self
.
server
.
onopen
(
self
)
self
.
server
.
onopen
(
self
)
def
onmessage
(
self
,
message
):
def
onmessage
(
self
,
message
):
logging
.
debug
(
'Received %s from %s'
,
message
,
self
)
self
.
server
.
onmessage
(
self
,
message
)
self
.
server
.
onmessage
(
self
,
message
)
def
onping
(
self
,
payload
):
def
onping
(
self
,
payload
):
logging
.
debug
(
'Sent ping "%s" to %s'
,
payload
,
self
)
self
.
server
.
onping
(
self
,
payload
)
self
.
server
.
onping
(
self
,
payload
)
def
onpong
(
self
,
payload
):
def
onpong
(
self
,
payload
):
logging
.
debug
(
'Received pong "%s" from %s'
,
payload
,
self
)
self
.
server
.
onpong
(
self
,
payload
)
self
.
server
.
onpong
(
self
,
payload
)
def
onclose
(
self
,
code
,
reason
):
def
onclose
(
self
,
code
,
reason
):
msg
=
'Closed socket to %s'
%
self
if
code
is
not
None
:
msg
+=
': [%d] %s'
%
(
code
,
reason
)
logging
.
debug
(
msg
)
self
.
server
.
remove_client
(
self
,
code
,
reason
)
self
.
server
.
remove_client
(
self
,
code
,
reason
)
def
onerror
(
self
,
e
):
def
onerror
(
self
,
e
):
logging
.
error
(
format_exc
(
e
))
self
.
server
.
onerror
(
self
,
e
)
self
.
server
.
onerror
(
self
,
e
)
...
...
test/client.py
View file @
6c79550e
#!/usr/bin/env python
#!/usr/bin/env python
import
sys
import
sys
import
ssl
from
os.path
import
abspath
,
dirname
from
os.path
import
abspath
,
dirname
basepath
=
abspath
(
dirname
(
abspath
(
__file__
))
+
'/..'
)
basepath
=
abspath
(
dirname
(
abspath
(
__file__
))
+
'/..'
)
...
@@ -20,7 +21,7 @@ class EchoClient(Connection):
...
@@ -20,7 +21,7 @@ class EchoClient(Connection):
def
onmessage
(
self
,
msg
):
def
onmessage
(
self
,
msg
):
print
'Received'
,
msg
print
'Received'
,
msg
raise
SocketClosed
(
None
,
'response received'
)
self
.
close
(
None
,
'response received'
)
def
onerror
(
self
,
e
):
def
onerror
(
self
,
e
):
print
'Error:'
,
e
print
'Error:'
,
e
...
@@ -29,8 +30,15 @@ class EchoClient(Connection):
...
@@ -29,8 +30,15 @@ class EchoClient(Connection):
print
'Connection closed'
print
'Connection closed'
secure
=
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
print
'Connecting to ws://%s:%d'
%
ADDR
scheme
=
'wss'
if
secure
else
'ws'
print
'Connecting to %s://%s'
%
(
scheme
,
'%s:%d'
%
ADDR
)
sock
=
websocket
()
sock
=
websocket
()
if
secure
:
sock
.
enable_ssl
(
ca_certs
=
'cert.pem'
,
cert_reqs
=
ssl
.
CERT_REQUIRED
)
sock
.
connect
(
ADDR
)
sock
.
connect
(
ADDR
)
EchoClient
(
sock
).
receive_forever
()
EchoClient
(
sock
).
receive_forever
()
test/talk.py
0 → 100755
View file @
6c79550e
#!/usr/bin/env python
import
sys
import
socket
from
os.path
import
abspath
,
dirname
basepath
=
abspath
(
dirname
(
abspath
(
__file__
))
+
'/..'
)
sys
.
path
.
insert
(
0
,
basepath
)
from
websocket
import
websocket
from
connection
import
Connection
from
message
import
TextMessage
from
errors
import
SocketClosed
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
<
3
:
print
>>
sys
.
stderr
,
'Usage: python %s HOST PORT'
%
sys
.
argv
[
0
]
sys
.
exit
(
1
)
host
=
sys
.
argv
[
1
]
port
=
int
(
sys
.
argv
[
2
])
sock
=
websocket
()
sock
.
connect
((
host
,
port
))
sock
.
settimeout
(
1.0
)
conn
=
Connection
(
sock
)
try
:
try
:
while
True
:
msg
=
TextMessage
(
raw_input
())
print
'send:'
,
msg
conn
.
send
(
msg
)
try
:
print
'recv:'
,
conn
.
recv
()
except
socket
.
timeout
:
print
'no response'
except
EOFError
:
conn
.
close
()
except
SocketClosed
as
e
:
if
e
.
initialized
:
print
'closed connection'
else
:
print
'other side closed connection'
websocket.py
View file @
6c79550e
import
socket
import
socket
import
ssl
import
ssl
from
frame
import
receive_frame
from
frame
import
receive_frame
,
pop_frame
,
contains_frame
from
handshake
import
ServerHandshake
,
ClientHandshake
from
handshake
import
ServerHandshake
,
ClientHandshake
from
errors
import
SSLError
from
errors
import
SSLError
...
@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
...
@@ -11,7 +11,6 @@ INHERITED_ATTRS = ['bind', 'close', 'listen', 'fileno', 'getpeername',
'settimeout'
,
'gettimeout'
,
'shutdown'
,
'family'
,
'type'
,
'settimeout'
,
'gettimeout'
,
'shutdown'
,
'family'
,
'type'
,
'proto'
]
'proto'
]
class
websocket
(
object
):
class
websocket
(
object
):
"""
"""
Implementation of web socket, upgrades a regular TCP socket to a websocket
Implementation of web socket, upgrades a regular TCP socket to a websocket
...
@@ -36,22 +35,23 @@ class websocket(object):
...
@@ -36,22 +35,23 @@ class websocket(object):
>>> sock.connect(('', 8000))
>>> sock.connect(('', 8000))
>>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
>>> sock.send(wspy.Frame(wspy.OPCODE_TEXT, 'Hello, Server!'))
"""
"""
def
__init__
(
self
,
sock
=
None
,
protocols
=
[],
extensions
=
[],
origin
=
None
,
def
__init__
(
self
,
sock
=
None
,
origin
=
None
,
protocols
=
[],
extensions
=
[]
,
location
=
'/'
,
trusted_origins
=
[],
locations
=
[],
auth
=
None
,
location
=
'/'
,
trusted_origins
=
[],
locations
=
[],
auth
=
None
,
sfamily
=
socket
.
AF_INET
,
sproto
=
0
):
recv_callback
=
None
,
sfamily
=
socket
.
AF_INET
,
sproto
=
0
):
"""
"""
Create a regular TCP socket of family `family` and protocol
Create a regular TCP socket of family `family` and protocol
`sock` is an optional regular TCP socket to be used for sending binary
`sock` is an optional regular TCP socket to be used for sending binary
data. If not specified, a new socket is created.
data. If not specified, a new socket is created.
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extensions (`Extension` instances).
`origin` (for client sockets) is the value for the "Origin" header sent
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
in a client handshake .
`protocols` is a list of supported protocol names.
`extensions` (for server sockets) is a list of supported extensions
(`Extension` instances).
`location` (for client sockets) is optional, used to request a
`location` (for client sockets) is optional, used to request a
particular resource in the HTTP handshake. In a URL, this would show as
particular resource in the HTTP handshake. In a URL, this would show as
ws://host[:port]/<location>. Use this when the server serves multiple
ws://host[:port]/<location>. Use this when the server serves multiple
...
@@ -71,10 +71,17 @@ class websocket(object):
...
@@ -71,10 +71,17 @@ class websocket(object):
`auth` is optional, used for HTTP Basic or Digest authentication during
`auth` is optional, used for HTTP Basic or Digest authentication during
the handshake. It must be specified as a (username, password) tuple.
the handshake. It must be specified as a (username, password) tuple.
`recv_callback` is the callback for received frames in asynchronous
sockets. Use in conjunction with setblocking(0). The callback itself
may for example change the recv_callback attribute to change the
behaviour for the next received message. Can be set when calling
`queue_send`.
`sfamily` and `sproto` are used for the regular socket constructor.
`sfamily` and `sproto` are used for the regular socket constructor.
"""
"""
self
.
protocols
=
protocols
self
.
protocols
=
protocols
self
.
extensions
=
extensions
self
.
extensions
=
extensions
self
.
extension_hooks
=
[]
self
.
origin
=
origin
self
.
origin
=
origin
self
.
location
=
location
self
.
location
=
location
self
.
trusted_origins
=
trusted_origins
self
.
trusted_origins
=
trusted_origins
...
@@ -85,11 +92,16 @@ class websocket(object):
...
@@ -85,11 +92,16 @@ class websocket(object):
self
.
handshake_sent
=
False
self
.
handshake_sent
=
False
self
.
hooks_send
=
[]
self
.
sendbuf_frames
=
[]
self
.
hooks_recv
=
[]
self
.
sendbuf
=
''
self
.
recvbuf
=
''
self
.
recv_callback
=
recv_callback
self
.
sock
=
sock
or
socket
.
socket
(
sfamily
,
socket
.
SOCK_STREAM
,
sproto
)
self
.
sock
=
sock
or
socket
.
socket
(
sfamily
,
socket
.
SOCK_STREAM
,
sproto
)
def
set_extensions
(
self
,
extensions
):
self
.
extensions
=
[
ext
.
Hook
()
for
ext
in
extensions
]
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
if
name
in
INHERITED_ATTRS
:
if
name
in
INHERITED_ATTRS
:
return
getattr
(
self
.
sock
,
name
)
return
getattr
(
self
.
sock
,
name
)
...
@@ -122,29 +134,31 @@ class websocket(object):
...
@@ -122,29 +134,31 @@ class websocket(object):
ClientHandshake
(
self
).
perform
()
ClientHandshake
(
self
).
perform
()
self
.
handshake_sent
=
True
self
.
handshake_sent
=
True
def
apply_send_hooks
(
self
,
frame
):
for
hook
in
self
.
extension_hooks
:
frame
=
hook
.
send
(
frame
)
return
frame
def
apply_recv_hooks
(
self
,
frame
):
for
hook
in
reversed
(
self
.
extension_hooks
):
frame
=
hook
.
recv
(
frame
)
return
frame
def
send
(
self
,
*
args
):
def
send
(
self
,
*
args
):
"""
"""
Send a number of frames.
Send a number of frames.
"""
"""
for
frame
in
args
:
for
frame
in
args
:
for
hook
in
self
.
hooks_send
:
self
.
sock
.
sendall
(
self
.
apply_send_hooks
(
frame
).
pack
())
frame
=
hook
(
frame
)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self
.
sock
.
sendall
(
frame
.
pack
())
def
recv
(
self
):
def
recv
(
self
):
"""
"""
Receive a single frames. This can be either a data frame or a control
Receive a single frames. This can be either a data frame or a control
frame.
frame.
"""
"""
frame
=
receive_frame
(
self
.
sock
)
return
self
.
apply_recv_hooks
(
receive_frame
(
self
.
sock
))
for
hook
in
self
.
hooks_recv
:
frame
=
hook
(
frame
)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return
frame
def
recvn
(
self
,
n
):
def
recvn
(
self
,
n
):
"""
"""
...
@@ -153,47 +167,79 @@ class websocket(object):
...
@@ -153,47 +167,79 @@ class websocket(object):
"""
"""
return
[
self
.
recv
()
for
i
in
xrange
(
n
)]
return
[
self
.
recv
()
for
i
in
xrange
(
n
)]
def
enable_ssl
(
self
,
*
args
,
**
kwargs
):
def
queue_send
(
self
,
frame
,
callback
=
None
,
recv_callback
=
None
):
"""
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
Enqueue `frame` to the send buffer so that it is send on the next
connections. Any arguments are passed to ssl.wrap_socket:
`do_async_send`. `callback` is an optional callable to call when the
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
frame has been fully written. `recv_callback` is an optional callable
to quickly set the `recv_callback` attribute to.
"""
"""
if
self
.
handshake_sent
:
frame
=
self
.
apply_send_hooks
(
frame
)
raise
SSLError
(
'can only enable SSL before handshake'
)
self
.
sendbuf
+=
frame
.
pack
()
self
.
sendbuf_frames
.
append
([
frame
,
len
(
self
.
sendbuf
),
callback
])
self
.
secure
=
True
if
recv_callback
:
self
.
sock
=
ssl
.
wrap_socket
(
self
.
sock
,
*
args
,
**
kwargs
)
self
.
recv_callback
=
recv_callback
def
add_hook
(
self
,
send
=
None
,
recv
=
None
,
prepend
=
False
):
def
do_async_send
(
self
):
"""
"""
Add a pair of send and receive hooks that are called for each frame
Send any queued data. This function should only be called after a write
that is sent or received. A hook is a function that receives a single
event on a file descriptor.
argument - a Frame instance - and returns a `Frame` instance as well.
"""
assert
len
(
self
.
sendbuf
)
`prepend` is a flag indicating whether the send hook is prepended to
nwritten
=
self
.
sock
.
send
(
self
.
sendbuf
)
the other send hooks. This is expecially useful when a program uses
nframes
=
0
extensions such as the built-in `DeflateFrame` extension. These
extensions are installed using these hooks as well.
For example, the following code creates a `Frame` instance for data
for
entry
in
self
.
sendbuf_frames
:
being sent and removes the instance for received data. This way, data
frame
,
offset
,
callback
=
entry
can be sent and received as if on a regular socket.
>>> import wspy
>>> sock.add_hook(lambda data: tswpy.Frame(tswpy.OPCODE_TEXT, data),
>>> lambda frame: frame.payload)
To add base64 encoding to the example above:
if
offset
<=
nwritten
:
>>> import base64
nframes
+=
1
>>> sock.add_hook(base64.encodestring, base64.decodestring, True)
Note that here `prepend=True`, so that data passed to `send()` is first
if
callback
:
encoded and then packed into a frame. Of course, one could also decide
callback
()
to add the base64 hook first, or to return a new `Frame` instance with
else
:
base64-encoded data.
entry
[
1
]
-=
nwritten
self
.
sendbuf
=
self
.
sendbuf
[
nwritten
:]
self
.
sendbuf_frames
=
self
.
sendbuf_frames
[
nframes
:]
def
do_async_recv
(
self
,
bufsize
):
"""
Receive any completed frames from the socket. This function should only
be called after a read event on a file descriptor.
"""
"""
if
send
:
data
=
self
.
sock
.
recv
(
bufsize
)
self
.
hooks_send
.
insert
(
0
if
prepend
else
-
1
,
send
)
if
len
(
data
)
==
0
:
raise
socket
.
error
(
'no data to receive'
)
self
.
recvbuf
+=
data
while
contains_frame
(
self
.
recvbuf
):
frame
,
self
.
recvbuf
=
pop_frame
(
self
.
recvbuf
)
frame
=
self
.
apply_recv_hooks
(
frame
)
if
not
self
.
recv_callback
:
raise
ValueError
(
'no callback installed for %s'
%
frame
)
self
.
recv_callback
(
frame
)
def
can_send
(
self
):
return
len
(
self
.
sendbuf
)
>
0
if
recv
:
def
can_recv
(
self
):
self
.
hooks_recv
.
insert
(
-
1
if
prepend
else
0
,
recv
)
return
self
.
recv_callback
is
not
None
def
enable_ssl
(
self
,
*
args
,
**
kwargs
):
"""
Transforms the regular socket.socket to an ssl.SSLSocket for secure
connections. Any arguments are passed to ssl.wrap_socket:
http://docs.python.org/dev/library/ssl.html#ssl.wrap_socket
"""
if
self
.
handshake_sent
:
raise
SSLError
(
'can only enable SSL before handshake'
)
self
.
secure
=
True
self
.
sock
=
ssl
.
wrap_socket
(
self
.
sock
,
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment