Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
W
wspy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Taddeüs Kroes
wspy
Commits
e465862f
Commit
e465862f
authored
Aug 22, 2013
by
Taddeüs Kroes
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Revised extension instantiation, now 'hooks' are installed which are cleaner and more flexible
parent
6efb8807
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
80 deletions
+84
-80
extension.py
extension.py
+56
-70
handshake.py
handshake.py
+14
-5
websocket.py
websocket.py
+14
-5
No files found.
extension.py
View file @
e465862f
from
errors
import
HandshakeError
class
Extension
(
object
):
name
=
''
rsv1
=
False
rsv2
=
False
rsv3
=
False
opcodes
=
[]
parameters
=
[]
defaults
=
{}
request
=
{}
def
__init__
(
self
,
**
kwargs
):
for
param
in
self
.
parameters
:
setattr
(
self
,
param
,
None
)
def
__init__
(
self
,
defaults
=
{},
request
=
{}):
for
param
in
defaults
.
keys
()
+
request
.
keys
():
if
param
not
in
self
.
defaults
:
raise
KeyError
(
'unrecognized parameter "%s"'
%
param
)
for
param
,
value
in
kwargs
.
items
():
if
param
not
in
self
.
parameters
:
raise
HandshakeError
(
'unrecognized parameter "%s"'
%
param
)
# Copy dict first to avoid duplicate references to the same object
self
.
defaults
=
dict
(
self
.
__class__
.
defaults
)
self
.
defaults
.
update
(
defaults
)
if
value
is
None
:
value
=
True
setattr
(
self
,
param
,
value
)
self
.
request
=
dict
(
self
.
__class__
.
request
)
self
.
request
.
update
(
request
)
def
__str__
(
self
,
frame
):
if
len
(
self
.
parameters
):
params
=
' '
+
', '
.
join
(
p
+
'='
+
str
(
getattr
(
self
,
p
))
for
p
in
self
.
parameters
)
else
:
params
=
''
return
'<Extension "%s" defaults=%s request=%s>'
\
%
(
self
.
name
,
self
.
defaults
,
self
.
request
)
return
'<Extension "%s"%s>'
%
(
self
.
name
,
params
)
def
header_params
(
self
,
frame
):
return
{}
class
Hook
:
def
__init__
(
self
,
**
kwargs
):
for
param
,
value
in
kwargs
.
iteritems
(
):
setattr
(
self
,
param
,
value
)
def
hook_
send
(
self
,
frame
):
def
send
(
self
,
frame
):
return
frame
def
hook_receive
(
self
,
frame
):
def
recv
(
self
,
frame
):
return
frame
...
...
@@ -57,44 +51,38 @@ class DeflateFrame(Extension):
name
=
'deflate-frame'
rsv1
=
True
parameters
=
[
'max_window_bits'
,
'no_context_takeover'
]
# FIXME: is 32768 (below) correct?
defaults
=
{
'max_window_bits'
:
32768
,
'no_context_takeover'
:
True
}
# FIXME: is this correct?
default_max_window_bits
=
32768
def
__init__
(
self
,
defaults
=
{},
request
=
{}):
Extension
.
__init__
(
self
,
defaults
,
request
)
def
__init__
(
self
,
**
kwargs
):
super
(
DeflateFrame
,
self
).
__init__
(
**
kwargs
)
if
self
.
max_window_bits
is
None
:
self
.
max_window_bits
=
self
.
default_max_window_bits
elif
not
isinstance
(
self
.
max_window_bits
,
int
):
raise
HandshakeError
(
'"max_window_bits" must be an integer'
)
elif
self
.
max_window_bits
>
32768
:
raise
HandshakeError
(
'"max_window_bits" may not be larger than '
'32768'
)
if
self
.
no_context_takeover
is
None
:
self
.
no_context_takeover
=
False
elif
self
.
no_context_takeover
is
not
True
:
raise
HandshakeError
(
'"no_context_takeover" must have no value'
)
def
hook_send
(
self
,
frame
):
mwb
=
self
.
defaults
[
'max_window_bits'
]
cto
=
self
.
defaults
[
'no_context_takeover'
]
if
not
isinstance
(
mwb
,
int
):
raise
ValueError
(
'"max_window_bits" must be an integer'
)
elif
mwb
>
32768
:
raise
ValueError
(
'"max_window_bits" may not be larger than 32768'
)
if
cto
is
not
False
and
cto
is
not
True
:
raise
ValueError
(
'"no_context_takeover" must have no value'
)
class
Hook
:
def
send
(
self
,
frame
):
if
not
frame
.
rsv1
:
frame
.
rsv1
=
True
frame
.
payload
=
self
.
deflate
(
frame
.
payload
)
return
frame
def
hook_
recv
(
self
,
frame
):
def
recv
(
self
,
frame
):
if
frame
.
rsv1
:
frame
.
rsv1
=
False
frame
.
payload
=
self
.
inflate
(
frame
.
payload
)
return
frame
def
header_params
(
self
):
raise
NotImplementedError
# TODO
def
deflate
(
self
,
data
):
raise
NotImplementedError
# TODO
...
...
@@ -115,20 +103,18 @@ class Multiplex(Extension):
rsv1
=
True
# FIXME
rsv2
=
True
# FIXME
rsv3
=
True
# FIXME
parameters
=
[
'quota'
]
defaults
=
{
'quota'
:
None
}
def
__init__
(
self
,
**
kwargs
):
super
(
Multiplex
,
self
).
__init__
(
**
kwargs
)
def
__init__
(
self
,
defaults
=
{},
request
=
{}
):
Extension
.
__init__
(
self
,
defaults
,
request
)
# TODO: check "quota" value
def
hook_send
(
self
,
frame
):
raise
NotImplementedError
# TODO
def
hook_recv
(
self
,
frame
):
class
Hook
:
def
send
(
self
,
frame
):
raise
NotImplementedError
# TODO
def
header_params
(
self
):
def
recv
(
self
,
frame
):
raise
NotImplementedError
# TODO
...
...
handshake.py
View file @
e465862f
...
...
@@ -142,14 +142,20 @@ class ServerHandshake(Handshake):
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
extensions
=
[]
all_params
=
[]
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
ext
)
if
name
in
supported_ext
:
extensions
.
append
(
supported_ext
[
name
](
**
params
))
extensions
.
append
(
supported_ext
[
name
])
all_params
.
append
(
params
)
self
.
wsock
.
extensions
=
filter_extensions
(
extensions
)
for
ext
,
params
in
zip
(
self
.
wsock
.
extensions
,
all_params
):
hook
=
ext
.
Hook
(
**
params
)
self
.
wsock
.
add_hook
(
send
=
hook
.
send
,
recv
=
hook
.
recv
)
else
:
self
.
wsock
.
extensions
=
[]
...
...
@@ -183,10 +189,11 @@ class ServerHandshake(Handshake):
yield
'Sec-WebSocket-Protocol'
,
self
.
wsock
.
protocol
if
self
.
wsock
.
extensions
:
values
=
[
format_param_hdr
(
e
.
name
,
e
.
header_params
()
)
values
=
[
format_param_hdr
(
e
.
name
,
e
.
request
)
for
e
in
self
.
wsock
.
extensions
]
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
values
)
class
ClientHandshake
(
Handshake
):
"""
Executes a handshake as the client end point of the socket. May raise a
...
...
@@ -230,7 +237,7 @@ class ClientHandshake(Handshake):
if
accept
!=
required_accept
:
self
.
fail
(
'invalid websocket accept header "%s"'
%
accept
)
# Compare extensions
# Compare extensions
, add hooks only for those returned by server
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
self
.
wsock
.
extensions
=
[]
...
...
@@ -242,7 +249,9 @@ class ClientHandshake(Handshake):
raise
HandshakeError
(
'server handshake contains '
'unsupported extension "%s"'
%
name
)
self
.
wsock
.
extensions
.
append
(
supported_ext
[
name
](
**
params
))
hook
=
supported_ext
[
name
].
Hook
(
**
params
)
self
.
wsock
.
extensions
.
append
(
supported_ext
[
name
])
self
.
wsock
.
add_hook
(
send
=
hook
.
send
,
recv
=
hook
.
recv
)
# Assert that returned protocol (if any) is supported
if
'Sec-WebSocket-Protocol'
in
headers
:
...
...
@@ -325,7 +334,7 @@ class ClientHandshake(Handshake):
yield
'Sec-WebSocket-Protocol'
,
', '
.
join
(
self
.
wsock
.
protocols
)
if
self
.
wsock
.
extensions
:
values
=
[
format_param_hdr
(
e
.
name
,
e
.
header_params
()
)
values
=
[
format_param_hdr
(
e
.
name
,
e
.
request
)
for
e
in
self
.
wsock
.
extensions
]
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
values
)
...
...
websocket.py
View file @
e465862f
...
...
@@ -41,7 +41,7 @@ class websocket(object):
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extension
classes
.
`extensions` is a list of supported extension
s (`Extension` instances)
.
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
...
...
@@ -68,6 +68,8 @@ class websocket(object):
self
.
sock
=
sock
or
socket
.
socket
(
sfamily
,
socket
.
SOCK_STREAM
,
sproto
)
self
.
secure
=
False
self
.
handshake_sent
=
False
self
.
hooks_send
=
[]
self
.
hooks_recv
=
[]
def
bind
(
self
,
address
):
self
.
sock
.
bind
(
address
)
...
...
@@ -104,8 +106,8 @@ class websocket(object):
Send a number of frames.
"""
for
frame
in
args
:
for
ext
in
self
.
extensions
:
frame
=
ext
.
hook_send
(
frame
)
for
hook
in
self
.
hooks_send
:
frame
=
hook
(
frame
)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self
.
sock
.
sendall
(
frame
.
pack
())
...
...
@@ -117,8 +119,8 @@ class websocket(object):
"""
frame
=
receive_frame
(
self
.
sock
)
for
ext
in
reversed
(
self
.
extensions
)
:
frame
=
ext
.
hook_recv
(
frame
)
for
hook
in
self
.
hooks_recv
:
frame
=
hook
(
frame
)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return
frame
...
...
@@ -156,3 +158,10 @@ class websocket(object):
self
.
secure
=
True
self
.
sock
=
ssl
.
wrap_socket
(
self
.
sock
,
*
args
,
**
kwargs
)
def
add_hook
(
self
,
send
=
None
,
recv
=
None
):
if
send
:
self
.
hooks_send
.
append
(
send
)
if
recv
:
self
.
hooks_recv
.
prepend
(
recv
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment