Mercurial > hg
diff tests/testlib/badserverext.py @ 48605:089cb4d6af5a
test-http-bad-server: move the extension in `testlib`
This seems like a better location for it.
Differential Revision: https://phab.mercurial-scm.org/D12036
author | Pierre-Yves David <pierre-yves.david@octobus.net> |
---|---|
date | Tue, 18 Jan 2022 21:29:43 +0100 |
parents | tests/badserverext.py@89a2afe31e82 |
children | ee1235afda4b |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/testlib/badserverext.py Tue Jan 18 21:29:43 2022 +0100 @@ -0,0 +1,384 @@ +# badserverext.py - Extension making servers behave badly +# +# Copyright 2017 Gregory Szorc <gregory.szorc@gmail.com> +# +# This software may be used and distributed according to the terms of the +# GNU General Public License version 2 or any later version. + +# no-check-code + +"""Extension to make servers behave badly. + +This extension is useful for testing Mercurial behavior when various network +events occur. + +Various config options in the [badserver] section influence behavior: + +closebeforeaccept + If true, close() the server socket when a new connection arrives before + accept() is called. The server will then exit. + +closeafteraccept + If true, the server will close() the client socket immediately after + accept(). + +closeafterrecvbytes + If defined, close the client socket after receiving this many bytes. + +closeaftersendbytes + If defined, close the client socket after sending this many bytes. +""" + +from __future__ import absolute_import + +import socket + +from mercurial import ( + pycompat, + registrar, +) + +from mercurial.hgweb import server + +configtable = {} +configitem = registrar.configitem(configtable) + +configitem( + b'badserver', + b'closeafteraccept', + default=False, +) +configitem( + b'badserver', + b'closeafterrecvbytes', + default=b'0', +) +configitem( + b'badserver', + b'closeaftersendbytes', + default=b'0', +) +configitem( + b'badserver', + b'closebeforeaccept', + default=False, +) + +# We can't adjust __class__ on a socket instance. So we define a proxy type. +class socketproxy(object): + __slots__ = ( + '_orig', + '_logfp', + '_closeafterrecvbytes', + '_closeaftersendbytes', + ) + + def __init__( + self, obj, logfp, closeafterrecvbytes=0, closeaftersendbytes=0 + ): + object.__setattr__(self, '_orig', obj) + object.__setattr__(self, '_logfp', logfp) + object.__setattr__(self, '_closeafterrecvbytes', closeafterrecvbytes) + object.__setattr__(self, '_closeaftersendbytes', closeaftersendbytes) + + def __getattribute__(self, name): + if name in ('makefile', 'sendall', '_writelog'): + return object.__getattribute__(self, name) + + return getattr(object.__getattribute__(self, '_orig'), name) + + def __delattr__(self, name): + delattr(object.__getattribute__(self, '_orig'), name) + + def __setattr__(self, name, value): + setattr(object.__getattribute__(self, '_orig'), name, value) + + def _writelog(self, msg): + msg = msg.replace(b'\r', b'\\r').replace(b'\n', b'\\n') + + object.__getattribute__(self, '_logfp').write(msg) + object.__getattribute__(self, '_logfp').write(b'\n') + object.__getattribute__(self, '_logfp').flush() + + def makefile(self, mode, bufsize): + f = object.__getattribute__(self, '_orig').makefile(mode, bufsize) + + logfp = object.__getattribute__(self, '_logfp') + closeafterrecvbytes = object.__getattribute__( + self, '_closeafterrecvbytes' + ) + closeaftersendbytes = object.__getattribute__( + self, '_closeaftersendbytes' + ) + + return fileobjectproxy( + f, + logfp, + closeafterrecvbytes=closeafterrecvbytes, + closeaftersendbytes=closeaftersendbytes, + ) + + def sendall(self, data, flags=0): + remaining = object.__getattribute__(self, '_closeaftersendbytes') + + # No read limit. Call original function. + if not remaining: + result = object.__getattribute__(self, '_orig').sendall(data, flags) + self._writelog(b'sendall(%d) -> %s' % (len(data), data)) + return result + + if len(data) > remaining: + newdata = data[0:remaining] + else: + newdata = data + + remaining -= len(newdata) + + result = object.__getattribute__(self, '_orig').sendall(newdata, flags) + + self._writelog( + b'sendall(%d from %d) -> (%d) %s' + % (len(newdata), len(data), remaining, newdata) + ) + + object.__setattr__(self, '_closeaftersendbytes', remaining) + + if remaining <= 0: + self._writelog(b'write limit reached; closing socket') + object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR) + + raise Exception('connection closed after sending N bytes') + + return result + + +# We can't adjust __class__ on socket._fileobject, so define a proxy. +class fileobjectproxy(object): + __slots__ = ( + '_orig', + '_logfp', + '_closeafterrecvbytes', + '_closeaftersendbytes', + ) + + def __init__( + self, obj, logfp, closeafterrecvbytes=0, closeaftersendbytes=0 + ): + object.__setattr__(self, '_orig', obj) + object.__setattr__(self, '_logfp', logfp) + object.__setattr__(self, '_closeafterrecvbytes', closeafterrecvbytes) + object.__setattr__(self, '_closeaftersendbytes', closeaftersendbytes) + + def __getattribute__(self, name): + if name in ('_close', 'read', 'readline', 'write', '_writelog'): + return object.__getattribute__(self, name) + + return getattr(object.__getattribute__(self, '_orig'), name) + + def __delattr__(self, name): + delattr(object.__getattribute__(self, '_orig'), name) + + def __setattr__(self, name, value): + setattr(object.__getattribute__(self, '_orig'), name, value) + + def _writelog(self, msg): + msg = msg.replace(b'\r', b'\\r').replace(b'\n', b'\\n') + + object.__getattribute__(self, '_logfp').write(msg) + object.__getattribute__(self, '_logfp').write(b'\n') + object.__getattribute__(self, '_logfp').flush() + + def _close(self): + # Python 3 uses an io.BufferedIO instance. Python 2 uses some file + # object wrapper. + if pycompat.ispy3: + orig = object.__getattribute__(self, '_orig') + + if hasattr(orig, 'raw'): + orig.raw._sock.shutdown(socket.SHUT_RDWR) + else: + self.close() + else: + self._sock.shutdown(socket.SHUT_RDWR) + + def read(self, size=-1): + remaining = object.__getattribute__(self, '_closeafterrecvbytes') + + # No read limit. Call original function. + if not remaining: + result = object.__getattribute__(self, '_orig').read(size) + self._writelog( + b'read(%d) -> (%d) (%s) %s' % (size, len(result), result) + ) + return result + + origsize = size + + if size < 0: + size = remaining + else: + size = min(remaining, size) + + result = object.__getattribute__(self, '_orig').read(size) + remaining -= len(result) + + self._writelog( + b'read(%d from %d) -> (%d) %s' + % (size, origsize, len(result), result) + ) + + object.__setattr__(self, '_closeafterrecvbytes', remaining) + + if remaining <= 0: + self._writelog(b'read limit reached, closing socket') + self._close() + + # This is the easiest way to abort the current request. + raise Exception('connection closed after receiving N bytes') + + return result + + def readline(self, size=-1): + remaining = object.__getattribute__(self, '_closeafterrecvbytes') + + # No read limit. Call original function. + if not remaining: + result = object.__getattribute__(self, '_orig').readline(size) + self._writelog( + b'readline(%d) -> (%d) %s' % (size, len(result), result) + ) + return result + + origsize = size + + if size < 0: + size = remaining + else: + size = min(remaining, size) + + result = object.__getattribute__(self, '_orig').readline(size) + remaining -= len(result) + + self._writelog( + b'readline(%d from %d) -> (%d) %s' + % (size, origsize, len(result), result) + ) + + object.__setattr__(self, '_closeafterrecvbytes', remaining) + + if remaining <= 0: + self._writelog(b'read limit reached; closing socket') + self._close() + + # This is the easiest way to abort the current request. + raise Exception('connection closed after receiving N bytes') + + return result + + def write(self, data): + remaining = object.__getattribute__(self, '_closeaftersendbytes') + + # No byte limit on this operation. Call original function. + if not remaining: + self._writelog(b'write(%d) -> %s' % (len(data), data)) + result = object.__getattribute__(self, '_orig').write(data) + return result + + if len(data) > remaining: + newdata = data[0:remaining] + else: + newdata = data + + remaining -= len(newdata) + + self._writelog( + b'write(%d from %d) -> (%d) %s' + % (len(newdata), len(data), remaining, newdata) + ) + + result = object.__getattribute__(self, '_orig').write(newdata) + + object.__setattr__(self, '_closeaftersendbytes', remaining) + + if remaining <= 0: + self._writelog(b'write limit reached; closing socket') + self._close() + + raise Exception('connection closed after sending N bytes') + + return result + + +def extsetup(ui): + # Change the base HTTP server class so various events can be performed. + # See SocketServer.BaseServer for how the specially named methods work. + class badserver(server.MercurialHTTPServer): + def __init__(self, ui, *args, **kwargs): + self._ui = ui + super(badserver, self).__init__(ui, *args, **kwargs) + + recvbytes = self._ui.config(b'badserver', b'closeafterrecvbytes') + recvbytes = recvbytes.split(b',') + self.closeafterrecvbytes = [int(v) for v in recvbytes if v] + sendbytes = self._ui.config(b'badserver', b'closeaftersendbytes') + sendbytes = sendbytes.split(b',') + self.closeaftersendbytes = [int(v) for v in sendbytes if v] + + # Need to inherit object so super() works. + class badrequesthandler(self.RequestHandlerClass, object): + def send_header(self, name, value): + # Make headers deterministic to facilitate testing. + if name.lower() == 'date': + value = 'Fri, 14 Apr 2017 00:00:00 GMT' + elif name.lower() == 'server': + value = 'badhttpserver' + + return super(badrequesthandler, self).send_header( + name, value + ) + + self.RequestHandlerClass = badrequesthandler + + # Called to accept() a pending socket. + def get_request(self): + if self._ui.configbool(b'badserver', b'closebeforeaccept'): + self.socket.close() + + # Tells the server to stop processing more requests. + self.__shutdown_request = True + + # Simulate failure to stop processing this request. + raise socket.error('close before accept') + + if self._ui.configbool(b'badserver', b'closeafteraccept'): + request, client_address = super(badserver, self).get_request() + request.close() + raise socket.error('close after accept') + + return super(badserver, self).get_request() + + # Does heavy lifting of processing a request. Invokes + # self.finish_request() which calls self.RequestHandlerClass() which + # is a hgweb.server._httprequesthandler. + def process_request(self, socket, address): + # Wrap socket in a proxy if we need to count bytes. + if self.closeafterrecvbytes: + closeafterrecvbytes = self.closeafterrecvbytes.pop(0) + else: + closeafterrecvbytes = 0 + if self.closeaftersendbytes: + closeaftersendbytes = self.closeaftersendbytes.pop(0) + else: + closeaftersendbytes = 0 + + if closeafterrecvbytes or closeaftersendbytes: + socket = socketproxy( + socket, + self.errorlog, + closeafterrecvbytes=closeafterrecvbytes, + closeaftersendbytes=closeaftersendbytes, + ) + + return super(badserver, self).process_request(socket, address) + + server.MercurialHTTPServer = badserver