wireprotoserver: extract SSH response handling functions
The lookup/dispatch table was cute. But it isn't needed. Future
refactors will benefit from the handlers for individual response
types living outside the class.
As part of this, I snuck in a change that changes a type compare
from str to bytes. This has no effect on Python 2. But it might
make Python 3 a bit happier.
Differential Revision: https://phab.mercurial-scm.org/D2091
--- a/mercurial/wireprotoserver.py Sat Dec 23 15:13:37 2017 +0530
+++ b/mercurial/wireprotoserver.py Wed Feb 07 21:04:54 2018 -0800
@@ -336,6 +336,24 @@
return ''
+def _sshv1respondbytes(fout, value):
+ """Send a bytes response for protocol version 1."""
+ fout.write('%d\n' % len(value))
+ fout.write(value)
+ fout.flush()
+
+def _sshv1respondstream(fout, source):
+ write = fout.write
+ for chunk in source.gen:
+ write(chunk)
+ fout.flush()
+
+def _sshv1respondooberror(fout, ferr, rsp):
+ ferr.write(b'%s\n-\n' % rsp)
+ ferr.flush()
+ fout.write(b'\n')
+ fout.flush()
+
class sshserver(baseprotocolhandler):
def __init__(self, ui, repo):
self._ui = ui
@@ -376,7 +394,7 @@
return [data[k] for k in keys]
def getfile(self, fpout):
- self._sendresponse('')
+ _sshv1respondbytes(self._fout, b'')
count = int(self._fin.readline())
while count:
fpout.write(self._fin.read(count))
@@ -385,51 +403,34 @@
def redirect(self):
pass
- def _sendresponse(self, v):
- self._fout.write("%d\n" % len(v))
- self._fout.write(v)
- self._fout.flush()
-
- def _sendstream(self, source):
- write = self._fout.write
- for chunk in source.gen:
- write(chunk)
- self._fout.flush()
-
- def _sendpushresponse(self, rsp):
- self._sendresponse('')
- self._sendresponse(str(rsp.res))
-
- def _sendpusherror(self, rsp):
- self._sendresponse(rsp.res)
-
- def _sendooberror(self, rsp):
- self._ui.ferr.write('%s\n-\n' % rsp.message)
- self._ui.ferr.flush()
- self._fout.write('\n')
- self._fout.flush()
-
def serve_forever(self):
while self.serve_one():
pass
sys.exit(0)
- _handlers = {
- str: _sendresponse,
- wireproto.streamres: _sendstream,
- wireproto.streamres_legacy: _sendstream,
- wireproto.pushres: _sendpushresponse,
- wireproto.pusherr: _sendpusherror,
- wireproto.ooberror: _sendooberror,
- }
-
def serve_one(self):
cmd = self._fin.readline()[:-1]
if cmd and wireproto.commands.commandavailable(cmd, self):
rsp = wireproto.dispatch(self._repo, self, cmd)
- self._handlers[rsp.__class__](self, rsp)
+
+ if isinstance(rsp, bytes):
+ _sshv1respondbytes(self._fout, rsp)
+ elif isinstance(rsp, wireproto.streamres):
+ _sshv1respondstream(self._fout, rsp)
+ elif isinstance(rsp, wireproto.streamres_legacy):
+ _sshv1respondstream(self._fout, rsp)
+ elif isinstance(rsp, wireproto.pushres):
+ _sshv1respondbytes(self._fout, b'')
+ _sshv1respondbytes(self._fout, bytes(rsp.res))
+ elif isinstance(rsp, wireproto.pusherr):
+ _sshv1respondbytes(self._fout, rsp.res)
+ elif isinstance(rsp, wireproto.ooberror):
+ _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
+ else:
+ raise error.ProgrammingError('unhandled response type from '
+ 'wire protocol command: %s' % rsp)
elif cmd:
- self._sendresponse("")
+ _sshv1respondbytes(self._fout, b'')
return cmd != ''
def _client(self):
--- a/tests/sshprotoext.py Sat Dec 23 15:13:37 2017 +0530
+++ b/tests/sshprotoext.py Wed Feb 07 21:04:54 2018 -0800
@@ -45,11 +45,11 @@
l = self._fin.readline()
assert l == b'hello\n'
# Respond to unknown commands with an empty reply.
- self._sendresponse(b'')
+ wireprotoserver._sshv1respondbytes(self._fout, b'')
l = self._fin.readline()
assert l == b'between\n'
rsp = wireproto.dispatch(self._repo, self, b'between')
- self._handlers[rsp.__class__](self, rsp)
+ wireprotoserver._sshv1respondbytes(self._fout, rsp)
super(prehelloserver, self).serve_forever()