--- a/tests/testlib/badserverext.py Sun Jan 23 21:25:01 2022 +0100
+++ b/tests/testlib/badserverext.py Fri Jan 21 00:54:15 2022 +0100
@@ -70,6 +70,11 @@
self._all_close_after_recv_bytes = close_after_recv_bytes
self._all_close_after_send_bytes = close_after_send_bytes
+ self.target_recv_bytes = None
+ self.remaining_recv_bytes = None
+ self.target_send_bytes = None
+ self.remaining_send_bytes = None
+
def start_next_request(self):
"""move to the next set of close condition"""
if self._all_close_after_recv_bytes:
@@ -93,6 +98,54 @@
return True
return False
+ def forward_write(self, obj, method, data, *args, **kwargs):
+ """call an underlying write function until condition are met
+
+ When the condition are met the socket is closed
+ """
+ remaining = self.remaining_send_bytes
+
+ orig = object.__getattribute__(obj, '_orig')
+ bmethod = method.encode('ascii')
+ func = getattr(orig, method)
+ # No byte limit on this operation. Call original function.
+ if not remaining:
+ result = func(data, *args, **kwargs)
+ obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data))
+ return result
+
+ remaining = max(0, remaining)
+
+ if remaining > 0:
+ if remaining < len(data):
+ newdata = data[0:remaining]
+ else:
+ newdata = data
+
+ remaining -= len(newdata)
+
+ obj._writelog(
+ b'%s(%d from %d) -> (%d) %s'
+ % (
+ bmethod,
+ len(newdata),
+ len(data),
+ remaining,
+ newdata,
+ )
+ )
+
+ result = func(newdata, *args, **kwargs)
+
+ self.remaining_send_bytes = remaining
+
+ if remaining <= 0:
+ obj._writelog(b'write limit reached; closing socket')
+ object.__getattribute__(obj, '_cond_close')()
+ raise Exception('connection closed after sending N bytes')
+
+ return result
+
# We can't adjust __class__ on a socket instance. So we define a proxy type.
class socketproxy(object):
@@ -131,37 +184,11 @@
return fileobjectproxy(f, logfp, cond)
def sendall(self, data, flags=0):
- remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
-
- # No read limit. Call original function.
- if not remaining:
- result = object.__getattribute__(self, '_orig').sendall(data, flags)
- self._writelog(b'sendall(%d) -> %s' % (len(data), data))
- return result
-
- if len(data) > remaining:
- newdata = data[0:remaining]
- else:
- newdata = data
-
- remaining -= len(newdata)
+ cond = object.__getattribute__(self, '_cond')
+ return cond.forward_write(self, 'sendall', data, flags)
- result = object.__getattribute__(self, '_orig').sendall(newdata, flags)
-
- self._writelog(
- b'sendall(%d from %d) -> (%d) %s'
- % (len(newdata), len(data), remaining, newdata)
- )
-
- object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
-
- if remaining <= 0:
- self._writelog(b'write limit reached; closing socket')
- object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
-
- raise Exception('connection closed after sending N bytes')
-
- return result
+ def _cond_close(self):
+ object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)
# We can't adjust __class__ on socket._fileobject, so define a proxy.
@@ -174,7 +201,14 @@
object.__setattr__(self, '_cond', condition_tracked)
def __getattribute__(self, name):
- if name in ('_close', 'read', 'readline', 'write', '_writelog'):
+ if name in (
+ '_close',
+ 'read',
+ 'readline',
+ 'write',
+ '_writelog',
+ '_cond_close',
+ ):
return object.__getattribute__(self, name)
return getattr(object.__getattribute__(self, '_orig'), name)
@@ -280,37 +314,11 @@
return result
def write(self, data):
- remaining = object.__getattribute__(self, '_cond').remaining_send_bytes
-
- # No byte limit on this operation. Call original function.
- if not remaining:
- result = object.__getattribute__(self, '_orig').write(data)
- self._writelog(b'write(%d) -> %s' % (len(data), data))
- return result
-
- if len(data) > remaining:
- newdata = data[0:remaining]
- else:
- newdata = data
-
- remaining -= len(newdata)
+ cond = object.__getattribute__(self, '_cond')
+ return cond.forward_write(self, 'write', data)
- result = object.__getattribute__(self, '_orig').write(newdata)
-
- self._writelog(
- b'write(%d from %d) -> (%d) %s'
- % (len(newdata), len(data), remaining, newdata)
- )
-
- object.__getattribute__(self, '_cond').remaining_send_bytes = remaining
-
- if remaining <= 0:
- self._writelog(b'write limit reached; closing socket')
- self._close()
-
- raise Exception('connection closed after sending N bytes')
-
- return result
+ def _cond_close(self):
+ self._close()
def process_config(value):