Prechádzať zdrojové kódy

Improved frame masking accessibily in various high-level functions

Taddeus Kroes 13 rokov pred
rodič
commit
c7572ccb01
3 zmenil súbory, kde vykonal 16 pridanie a 9 odobranie
  1. 3 3
      connection.py
  2. 11 4
      frame.py
  3. 2 2
      message.py

+ 3 - 3
connection.py

@@ -28,16 +28,16 @@ class Connection(object):
 
         self.onopen()
 
-    def send(self, message, fragment_size=None):
+    def send(self, message, fragment_size=None, mask=False):
         """
         Send a message. If `fragment_size` is specified, the message is
         fragmented into multiple frames whose payload size does not extend
         `fragment_size`.
         """
         if fragment_size is None:
-            self.sock.send(message.frame())
+            self.sock.send(message.frame(mask=mask))
         else:
-            self.sock.send(*message.fragment(fragment_size))
+            self.sock.send(*message.fragment(fragment_size, mask=mask))
 
     def receive(self):
         """

+ 11 - 4
frame.py

@@ -28,8 +28,8 @@ class Frame(object):
     To encoding a frame for sending it over a socket, use Frame.pack(). To
     receive and decode a frame from a socket, use receive_frame().
     """
-    def __init__(self, opcode, payload, masking_key='', final=True, rsv1=False,
-            rsv2=False, rsv3=False):
+    def __init__(self, opcode, payload, masking_key='', mask=False, final=True,
+            rsv1=False, rsv2=False, rsv3=False):
         """
         Create a new frame.
 
@@ -37,12 +37,19 @@ class Frame(object):
 
         `payload` is a string of bytes containing the data sendt in the frame.
 
+        `masking_key` is an optional custom key to use for masking, or `mask`
+        can be used instead to let this constructor generate a random masking
+        key.
+
         `final` is a boolean indicating whether this frame is the last in a
         chain of fragments.
 
         `rsv1`, `rsv2` and `rsv3` are booleans indicating bit values for RSV1,
         RVS2 and RSV3, which are only non-zero if defined so by extensions.
         """
+        if mask:
+            masking_key = urandom(4)
+
         if len(masking_key)  not in (0, 4):
             raise ValueError('invalid masking key "%s"' % masking_key)
 
@@ -122,8 +129,8 @@ class Frame(object):
 
         for start in range(0, len(self.payload), fragment_size):
             payload = self.payload[start:start + fragment_size]
-            key = urandom(4) if mask else ''
-            frames.append(Frame(OPCODE_CONTINUATION, payload, key, False))
+            frames.append(Frame(OPCODE_CONTINUATION, payload, mask=mask,
+                                final=False))
 
         frames[0].opcode = self.opcode
         frames[-1].final = True

+ 2 - 2
message.py

@@ -9,8 +9,8 @@ class Message(object):
         self.opcode = opcode
         self.payload = payload
 
-    def frame(self):
-        return Frame(self.opcode, self.payload)
+    def frame(self, mask=False):
+        return Frame(self.opcode, self.payload, mask=mask)
 
     def fragment(self, fragment_size, mask=False):
         return self.frame().fragment(fragment_size, mask)