changeset 36065:bf676267f64f

wireprotoserver: split ssh protocol handler and server We want to formalize the interface for protocol handlers. Today, server functionality (which is domain specific) is interleaved with protocol handling functionality (which conforms to a generic interface) in the sshserver class. This commit splits the ssh protocol handling code out of the sshserver class. Differential Revision: https://phab.mercurial-scm.org/D2080
author Gregory Szorc <gregory.szorc@gmail.com>
date Wed, 07 Feb 2018 20:17:05 -0800
parents 5767664d39a5
children 2ad145fbde54
files mercurial/wireprotoserver.py tests/sshprotoext.py tests/test-sshserver.py
diffstat 3 files changed, 30 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/wireprotoserver.py	Wed Feb 07 21:04:54 2018 -0800
+++ b/mercurial/wireprotoserver.py	Wed Feb 07 20:17:05 2018 -0800
@@ -354,19 +354,12 @@
     fout.write(b'\n')
     fout.flush()
 
-class sshserver(baseprotocolhandler):
-    def __init__(self, ui, repo):
+class sshv1protocolhandler(baseprotocolhandler):
+    """Handler for requests services via version 1 of SSH protocol."""
+    def __init__(self, ui, fin, fout):
         self._ui = ui
-        self._repo = repo
-        self._fin = ui.fin
-        self._fout = ui.fout
-
-        hook.redirect(True)
-        ui.fout = repo.ui.fout = ui.ferr
-
-        # Prevent insertion/deletion of CRs
-        util.setbinary(self._fin)
-        util.setbinary(self._fout)
+        self._fin = fin
+        self._fout = fout
 
     @property
     def name(self):
@@ -403,6 +396,26 @@
     def redirect(self):
         pass
 
+    def _client(self):
+        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
+        return 'remote:ssh:' + client
+
+class sshserver(object):
+    def __init__(self, ui, repo):
+        self._ui = ui
+        self._repo = repo
+        self._fin = ui.fin
+        self._fout = ui.fout
+
+        hook.redirect(True)
+        ui.fout = repo.ui.fout = ui.ferr
+
+        # Prevent insertion/deletion of CRs
+        util.setbinary(self._fin)
+        util.setbinary(self._fout)
+
+        self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
+
     def serve_forever(self):
         while self.serve_one():
             pass
@@ -410,8 +423,8 @@
 
     def serve_one(self):
         cmd = self._fin.readline()[:-1]
-        if cmd and wireproto.commands.commandavailable(cmd, self):
-            rsp = wireproto.dispatch(self._repo, self, cmd)
+        if cmd and wireproto.commands.commandavailable(cmd, self._proto):
+            rsp = wireproto.dispatch(self._repo, self._proto, cmd)
 
             if isinstance(rsp, bytes):
                 _sshv1respondbytes(self._fout, rsp)
@@ -432,7 +445,3 @@
         elif cmd:
             _sshv1respondbytes(self._fout, b'')
         return cmd != ''
-
-    def _client(self):
-        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
-        return 'remote:ssh:' + client
--- a/tests/sshprotoext.py	Wed Feb 07 21:04:54 2018 -0800
+++ b/tests/sshprotoext.py	Wed Feb 07 20:17:05 2018 -0800
@@ -48,7 +48,7 @@
         wireprotoserver._sshv1respondbytes(self._fout, b'')
         l = self._fin.readline()
         assert l == b'between\n'
-        rsp = wireproto.dispatch(self._repo, self, b'between')
+        rsp = wireproto.dispatch(self._repo, self._proto, b'between')
         wireprotoserver._sshv1respondbytes(self._fout, rsp)
 
         super(prehelloserver, self).serve_forever()
@@ -73,7 +73,7 @@
 
         # Send the upgrade response.
         self._fout.write(b'upgraded %s %s\n' % (token, name))
-        servercaps = wireproto.capabilities(self._repo, self)
+        servercaps = wireproto.capabilities(self._repo, self._proto)
         rsp = b'capabilities: %s' % servercaps
         self._fout.write(b'%d\n' % len(rsp))
         self._fout.write(rsp)
--- a/tests/test-sshserver.py	Wed Feb 07 21:04:54 2018 -0800
+++ b/tests/test-sshserver.py	Wed Feb 07 20:17:05 2018 -0800
@@ -24,7 +24,7 @@
     def assertparse(self, cmd, input, expected):
         server = mockserver(input)
         _func, spec = wireproto.commands[cmd]
-        self.assertEqual(server.getargs(spec), expected)
+        self.assertEqual(server._proto.getargs(spec), expected)
 
 def mockserver(inbytes):
     ui = mockui(inbytes)