changeset 48612:11e5cb170d36

test-http-bad-server: factor code dealing with "write" in the new object This will make sure both `sendall` and `write` do the same processing and make it simpler to update that processing in the future. Differential Revision: https://phab.mercurial-scm.org/D12043
author Pierre-Yves David <pierre-yves.david@octobus.net>
date Fri, 21 Jan 2022 00:54:15 +0100
parents f91f98e9834a
children b060e305d79f
files tests/testlib/badserverext.py
diffstat 1 files changed, 69 insertions(+), 61 deletions(-) [+]
line wrap: on
line diff
--- a/tests/testlib/badserverext.py	Sun Jan 23 21:25:01 2022 +0100
+++ b/tests/testlib/badserverext.py	Fri Jan 21 00:54:15 2022 +0100
@@ -70,6 +70,11 @@
         self._all_close_after_recv_bytes = close_after_recv_bytes
         self._all_close_after_send_bytes = close_after_send_bytes
 
+        self.target_recv_bytes = None
+        self.remaining_recv_bytes = None
+        self.target_send_bytes = None
+        self.remaining_send_bytes = None
+
     def start_next_request(self):
         """move to the next set of close condition"""
         if self._all_close_after_recv_bytes:
@@ -93,6 +98,54 @@
             return True
         return False
 
+    def forward_write(self, obj, method, data, *args, **kwargs):
+        """call an underlying write function until condition are met
+
+        When the condition are met the socket is closed
+        """
+        remaining = self.remaining_send_bytes
+
+        orig = object.__getattribute__(obj, '_orig')
+        bmethod = method.encode('ascii')
+        func = getattr(orig, method)
+        # No byte limit on this operation. Call original function.
+        if not remaining:
+            result = func(data, *args, **kwargs)
+            obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data))
+            return result
+
+        remaining = max(0, remaining)
+
+        if remaining > 0:
+            if remaining < len(data):
+                newdata = data[0:remaining]
+            else:
+                newdata = data
+
+            remaining -= len(newdata)
+
+            obj._writelog(
+                b'%s(%d from %d) -> (%d) %s'
+                % (
+                    bmethod,
+                    len(newdata),
+                    len(data),
+                    remaining,
+                    newdata,
+                )
+            )
+
+            result = func(newdata, *args, **kwargs)
+
+        self.remaining_send_bytes = remaining
+
+        if remaining <= 0:
+            obj._writelog(b'write limit reached; closing socket')
+            object.__getattribute__(obj, '_cond_close')()
+            raise Exception('connection closed after sending N bytes')
+
+        return result
+
 
 # We can't adjust __class__ on a socket instance. So we define a proxy type.
 class socketproxy(object):
@@ -131,37 +184,11 @@
         return fileobjectproxy(f, logfp, cond)
 
     def sendall(self, data, flags=0):
-        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
-
-        # No read limit. Call original function.
-        if not remaining:
-            result = object.__getattribute__(self, '_orig').sendall(data, flags)
-            self._writelog(b'sendall(%d) -> %s' % (len(data), data))
-            return result
-
-        if len(data) > remaining:
-            newdata = data[0:remaining]
-        else:
-            newdata = data
-
-        remaining -= len(newdata)
+        cond = object.__getattribute__(self, '_cond')
+        return cond.forward_write(self, 'sendall', data, flags)
 
-        result = object.__getattribute__(self, '_orig').sendall(newdata, flags)
-
-        self._writelog(
-            b'sendall(%d from %d) -> (%d) %s'
-            % (len(newdata), len(data), remaining, newdata)
-        )
-
-        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
-
-        if remaining <= 0:
-            self._writelog(b'write limit reached; closing socket')
-            object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
-
-            raise Exception('connection closed after sending N bytes')
-
-        return result
+    def _cond_close(self):
+        object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
 
 
 # We can't adjust __class__ on socket._fileobject, so define a proxy.
@@ -174,7 +201,14 @@
         object.__setattr__(self, '_cond', condition_tracked)
 
     def __getattribute__(self, name):
-        if name in ('_close', 'read', 'readline', 'write', '_writelog'):
+        if name in (
+            '_close',
+            'read',
+            'readline',
+            'write',
+            '_writelog',
+            '_cond_close',
+        ):
             return object.__getattribute__(self, name)
 
         return getattr(object.__getattribute__(self, '_orig'), name)
@@ -280,37 +314,11 @@
         return result
 
     def write(self, data):
-        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
-
-        # No byte limit on this operation. Call original function.
-        if not remaining:
-            result = object.__getattribute__(self, '_orig').write(data)
-            self._writelog(b'write(%d) -> %s' % (len(data), data))
-            return result
-
-        if len(data) > remaining:
-            newdata = data[0:remaining]
-        else:
-            newdata = data
-
-        remaining -= len(newdata)
+        cond = object.__getattribute__(self, '_cond')
+        return cond.forward_write(self, 'write', data)
 
-        result = object.__getattribute__(self, '_orig').write(newdata)
-
-        self._writelog(
-            b'write(%d from %d) -> (%d) %s'
-            % (len(newdata), len(data), remaining, newdata)
-        )
-
-        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
-
-        if remaining <= 0:
-            self._writelog(b'write limit reached; closing socket')
-            self._close()
-
-            raise Exception('connection closed after sending N bytes')
-
-        return result
+    def _cond_close(self):
+        self._close()
 
 
 def process_config(value):