wireprotoserver: split ssh protocol handler and server
authorGregory Szorc <gregory.szorc@gmail.com>
Wed, 07 Feb 2018 20:17:05 -0800
changeset 36065 bf676267f64f
parent 36064 5767664d39a5
child 36066 2ad145fbde54
wireprotoserver: split ssh protocol handler and server We want to formalize the interface for protocol handlers. Today, server functionality (which is domain specific) is interleaved with protocol handling functionality (which conforms to a generic interface) in the sshserver class. This commit splits the ssh protocol handling code out of the sshserver class. Differential Revision: https://phab.mercurial-scm.org/D2080
mercurial/wireprotoserver.py
tests/sshprotoext.py
tests/test-sshserver.py
--- a/mercurial/wireprotoserver.py	Wed Feb 07 21:04:54 2018 -0800
+++ b/mercurial/wireprotoserver.py	Wed Feb 07 20:17:05 2018 -0800
@@ -354,19 +354,12 @@
     fout.write(b'\n')
     fout.flush()
 
-class sshserver(baseprotocolhandler):
-    def __init__(self, ui, repo):
+class sshv1protocolhandler(baseprotocolhandler):
+    """Handler for requests services via version 1 of SSH protocol."""
+    def __init__(self, ui, fin, fout):
         self._ui = ui
-        self._repo = repo
-        self._fin = ui.fin
-        self._fout = ui.fout
-
-        hook.redirect(True)
-        ui.fout = repo.ui.fout = ui.ferr
-
-        # Prevent insertion/deletion of CRs
-        util.setbinary(self._fin)
-        util.setbinary(self._fout)
+        self._fin = fin
+        self._fout = fout
 
     @property
     def name(self):
@@ -403,6 +396,26 @@
     def redirect(self):
         pass
 
+    def _client(self):
+        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
+        return 'remote:ssh:' + client
+
+class sshserver(object):
+    def __init__(self, ui, repo):
+        self._ui = ui
+        self._repo = repo
+        self._fin = ui.fin
+        self._fout = ui.fout
+
+        hook.redirect(True)
+        ui.fout = repo.ui.fout = ui.ferr
+
+        # Prevent insertion/deletion of CRs
+        util.setbinary(self._fin)
+        util.setbinary(self._fout)
+
+        self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
+
     def serve_forever(self):
         while self.serve_one():
             pass
@@ -410,8 +423,8 @@
 
     def serve_one(self):
         cmd = self._fin.readline()[:-1]
-        if cmd and wireproto.commands.commandavailable(cmd, self):
-            rsp = wireproto.dispatch(self._repo, self, cmd)
+        if cmd and wireproto.commands.commandavailable(cmd, self._proto):
+            rsp = wireproto.dispatch(self._repo, self._proto, cmd)
 
             if isinstance(rsp, bytes):
                 _sshv1respondbytes(self._fout, rsp)
@@ -432,7 +445,3 @@
         elif cmd:
             _sshv1respondbytes(self._fout, b'')
         return cmd != ''
-
-    def _client(self):
-        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
-        return 'remote:ssh:' + client
--- a/tests/sshprotoext.py	Wed Feb 07 21:04:54 2018 -0800
+++ b/tests/sshprotoext.py	Wed Feb 07 20:17:05 2018 -0800
@@ -48,7 +48,7 @@
         wireprotoserver._sshv1respondbytes(self._fout, b'')
         l = self._fin.readline()
         assert l == b'between\n'
-        rsp = wireproto.dispatch(self._repo, self, b'between')
+        rsp = wireproto.dispatch(self._repo, self._proto, b'between')
         wireprotoserver._sshv1respondbytes(self._fout, rsp)
 
         super(prehelloserver, self).serve_forever()
@@ -73,7 +73,7 @@
 
         # Send the upgrade response.
         self._fout.write(b'upgraded %s %s\n' % (token, name))
-        servercaps = wireproto.capabilities(self._repo, self)
+        servercaps = wireproto.capabilities(self._repo, self._proto)
         rsp = b'capabilities: %s' % servercaps
         self._fout.write(b'%d\n' % len(rsp))
         self._fout.write(rsp)
--- a/tests/test-sshserver.py	Wed Feb 07 21:04:54 2018 -0800
+++ b/tests/test-sshserver.py	Wed Feb 07 20:17:05 2018 -0800
@@ -24,7 +24,7 @@
     def assertparse(self, cmd, input, expected):
         server = mockserver(input)
         _func, spec = wireproto.commands[cmd]
-        self.assertEqual(server.getargs(spec), expected)
+        self.assertEqual(server._proto.getargs(spec), expected)
 
 def mockserver(inbytes):
     ui = mockui(inbytes)