Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
W
wspy
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Taddeüs Kroes
wspy
Commits
e465862f
Commit
e465862f
authored
Aug 22, 2013
by
Taddeüs Kroes
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Revised extension instantiation, now 'hooks' are installed which are cleaner and more flexible
parent
6efb8807
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
80 deletions
+84
-80
extension.py
extension.py
+56
-70
handshake.py
handshake.py
+14
-5
websocket.py
websocket.py
+14
-5
No files found.
extension.py
View file @
e465862f
from
errors
import
HandshakeError
class
Extension
(
object
):
name
=
''
rsv1
=
False
rsv2
=
False
rsv3
=
False
opcodes
=
[]
parameters
=
[]
def
__init__
(
self
,
**
kwargs
):
for
param
in
self
.
parameters
:
setattr
(
self
,
param
,
None
)
defaults
=
{}
request
=
{}
for
param
,
value
in
kwargs
.
items
():
if
param
not
in
self
.
parameters
:
raise
HandshakeError
(
'unrecognized parameter "%s"'
%
param
)
def
__init__
(
self
,
defaults
=
{},
request
=
{}):
for
param
in
defaults
.
keys
()
+
request
.
keys
():
if
param
not
in
self
.
defaults
:
raise
KeyError
(
'unrecognized parameter "%s"'
%
param
)
if
value
is
None
:
value
=
True
# Copy dict first to avoid duplicate references to the same object
self
.
defaults
=
dict
(
self
.
__class__
.
defaults
)
self
.
defaults
.
update
(
defaults
)
setattr
(
self
,
param
,
value
)
self
.
request
=
dict
(
self
.
__class__
.
request
)
self
.
request
.
update
(
request
)
def
__str__
(
self
,
frame
):
if
len
(
self
.
parameters
):
params
=
' '
+
', '
.
join
(
p
+
'='
+
str
(
getattr
(
self
,
p
))
for
p
in
self
.
parameters
)
else
:
params
=
''
return
'<Extension "%s" defaults=%s request=%s>'
\
%
(
self
.
name
,
self
.
defaults
,
self
.
request
)
return
'<Extension "%s"%s>'
%
(
self
.
name
,
params
)
class
Hook
:
def
__init__
(
self
,
**
kwargs
):
for
param
,
value
in
kwargs
.
iteritems
():
setattr
(
self
,
param
,
value
)
def
header_params
(
self
,
frame
):
return
{}
def
send
(
self
,
frame
):
return
frame
def
hook_send
(
self
,
frame
):
return
frame
def
hook_receive
(
self
,
frame
):
return
frame
def
recv
(
self
,
frame
):
return
frame
class
DeflateFrame
(
Extension
):
...
...
@@ -57,49 +51,43 @@ class DeflateFrame(Extension):
name
=
'deflate-frame'
rsv1
=
True
parameters
=
[
'max_window_bits'
,
'no_context_takeover'
]
# FIXME: is this correct?
default_max_window_bits
=
32768
# FIXME: is 32768 (below) correct?
defaults
=
{
'max_window_bits'
:
32768
,
'no_context_takeover'
:
True
}
def
__init__
(
self
,
**
kwargs
):
super
(
DeflateFrame
,
self
).
__init__
(
**
kwargs
)
def
__init__
(
self
,
defaults
=
{},
request
=
{}
):
Extension
.
__init__
(
self
,
defaults
,
request
)
if
self
.
max_window_bits
is
None
:
self
.
max_window_bits
=
self
.
default_max_window_bits
elif
not
isinstance
(
self
.
max_window_bits
,
int
):
raise
HandshakeError
(
'"max_window_bits" must be an integer'
)
elif
self
.
max_window_bits
>
32768
:
raise
HandshakeError
(
'"max_window_bits" may not be larger than '
'32768'
)
mwb
=
self
.
defaults
[
'max_window_bits'
]
cto
=
self
.
defaults
[
'no_context_takeover'
]
if
self
.
no_context_takeover
is
None
:
self
.
no_context_takeover
=
False
elif
self
.
no_context_takeover
is
not
True
:
raise
HandshakeError
(
'"no_context_takeover" must have no value
'
)
if
not
isinstance
(
mwb
,
int
)
:
raise
ValueError
(
'"max_window_bits" must be an integer'
)
elif
mwb
>
32768
:
raise
ValueError
(
'"max_window_bits" may not be larger than 32768
'
)
def
hook_send
(
self
,
frame
):
if
not
frame
.
rsv1
:
frame
.
rsv1
=
True
frame
.
payload
=
self
.
deflate
(
frame
.
payload
)
if
cto
is
not
False
and
cto
is
not
True
:
raise
ValueError
(
'"no_context_takeover" must have no value'
)
return
frame
class
Hook
:
def
send
(
self
,
frame
):
if
not
frame
.
rsv1
:
frame
.
rsv1
=
True
frame
.
payload
=
self
.
deflate
(
frame
.
payload
)
def
hook_recv
(
self
,
frame
):
if
frame
.
rsv1
:
frame
.
rsv1
=
False
frame
.
payload
=
self
.
inflate
(
frame
.
payload
)
return
frame
return
frame
def
recv
(
self
,
frame
):
if
frame
.
rsv1
:
frame
.
rsv1
=
False
frame
.
payload
=
self
.
inflate
(
frame
.
payload
)
def
header_params
(
self
):
raise
NotImplementedError
# TODO
return
frame
def
deflate
(
self
,
data
):
raise
NotImplementedError
# TODO
def
deflate
(
self
,
data
):
raise
NotImplementedError
# TODO
def
inflate
(
self
,
data
):
raise
NotImplementedError
# TODO
def
inflate
(
self
,
data
):
raise
NotImplementedError
# TODO
class
Multiplex
(
Extension
):
...
...
@@ -115,21 +103,19 @@ class Multiplex(Extension):
rsv1
=
True
# FIXME
rsv2
=
True
# FIXME
rsv3
=
True
# FIXME
parameters
=
[
'quota'
]
defaults
=
{
'quota'
:
None
}
def
__init__
(
self
,
**
kwargs
):
super
(
Multiplex
,
self
).
__init__
(
**
kwargs
)
def
__init__
(
self
,
defaults
=
{},
request
=
{}
):
Extension
.
__init__
(
self
,
defaults
,
request
)
# TODO: check "quota" value
def
hook_send
(
self
,
frame
):
raise
NotImplementedError
# TODO
def
hook_recv
(
self
,
frame
):
raise
NotImplementedError
# TODO
class
Hook
:
def
send
(
self
,
frame
):
raise
NotImplementedError
# TODO
def
header_params
(
self
):
raise
NotImplementedError
# TODO
def
recv
(
self
,
frame
):
raise
NotImplementedError
# TODO
def
filter_extensions
(
extensions
):
...
...
handshake.py
View file @
e465862f
...
...
@@ -142,14 +142,20 @@ class ServerHandshake(Handshake):
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
extensions
=
[]
all_params
=
[]
for
ext
in
split_stripped
(
headers
[
'Sec-WebSocket-Extensions'
]):
name
,
params
=
parse_param_hdr
(
ext
)
if
name
in
supported_ext
:
extensions
.
append
(
supported_ext
[
name
](
**
params
))
extensions
.
append
(
supported_ext
[
name
])
all_params
.
append
(
params
)
self
.
wsock
.
extensions
=
filter_extensions
(
extensions
)
for
ext
,
params
in
zip
(
self
.
wsock
.
extensions
,
all_params
):
hook
=
ext
.
Hook
(
**
params
)
self
.
wsock
.
add_hook
(
send
=
hook
.
send
,
recv
=
hook
.
recv
)
else
:
self
.
wsock
.
extensions
=
[]
...
...
@@ -183,10 +189,11 @@ class ServerHandshake(Handshake):
yield
'Sec-WebSocket-Protocol'
,
self
.
wsock
.
protocol
if
self
.
wsock
.
extensions
:
values
=
[
format_param_hdr
(
e
.
name
,
e
.
header_params
()
)
values
=
[
format_param_hdr
(
e
.
name
,
e
.
request
)
for
e
in
self
.
wsock
.
extensions
]
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
values
)
class
ClientHandshake
(
Handshake
):
"""
Executes a handshake as the client end point of the socket. May raise a
...
...
@@ -230,7 +237,7 @@ class ClientHandshake(Handshake):
if
accept
!=
required_accept
:
self
.
fail
(
'invalid websocket accept header "%s"'
%
accept
)
# Compare extensions
# Compare extensions
, add hooks only for those returned by server
if
'Sec-WebSocket-Extensions'
in
headers
:
supported_ext
=
dict
((
e
.
name
,
e
)
for
e
in
self
.
wsock
.
extensions
)
self
.
wsock
.
extensions
=
[]
...
...
@@ -242,7 +249,9 @@ class ClientHandshake(Handshake):
raise
HandshakeError
(
'server handshake contains '
'unsupported extension "%s"'
%
name
)
self
.
wsock
.
extensions
.
append
(
supported_ext
[
name
](
**
params
))
hook
=
supported_ext
[
name
].
Hook
(
**
params
)
self
.
wsock
.
extensions
.
append
(
supported_ext
[
name
])
self
.
wsock
.
add_hook
(
send
=
hook
.
send
,
recv
=
hook
.
recv
)
# Assert that returned protocol (if any) is supported
if
'Sec-WebSocket-Protocol'
in
headers
:
...
...
@@ -325,7 +334,7 @@ class ClientHandshake(Handshake):
yield
'Sec-WebSocket-Protocol'
,
', '
.
join
(
self
.
wsock
.
protocols
)
if
self
.
wsock
.
extensions
:
values
=
[
format_param_hdr
(
e
.
name
,
e
.
header_params
()
)
values
=
[
format_param_hdr
(
e
.
name
,
e
.
request
)
for
e
in
self
.
wsock
.
extensions
]
yield
'Sec-WebSocket-Extensions'
,
', '
.
join
(
values
)
...
...
websocket.py
View file @
e465862f
...
...
@@ -41,7 +41,7 @@ class websocket(object):
`protocols` is a list of supported protocol names.
`extensions` is a list of supported extension
classes
.
`extensions` is a list of supported extension
s (`Extension` instances)
.
`origin` (for client sockets) is the value for the "Origin" header sent
in a client handshake .
...
...
@@ -68,6 +68,8 @@ class websocket(object):
self
.
sock
=
sock
or
socket
.
socket
(
sfamily
,
socket
.
SOCK_STREAM
,
sproto
)
self
.
secure
=
False
self
.
handshake_sent
=
False
self
.
hooks_send
=
[]
self
.
hooks_recv
=
[]
def
bind
(
self
,
address
):
self
.
sock
.
bind
(
address
)
...
...
@@ -104,8 +106,8 @@ class websocket(object):
Send a number of frames.
"""
for
frame
in
args
:
for
ext
in
self
.
extensions
:
frame
=
ext
.
hook_send
(
frame
)
for
hook
in
self
.
hooks_send
:
frame
=
hook
(
frame
)
#print 'send frame:', frame, 'to %s:%d' % self.sock.getpeername()
self
.
sock
.
sendall
(
frame
.
pack
())
...
...
@@ -117,8 +119,8 @@ class websocket(object):
"""
frame
=
receive_frame
(
self
.
sock
)
for
ext
in
reversed
(
self
.
extensions
)
:
frame
=
ext
.
hook_recv
(
frame
)
for
hook
in
self
.
hooks_recv
:
frame
=
hook
(
frame
)
#print 'receive frame:', frame, 'from %s:%d' % self.sock.getpeername()
return
frame
...
...
@@ -156,3 +158,10 @@ class websocket(object):
self
.
secure
=
True
self
.
sock
=
ssl
.
wrap_socket
(
self
.
sock
,
*
args
,
**
kwargs
)
def
add_hook
(
self
,
send
=
None
,
recv
=
None
):
if
send
:
self
.
hooks_send
.
append
(
send
)
if
recv
:
self
.
hooks_recv
.
prepend
(
recv
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment