Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
W
wspy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Taddeüs Kroes
wspy
Commits
9232e5d4
Commit
9232e5d4
authored
Dec 20, 2014
by
Taddeüs Kroes
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rewrote extensions API + reimplemented deflate-frame
parent
6c79550e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
144 additions
and
126 deletions
+144
-126
__init__.py
__init__.py
+1
-1
deflate_frame.py
deflate_frame.py
+38
-43
extension.py
extension.py
+62
-42
handshake.py
handshake.py
+27
-25
test/client.py
test/client.py
+1
-3
test/server.py
test/server.py
+4
-4
websocket.py
websocket.py
+11
-8
No files found.
__init__.py
View file @
9232e5d4
...
...
@@ -10,5 +10,5 @@ from connection import Connection
from
message
import
Message
,
TextMessage
,
BinaryMessage
from
errors
import
SocketClosed
,
HandshakeError
,
PingError
,
SSLError
from
extension
import
Extension
from
deflate_frame
import
DeflateFrame
,
WebkitDeflateFrame
from
deflate_frame
import
DeflateFrame
from
async
import
AsyncConnection
,
AsyncServer
deflate_frame.py
View file @
9232e5d4
...
...
@@ -17,40 +17,39 @@ class DeflateFrame(Extension):
Note that the deflate and inflate hooks modify the RSV1 bit and payload of
existing `Frame` objects.
"""
name
=
'deflate-frame'
names
=
(
'deflate-frame'
,
'x-webkit-deflate-frame'
)
rsv1
=
True
defaults
=
{
'max_window_bits'
:
zlib
.
MAX_WBITS
,
'no_context_takeover'
:
False
}
COMPRESSION_THRESHOLD
=
64
# minimal payload size for compression
def
init
(
self
):
mwb
=
self
.
defaults
[
'max_window_bits'
]
cto
=
self
.
defaults
[
'no_context_takeover'
]
if
not
isinstance
(
mwb
,
int
)
or
mwb
<
1
or
mwb
>
zlib
.
MAX_WBITS
:
raise
ValueError
(
'"max_window_bits" must be in range 1-15'
)
if
cto
is
not
False
and
cto
is
not
True
:
raise
ValueError
(
'"no_context_takeover" must have no value'
)
class
Hook
(
Extension
.
Hook
):
def
init
(
self
,
extension
):
self
.
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
zlib
.
DEFLATED
,
-
self
.
max_window_bits
)
other_wbits
=
extension
.
request
.
get
(
'max_window_bits'
,
zlib
.
MAX_WBITS
)
self
.
dec
=
zlib
.
decompressobj
(
-
other_wbits
)
defaults
=
{
'max_window_bits'
:
zlib
.
MAX_WBITS
,
'no_context_takeover'
:
False
}
compression_threshold
=
64
# minimal payload size for compression
def
negotiate
(
self
,
name
,
params
):
if
'max_window_bits'
in
params
:
mwb
=
int
(
params
[
'max_window_bits'
])
assert
8
<=
mwb
<=
zlib
.
MAX_WBITS
yield
'max_window_bits'
,
mwb
if
'no_context_takeover'
in
params
:
assert
params
[
'no_context_takeover'
]
is
True
yield
'no_context_takeover'
,
True
class
Instance
(
Extension
.
Instance
):
def
init
(
self
):
if
not
self
.
no_context_takeover
:
self
.
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
zlib
.
DEFLATED
,
-
self
.
max_window_bits
)
self
.
dec
=
zlib
.
decompressobj
(
-
self
.
max_window_bits
)
def
send
(
self
,
frame
):
# FIXME: this does not seem to work properly on Android
def
onsend_frame
(
self
,
frame
):
if
not
frame
.
rsv1
and
not
isinstance
(
frame
,
ControlFrame
)
and
\
len
(
frame
.
payload
)
>
DeflateFrame
.
COMPRESSION_THRESHOLD
:
len
(
frame
.
payload
)
>
self
.
extension
.
compression_threshold
:
frame
.
rsv1
=
True
frame
.
payload
=
self
.
deflate
(
frame
)
return
frame
def
recv
(
self
,
frame
):
def
onrecv_frame
(
self
,
frame
):
if
frame
.
rsv1
:
if
isinstance
(
frame
,
ControlFrame
):
raise
ValueError
(
'received compressed control frame'
)
...
...
@@ -58,26 +57,22 @@ class DeflateFrame(Extension):
frame
.
rsv1
=
False
frame
.
payload
=
self
.
inflate
(
frame
.
payload
)
return
frame
def
deflate
(
self
,
frame
):
compressed
=
self
.
defl
.
compress
(
frame
.
payload
)
if
frame
.
final
or
self
.
no_context_takeover
:
compressed
+=
self
.
defl
.
flush
(
zlib
.
Z_FINISH
)
+
'
\
x00
'
self
.
defl
=
zlib
.
compressobj
(
zlib
.
Z_DEFAULT_COMPRESSION
,
zlib
.
DEFLATED
,
-
self
.
max_window_bits
)
if
self
.
no_context_takeover
:
print
'no_context_takeover'
compressed
=
zlib
.
compress
(
frame
.
payload
)
else
:
compressed
=
self
.
defl
.
compress
(
frame
.
payload
)
compressed
+=
self
.
defl
.
flush
(
zlib
.
Z_SYNC_FLUSH
)
assert
compressed
[
-
4
:]
==
'
\
x00
\
x00
\
xff
\
xff
'
compressed
=
compressed
[:
-
4
]
return
compressed
assert
compressed
[
-
4
:]
==
'
\
x00
\
x00
\
xff
\
xff
'
return
compressed
[:
-
4
]
def
inflate
(
self
,
data
):
return
self
.
dec
.
decompress
(
data
+
'
\
x00
\
x00
\
xff
\
xff
'
)
+
\
self
.
dec
.
flush
(
zlib
.
Z_SYNC_FLUSH
)
data
=
str
(
data
+
'
\
x00
\
x00
\
xff
\
xff
'
)
if
self
.
no_context_takeover
:
dec
=
zlib
.
decompressobj
(
-
self
.
max_window_bits
)
return
dec
.
decompress
(
data
)
+
dec
.
flush
()
class
WebkitDeflateFrame
(
DeflateFrame
):
name
=
'x-webkit-deflate-frame'
return
self
.
dec
.
decompress
(
data
)
extension.py
View file @
9232e5d4
...
...
@@ -3,67 +3,87 @@ class Extension(object):
rsv1
=
False
rsv2
=
False
rsv3
=
False
opcodes
=
[]
opcodes
=
()
defaults
=
{}
request
=
{}
def
__init__
(
self
,
defaults
=
{},
request
=
{}
):
for
param
in
defaults
.
keys
()
+
request
.
keys
():
def
__init__
(
self
,
**
kwargs
):
for
param
in
kwargs
.
iter
keys
():
if
param
not
in
self
.
defaults
:
raise
KeyError
(
'unrecognized parameter "%s"'
%
param
)
# Copy dict first to avoid duplicate references to the same object
self
.
defaults
=
dict
(
self
.
__class__
.
defaults
)
self
.
defaults
.
update
(
defaults
)
self
.
request
=
dict
(
self
.
__class__
.
request
)
self
.
request
.
update
(
request
)
self
.
init
()
self
.
defaults
.
update
(
kwargs
)
def
__str__
(
self
):
return
'<Extension "%s" defaults=%s request=%s>'
\
%
(
self
.
name
,
self
.
defaults
,
self
.
request
)
def
init
(
self
):
return
NotImplemented
@
property
def
names
(
self
):
return
(
self
.
name
,)
if
self
.
name
else
()
def
conflicts
(
self
,
ext
):
"""
Check if the extension conflicts with an already accepted extension.
This may be the case when the two extensions use the same reserved
bits, or have the same name (when the same extension is negotiated
multiple times with different parameters).
"""
return
ext
.
rsv1
and
self
.
rsv1
\
or
ext
.
rsv2
and
self
.
rsv2
\
or
ext
.
rsv3
and
self
.
rsv3
\
or
set
(
ext
.
names
)
&
set
(
self
.
names
)
\
or
set
(
ext
.
opcodes
)
&
set
(
self
.
opcodes
)
def
negotiate
(
self
,
name
,
params
):
"""
Same as `negotiate_safe`, but instead returns an iterator of (param,
value) tuples and raises an exception on error.
"""
raise
NotImplementedError
def
negotiate_safe
(
self
,
name
,
params
):
"""
`name` and `params` are sent in the HTTP request by the client. Check
if the extension name is supported by this extension, and validate the
parameters. Returns a dict with accepted parameters, or None if not
accepted.
"""
for
param
in
params
.
iterkeys
():
if
param
not
in
self
.
defaults
:
return
try
:
return
dict
(
self
.
negotiate
(
name
,
params
))
except
(
KeyError
,
ValueError
,
AssertionError
):
pass
def
create_hook
(
self
,
**
kwargs
):
params
=
{}
params
.
update
(
self
.
defaults
)
params
.
update
(
kwargs
)
hook
=
self
.
Hook
(
**
params
)
hook
.
init
(
self
)
return
hook
class
Instance
:
def
__init__
(
self
,
extension
,
name
,
params
):
self
.
extension
=
extension
self
.
name
=
name
self
.
params
=
params
class
Hook
:
def
__init__
(
self
,
**
kwargs
):
for
param
,
value
in
kwargs
.
iteritems
():
for
param
,
value
in
extension
.
defaults
.
iteritems
():
setattr
(
self
,
param
,
value
)
def
init
(
self
,
extension
):
return
NotImplemented
for
param
,
value
in
params
.
iteritems
(
):
setattr
(
self
,
param
,
value
)
def
send
(
self
,
frame
):
return
frame
self
.
init
()
def
recv
(
self
,
frame
):
return
frame
def
init
(
self
):
return
NotImplemented
def
onsend_frame
(
self
,
frame
):
pass
def
extension_conflicts
(
ext
,
existing
):
rsv1_reserved
=
False
rsv2_reserved
=
False
rsv3_reserved
=
False
reserved_opcodes
=
[]
def
onrecv_frame
(
self
,
frame
):
pass
for
e
in
existing
:
rsv1_reserved
|=
e
.
rsv1
rsv2_reserved
|=
e
.
rsv2
rsv3_reserved
|=
e
.
rsv3
reserved_opcodes
.
extend
(
e
.
opcodes
)
def
onsend_message
(
self
,
message
):
pass
return
ext
.
rsv1
and
rsv1_reserved
\
or
ext
.
rsv2
and
rsv2_reserved
\
or
ext
.
rsv3
and
rsv3_reserved
\
or
len
(
set
(
ext
.
opcodes
)
&
set
(
reserved_opcodes
))
def
onrecv_message
(
self
,
message
):
pass
handshake.py
View file @
9232e5d4
...
...
@@ -7,7 +7,6 @@ from hashlib import sha1
from
urlparse
import
urlparse
from
errors
import
HandshakeError
from
extension
import
extension_conflicts
from
python_digest
import
build_authorization_request
...
...
@@ -172,20 +171,19 @@ class ServerHandshake(Handshake):
# Only supported extensions are returned
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
ssock
.
extensions
)
self
.
wsock
.
extension_hooks
=
[]
extensions
=
[]
self
.
wsock
.
extension_instances
=
[]
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
ext
)
for
hdr
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
hdr
)
if
name
in
supported_ext
:
ext
=
supported_ext
[
name
]
for
ext
in
ssock
.
extensions
:
if
not
any
(
ext
.
conflicts
(
other
.
extension
)
for
other
in
self
.
wsock
.
extension_instances
):
accept_params
=
ext
.
negotiate_safe
(
name
,
params
)
if
not
extension_conflicts
(
ext
,
extensions
):
extensions
.
append
(
ext
)
hook
=
ext
.
create_hook
(
**
params
)
self
.
wsock
.
extension_hooks
.
append
(
hook
)
if
accept_params
is
not
None
:
instance
=
ext
.
Instance
(
ext
,
name
,
accept_params
)
self
.
wsock
.
extension_instances
.
append
(
instance
)
# Check if requested resource location is served by this server
if
ssock
.
locations
:
...
...
@@ -222,9 +220,9 @@ class ServerHandshake(Handshake):
if
self
.
wsock
.
protocol
:
yield
'Sec-WebSocket-Protocol'
,
self
.
wsock
.
protocol
if
self
.
wsock
.
extensions
:
values
=
[
format_param_hdr
(
e
.
name
,
e
.
request
)
for
e
in
self
.
wsock
.
extension
s
]
if
self
.
wsock
.
extension
_instance
s
:
values
=
[
format_param_hdr
(
i
.
name
,
i
.
params
)
for
i
in
self
.
wsock
.
extension_instance
s
]
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
values
)
...
...
@@ -273,19 +271,23 @@ class ClientHandshake(Handshake):
# Compare extensions, add hooks only for those returned by server
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
self
.
wsock
.
extension_hooks
=
[]
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
ext
)
if
name
not
in
supported_ext
:
# FIXME: there is no distinction between server/client extension
# instances, while the extension instance may assume it belongs to
# a server, leading to undefined behavior
self
.
wsock
.
extension_instances
=
[]
for
hdr
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
accept_params
=
parse_param_hdr
(
hdr
)
for
ext
in
self
.
wsock
.
extensions
:
if
name
in
ext
.
names
:
instance
=
ext
.
Instance
(
ext
,
name
,
accept_params
)
self
.
wsock
.
extension_instances
.
append
(
instance
)
break
else
:
raise
HandshakeError
(
'server handshake contains '
'unsupported extension "%s"'
%
name
)
hook
=
supported_ext
[
name
].
create_hook
(
**
params
)
self
.
wsock
.
extension_hooks
.
append
(
hook
)
# Assert that returned protocol (if any) is supported
if
'Sec-WebSocket-Protocol'
in
headers
:
protocol
=
headers
[
'Sec-WebSocket-Protocol'
]
...
...
test/client.py
View file @
9232e5d4
...
...
@@ -9,7 +9,6 @@ sys.path.insert(0, basepath)
from
websocket
import
websocket
from
connection
import
Connection
from
message
import
TextMessage
from
errors
import
SocketClosed
ADDR
=
(
'localhost'
,
8000
)
...
...
@@ -30,9 +29,8 @@ class EchoClient(Connection):
print
'Connection closed'
secure
=
True
if
__name__
==
'__main__'
:
secure
=
'-s'
in
sys
.
argv
[
1
:]
scheme
=
'wss'
if
secure
else
'ws'
print
'Connecting to %s://%s'
%
(
scheme
,
'%s:%d'
%
ADDR
)
sock
=
websocket
()
...
...
test/server.py
View file @
9232e5d4
...
...
@@ -7,7 +7,7 @@ basepath = abspath(dirname(abspath(__file__)) + '/..')
sys
.
path
.
insert
(
0
,
basepath
)
from
server
import
Server
from
deflate_frame
import
Webkit
DeflateFrame
from
deflate_frame
import
DeflateFrame
class
EchoServer
(
Server
):
...
...
@@ -17,8 +17,8 @@ class EchoServer(Server):
if
__name__
==
'__main__'
:
deflate
=
WebkitDeflateFrame
()
#deflate = WebkitDeflateFrame(defaults={'no_context_takeover': True})
EchoServer
((
'localhost'
,
8000
),
extensions
=
[
deflate
],
EchoServer
((
'localhost'
,
8000
),
#extensions=[DeflateFrame(no_context_takeover=True)],
extensions
=
[
DeflateFrame
()
],
#ssl_args=dict(keyfile='cert.pem', certfile='cert.pem'),
loglevel
=
logging
.
DEBUG
).
run
()
websocket.py
View file @
9232e5d4
...
...
@@ -81,7 +81,7 @@ class websocket(object):
"""
self
.
protocols
=
protocols
self
.
extensions
=
extensions
self
.
extension_
hook
s
=
[]
self
.
extension_
instance
s
=
[]
self
.
origin
=
origin
self
.
location
=
location
self
.
trusted_origins
=
trusted_origins
...
...
@@ -99,9 +99,6 @@ class websocket(object):
self
.
sock
=
sock
or
socket
.
socket
(
sfamily
,
socket
.
SOCK_STREAM
,
sproto
)
def
set_extensions
(
self
,
extensions
):
self
.
extensions
=
[
ext
.
Hook
()
for
ext
in
extensions
]
def
__getattr__
(
self
,
name
):
if
name
in
INHERITED_ATTRS
:
return
getattr
(
self
.
sock
,
name
)
...
...
@@ -135,14 +132,20 @@ class websocket(object):
self
.
handshake_sent
=
True
def
apply_send_hooks
(
self
,
frame
):
for
hook
in
self
.
extension_hooks
:
frame
=
hook
.
send
(
frame
)
for
inst
in
self
.
extension_instances
:
replacement
=
inst
.
onsend_frame
(
frame
)
if
replacement
is
not
None
:
frame
=
replacement
return
frame
def
apply_recv_hooks
(
self
,
frame
):
for
hook
in
reversed
(
self
.
extension_hooks
):
frame
=
hook
.
recv
(
frame
)
for
inst
in
reversed
(
self
.
extension_instances
):
replacement
=
inst
.
onrecv_frame
(
frame
)
if
replacement
is
not
None
:
frame
=
replacement
return
frame
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment