Commit c7572ccb authored by Taddeüs Kroes's avatar Taddeüs Kroes

Improved frame masking accessibily in various high-level functions

parent 53d9f210
...@@ -28,16 +28,16 @@ class Connection(object): ...@@ -28,16 +28,16 @@ class Connection(object):
self.onopen() self.onopen()
def send(self, message, fragment_size=None): def send(self, message, fragment_size=None, mask=False):
""" """
Send a message. If `fragment_size` is specified, the message is Send a message. If `fragment_size` is specified, the message is
fragmented into multiple frames whose payload size does not extend fragmented into multiple frames whose payload size does not extend
`fragment_size`. `fragment_size`.
""" """
if fragment_size is None: if fragment_size is None:
self.sock.send(message.frame()) self.sock.send(message.frame(mask=mask))
else: else:
self.sock.send(*message.fragment(fragment_size)) self.sock.send(*message.fragment(fragment_size, mask=mask))
def receive(self): def receive(self):
""" """
......
...@@ -28,8 +28,8 @@ class Frame(object): ...@@ -28,8 +28,8 @@ class Frame(object):
To encoding a frame for sending it over a socket, use Frame.pack(). To To encoding a frame for sending it over a socket, use Frame.pack(). To
receive and decode a frame from a socket, use receive_frame(). receive and decode a frame from a socket, use receive_frame().
""" """
def __init__(self, opcode, payload, masking_key='', final=True, rsv1=False, def __init__(self, opcode, payload, masking_key='', mask=False, final=True,
rsv2=False, rsv3=False): rsv1=False, rsv2=False, rsv3=False):
""" """
Create a new frame. Create a new frame.
...@@ -37,12 +37,19 @@ class Frame(object): ...@@ -37,12 +37,19 @@ class Frame(object):
`payload` is a string of bytes containing the data sendt in the frame. `payload` is a string of bytes containing the data sendt in the frame.
`masking_key` is an optional custom key to use for masking, or `mask`
can be used instead to let this constructor generate a random masking
key.
`final` is a boolean indicating whether this frame is the last in a `final` is a boolean indicating whether this frame is the last in a
chain of fragments. chain of fragments.
`rsv1`, `rsv2` and `rsv3` are booleans indicating bit values for RSV1, `rsv1`, `rsv2` and `rsv3` are booleans indicating bit values for RSV1,
RVS2 and RSV3, which are only non-zero if defined so by extensions. RVS2 and RSV3, which are only non-zero if defined so by extensions.
""" """
if mask:
masking_key = urandom(4)
if len(masking_key) not in (0, 4): if len(masking_key) not in (0, 4):
raise ValueError('invalid masking key "%s"' % masking_key) raise ValueError('invalid masking key "%s"' % masking_key)
...@@ -122,8 +129,8 @@ class Frame(object): ...@@ -122,8 +129,8 @@ class Frame(object):
for start in range(0, len(self.payload), fragment_size): for start in range(0, len(self.payload), fragment_size):
payload = self.payload[start:start + fragment_size] payload = self.payload[start:start + fragment_size]
key = urandom(4) if mask else '' frames.append(Frame(OPCODE_CONTINUATION, payload, mask=mask,
frames.append(Frame(OPCODE_CONTINUATION, payload, key, False)) final=False))
frames[0].opcode = self.opcode frames[0].opcode = self.opcode
frames[-1].final = True frames[-1].final = True
......
...@@ -9,8 +9,8 @@ class Message(object): ...@@ -9,8 +9,8 @@ class Message(object):
self.opcode = opcode self.opcode = opcode
self.payload = payload self.payload = payload
def frame(self): def frame(self, mask=False):
return Frame(self.opcode, self.payload) return Frame(self.opcode, self.payload, mask=mask)
def fragment(self, fragment_size, mask=False): def fragment(self, fragment_size, mask=False):
return self.frame().fragment(fragment_size, mask) return self.frame().fragment(fragment_size, mask)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment