test-http-bad-server: track close condition in an object
authorPierre-Yves David <pierre-yves.david@octobus.net>
Fri, 21 Jan 2022 01:07:50 +0100
changeset 48632 caa6694dac45
parent 48631 8039cca948f8
child 48633 f91f98e9834a
test-http-bad-server: track close condition in an object In order to make the logic more advanced, we need to unify it. To unify it, we introduce a small object that will be responsible for tracking and enforcing "premature socket close" conditions for both proxy object (socketproxy and fileobjectproxy). More logic will be moved into the object in later changesets. Differential Revision: https://phab.mercurial-scm.org/D12041
tests/testlib/badserverext.py
--- a/tests/testlib/badserverext.py	Wed Jan 19 19:14:17 2022 +0100
+++ b/tests/testlib/badserverext.py	Fri Jan 21 01:07:50 2022 +0100
@@ -64,29 +64,47 @@
     default=False,
 )
 
+
+class ConditionTracker(object):
+    def __init__(self, close_after_recv_bytes, close_after_send_bytes):
+        self._all_close_after_recv_bytes = close_after_recv_bytes
+        self._all_close_after_send_bytes = close_after_send_bytes
+
+    def start_next_request(self):
+        """move to the next set of close condition"""
+        if self._all_close_after_recv_bytes:
+            self.target_recv_bytes = self._all_close_after_recv_bytes.pop(0)
+            self.remaining_recv_bytes = self.target_recv_bytes
+        else:
+            self.target_recv_bytes = None
+            self.remaining_recv_bytes = None
+        if self._all_close_after_send_bytes:
+            self.target_send_bytes = self._all_close_after_send_bytes.pop(0)
+            self.remaining_send_bytes = self.target_send_bytes
+        else:
+            self.target_send_bytes = None
+            self.remaining_send_bytes = None
+
+    def might_close(self):
+        """True, if any processing will be needed"""
+        if self.remaining_recv_bytes is not None:
+            return True
+        if self.remaining_send_bytes is not None:
+            return True
+        return False
+
+
 # We can't adjust __class__ on a socket instance. So we define a proxy type.
 class socketproxy(object):
-    __slots__ = (
-        '_orig',
-        '_logfp',
-        '_close_after_recv_bytes',
-        '_close_after_send_bytes',
-    )
+    __slots__ = ('_orig', '_logfp', '_cond')
 
-    def __init__(
-        self, obj, logfp, close_after_recv_bytes=0, close_after_send_bytes=0
-    ):
+    def __init__(self, obj, logfp, condition_tracked):
         object.__setattr__(self, '_orig', obj)
         object.__setattr__(self, '_logfp', logfp)
-        object.__setattr__(
-            self, '_close_after_recv_bytes', close_after_recv_bytes
-        )
-        object.__setattr__(
-            self, '_close_after_send_bytes', close_after_send_bytes
-        )
+        object.__setattr__(self, '_cond', condition_tracked)
 
     def __getattribute__(self, name):
-        if name in ('makefile', 'sendall', '_writelog'):
+        if name in ('makefile', 'sendall', '_writelog', '_cond_close'):
             return object.__getattribute__(self, name)
 
         return getattr(object.__getattribute__(self, '_orig'), name)
@@ -108,22 +126,12 @@
         f = object.__getattribute__(self, '_orig').makefile(mode, bufsize)
 
         logfp = object.__getattribute__(self, '_logfp')
-        close_after_recv_bytes = object.__getattribute__(
-            self, '_close_after_recv_bytes'
-        )
-        close_after_send_bytes = object.__getattribute__(
-            self, '_close_after_send_bytes'
-        )
+        cond = object.__getattribute__(self, '_cond')
 
-        return fileobjectproxy(
-            f,
-            logfp,
-            close_after_recv_bytes=close_after_recv_bytes,
-            close_after_send_bytes=close_after_send_bytes,
-        )
+        return fileobjectproxy(f, logfp, cond)
 
     def sendall(self, data, flags=0):
-        remaining = object.__getattribute__(self, '_close_after_send_bytes')
+        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
 
         # No read limit. Call original function.
         if not remaining:
@@ -145,7 +153,7 @@
             % (len(newdata), len(data), remaining, newdata)
         )
 
-        object.__setattr__(self, '_close_after_send_bytes', remaining)
+        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
 
         if remaining <= 0:
             self._writelog(b'write limit reached; closing socket')
@@ -158,24 +166,12 @@
 
 # We can't adjust __class__ on socket._fileobject, so define a proxy.
 class fileobjectproxy(object):
-    __slots__ = (
-        '_orig',
-        '_logfp',
-        '_close_after_recv_bytes',
-        '_close_after_send_bytes',
-    )
+    __slots__ = ('_orig', '_logfp', '_cond')
 
-    def __init__(
-        self, obj, logfp, close_after_recv_bytes=0, close_after_send_bytes=0
-    ):
+    def __init__(self, obj, logfp, condition_tracked):
         object.__setattr__(self, '_orig', obj)
         object.__setattr__(self, '_logfp', logfp)
-        object.__setattr__(
-            self, '_close_after_recv_bytes', close_after_recv_bytes
-        )
-        object.__setattr__(
-            self, '_close_after_send_bytes', close_after_send_bytes
-        )
+        object.__setattr__(self, '_cond', condition_tracked)
 
     def __getattribute__(self, name):
         if name in ('_close', 'read', 'readline', 'write', '_writelog'):
@@ -210,7 +206,7 @@
             self._sock.shutdown(socket.SHUT_RDWR)
 
     def read(self, size=-1):
-        remaining = object.__getattribute__(self, '_close_after_recv_bytes')
+        remaining = object.__getattribute__(self, '_cond').remaining_recv_bytes
 
         # No read limit. Call original function.
         if not remaining:
@@ -235,7 +231,7 @@
             % (size, origsize, len(result), result)
         )
 
-        object.__setattr__(self, '_close_after_recv_bytes', remaining)
+        object.__getattribute__(self, '_cond').remaining_recv_bytes = remaining
 
         if remaining <= 0:
             self._writelog(b'read limit reached; closing socket')
@@ -247,7 +243,7 @@
         return result
 
     def readline(self, size=-1):
-        remaining = object.__getattribute__(self, '_close_after_recv_bytes')
+        remaining = object.__getattribute__(self, '_cond').remaining_recv_bytes
 
         # No read limit. Call original function.
         if not remaining:
@@ -272,7 +268,7 @@
             % (size, origsize, len(result), result)
         )
 
-        object.__setattr__(self, '_close_after_recv_bytes', remaining)
+        object.__getattribute__(self, '_cond').remaining_recv_bytes = remaining
 
         if remaining <= 0:
             self._writelog(b'read limit reached; closing socket')
@@ -284,7 +280,7 @@
         return result
 
     def write(self, data):
-        remaining = object.__getattribute__(self, '_close_after_send_bytes')
+        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
 
         # No byte limit on this operation. Call original function.
         if not remaining:
@@ -306,7 +302,7 @@
 
         result = object.__getattribute__(self, '_orig').write(newdata)
 
-        object.__setattr__(self, '_close_after_send_bytes', remaining)
+        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
 
         if remaining <= 0:
             self._writelog(b'write limit reached; closing socket')
@@ -317,6 +313,12 @@
         return result
 
 
+def process_config(value):
+    parts = value.split(b',')
+    integers = [int(v) for v in parts if v]
+    return [v if v else None for v in integers]
+
+
 def extsetup(ui):
     # Change the base HTTP server class so various events can be performed.
     # See SocketServer.BaseServer for how the specially named methods work.
@@ -325,12 +327,15 @@
             self._ui = ui
             super(badserver, self).__init__(ui, *args, **kwargs)
 
-            recvbytes = self._ui.config(b'badserver', b'close-after-recv-bytes')
-            recvbytes = recvbytes.split(b',')
-            self.close_after_recv_bytes = [int(v) for v in recvbytes if v]
-            sendbytes = self._ui.config(b'badserver', b'close-after-send-bytes')
-            sendbytes = sendbytes.split(b',')
-            self.close_after_send_bytes = [int(v) for v in sendbytes if v]
+            all_recv_bytes = self._ui.config(
+                b'badserver', b'close-after-recv-bytes'
+            )
+            all_recv_bytes = process_config(all_recv_bytes)
+            all_send_bytes = self._ui.config(
+                b'badserver', b'close-after-send-bytes'
+            )
+            all_send_bytes = process_config(all_send_bytes)
+            self._cond = ConditionTracker(all_recv_bytes, all_send_bytes)
 
             # Need to inherit object so super() works.
             class badrequesthandler(self.RequestHandlerClass, object):
@@ -370,21 +375,11 @@
         # is a hgweb.server._httprequesthandler.
         def process_request(self, socket, address):
             # Wrap socket in a proxy if we need to count bytes.
-            if self.close_after_recv_bytes:
-                close_after_recv_bytes = self.close_after_recv_bytes.pop(0)
-            else:
-                close_after_recv_bytes = 0
-            if self.close_after_send_bytes:
-                close_after_send_bytes = self.close_after_send_bytes.pop(0)
-            else:
-                close_after_send_bytes = 0
+            self._cond.start_next_request()
 
-            if close_after_recv_bytes or close_after_send_bytes:
+            if self._cond.might_close():
                 socket = socketproxy(
-                    socket,
-                    self.errorlog,
-                    close_after_recv_bytes=close_after_recv_bytes,
-                    close_after_send_bytes=close_after_send_bytes,
+                    socket, self.errorlog, condition_tracked=self._cond
                 )
 
             return super(badserver, self).process_request(socket, address)