Mercurial > hg
changeset 36064:5767664d39a5
wireprotoserver: extract SSH response handling functions
The lookup/dispatch table was cute. But it isn't needed. Future
refactors will benefit from the handlers for individual response
types living outside the class.
As part of this, I snuck in a change that changes a type compare
from str to bytes. This has no effect on Python 2. But it might
make Python 3 a bit happier.
Differential Revision: https://phab.mercurial-scm.org/D2091
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Wed, 07 Feb 2018 21:04:54 -0800 |
parents | 5a53af7d09aa |
children | bf676267f64f |
files | mercurial/wireprotoserver.py tests/sshprotoext.py |
diffstat | 2 files changed, 39 insertions(+), 38 deletions(-) [+] |
line wrap: on
line diff
--- a/mercurial/wireprotoserver.py Sat Dec 23 15:13:37 2017 +0530 +++ b/mercurial/wireprotoserver.py Wed Feb 07 21:04:54 2018 -0800 @@ -336,6 +336,24 @@ return '' +def _sshv1respondbytes(fout, value): + """Send a bytes response for protocol version 1.""" + fout.write('%d\n' % len(value)) + fout.write(value) + fout.flush() + +def _sshv1respondstream(fout, source): + write = fout.write + for chunk in source.gen: + write(chunk) + fout.flush() + +def _sshv1respondooberror(fout, ferr, rsp): + ferr.write(b'%s\n-\n' % rsp) + ferr.flush() + fout.write(b'\n') + fout.flush() + class sshserver(baseprotocolhandler): def __init__(self, ui, repo): self._ui = ui @@ -376,7 +394,7 @@ return [data[k] for k in keys] def getfile(self, fpout): - self._sendresponse('') + _sshv1respondbytes(self._fout, b'') count = int(self._fin.readline()) while count: fpout.write(self._fin.read(count)) @@ -385,51 +403,34 @@ def redirect(self): pass - def _sendresponse(self, v): - self._fout.write("%d\n" % len(v)) - self._fout.write(v) - self._fout.flush() - - def _sendstream(self, source): - write = self._fout.write - for chunk in source.gen: - write(chunk) - self._fout.flush() - - def _sendpushresponse(self, rsp): - self._sendresponse('') - self._sendresponse(str(rsp.res)) - - def _sendpusherror(self, rsp): - self._sendresponse(rsp.res) - - def _sendooberror(self, rsp): - self._ui.ferr.write('%s\n-\n' % rsp.message) - self._ui.ferr.flush() - self._fout.write('\n') - self._fout.flush() - def serve_forever(self): while self.serve_one(): pass sys.exit(0) - _handlers = { - str: _sendresponse, - wireproto.streamres: _sendstream, - wireproto.streamres_legacy: _sendstream, - wireproto.pushres: _sendpushresponse, - wireproto.pusherr: _sendpusherror, - wireproto.ooberror: _sendooberror, - } - def serve_one(self): cmd = self._fin.readline()[:-1] if cmd and wireproto.commands.commandavailable(cmd, self): rsp = wireproto.dispatch(self._repo, self, cmd) - self._handlers[rsp.__class__](self, rsp) + + if isinstance(rsp, bytes): + _sshv1respondbytes(self._fout, rsp) + elif isinstance(rsp, wireproto.streamres): + _sshv1respondstream(self._fout, rsp) + elif isinstance(rsp, wireproto.streamres_legacy): + _sshv1respondstream(self._fout, rsp) + elif isinstance(rsp, wireproto.pushres): + _sshv1respondbytes(self._fout, b'') + _sshv1respondbytes(self._fout, bytes(rsp.res)) + elif isinstance(rsp, wireproto.pusherr): + _sshv1respondbytes(self._fout, rsp.res) + elif isinstance(rsp, wireproto.ooberror): + _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message) + else: + raise error.ProgrammingError('unhandled response type from ' + 'wire protocol command: %s' % rsp) elif cmd: - self._sendresponse("") + _sshv1respondbytes(self._fout, b'') return cmd != '' def _client(self):
--- a/tests/sshprotoext.py Sat Dec 23 15:13:37 2017 +0530 +++ b/tests/sshprotoext.py Wed Feb 07 21:04:54 2018 -0800 @@ -45,11 +45,11 @@ l = self._fin.readline() assert l == b'hello\n' # Respond to unknown commands with an empty reply. - self._sendresponse(b'') + wireprotoserver._sshv1respondbytes(self._fout, b'') l = self._fin.readline() assert l == b'between\n' rsp = wireproto.dispatch(self._repo, self, b'between') - self._handlers[rsp.__class__](self, rsp) + wireprotoserver._sshv1respondbytes(self._fout, rsp) super(prehelloserver, self).serve_forever()