diff mercurial/sshpeer.py @ 35938:80a2b8ae42a1

sshpeer: move handshake outside of sshpeer With the handshake now performed before a peer class is instantiated, we can now instantiate a different peer class depending on the results of the handshake. Our test extension had to change to cope with the new API. Because we now issue the command via raw I/O calls and don't call _callstream(), we no longer have to register the fake command. (_callstream() uses the command registration to see what args to send). Differential Revision: https://phab.mercurial-scm.org/D2034
author Gregory Szorc <gregory.szorc@gmail.com>
date Mon, 05 Feb 2018 09:14:32 -0800
parents a9cffd14aa04
children a622a927fe03
line wrap: on
line diff
--- a/mercurial/sshpeer.py	Sun Feb 04 14:10:56 2018 -0800
+++ b/mercurial/sshpeer.py	Mon Feb 05 09:14:32 2018 -0800
@@ -156,13 +156,69 @@
 
     return proc, stdin, stdout, stderr
 
+def _performhandshake(ui, stdin, stdout, stderr):
+    def badresponse():
+        msg = _('no suitable response from remote hg')
+        hint = ui.config('ui', 'ssherrorhint')
+        raise error.RepoError(msg, hint=hint)
+
+    requestlog = ui.configbool('devel', 'debug.peer-request')
+
+    try:
+        pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
+        handshake = [
+            'hello\n',
+            'between\n',
+            'pairs %d\n' % len(pairsarg),
+            pairsarg,
+        ]
+
+        if requestlog:
+            ui.debug('devel-peer-request: hello\n')
+        ui.debug('sending hello command\n')
+        if requestlog:
+            ui.debug('devel-peer-request: between\n')
+            ui.debug('devel-peer-request:   pairs: %d bytes\n' % len(pairsarg))
+        ui.debug('sending between command\n')
+
+        stdin.write(''.join(handshake))
+        stdin.flush()
+    except IOError:
+        badresponse()
+
+    lines = ['', 'dummy']
+    max_noise = 500
+    while lines[-1] and max_noise:
+        try:
+            l = stdout.readline()
+            _forwardoutput(ui, stderr)
+            if lines[-1] == '1\n' and l == '\n':
+                break
+            if l:
+                ui.debug('remote: ', l)
+            lines.append(l)
+            max_noise -= 1
+        except IOError:
+            badresponse()
+    else:
+        badresponse()
+
+    caps = set()
+    for l in reversed(lines):
+        if l.startswith('capabilities:'):
+            caps.update(l[:-1].split(':')[1].split())
+            break
+
+    return caps
+
 class sshpeer(wireproto.wirepeer):
-    def __init__(self, ui, url, proc, stdin, stdout, stderr):
+    def __init__(self, ui, url, proc, stdin, stdout, stderr, caps):
         """Create a peer from an existing SSH connection.
 
         ``proc`` is a handle on the underlying SSH process.
         ``stdin``, ``stdout``, and ``stderr`` are handles on the stdio
         pipes for that process.
+        ``caps`` is a set of capabilities supported by the remote.
         """
         self._url = url
         self._ui = ui
@@ -172,8 +228,7 @@
         self._pipeo = stdin
         self._pipei = stdout
         self._pipee = stderr
-
-        self._validaterepo()
+        self._caps = caps
 
     # Begin of _basepeer interface.
 
@@ -205,61 +260,6 @@
 
     # End of _basewirecommands interface.
 
-    def _validaterepo(self):
-        def badresponse():
-            msg = _("no suitable response from remote hg")
-            hint = self.ui.config("ui", "ssherrorhint")
-            self._abort(error.RepoError(msg, hint=hint))
-
-        try:
-            pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
-
-            handshake = [
-                'hello\n',
-                'between\n',
-                'pairs %d\n' % len(pairsarg),
-                pairsarg,
-            ]
-
-            requestlog = self.ui.configbool('devel', 'debug.peer-request')
-
-            if requestlog:
-                self.ui.debug('devel-peer-request: hello\n')
-            self.ui.debug('sending hello command\n')
-            if requestlog:
-                self.ui.debug('devel-peer-request: between\n')
-                self.ui.debug('devel-peer-request:   pairs: %d bytes\n' %
-                              len(pairsarg))
-            self.ui.debug('sending between command\n')
-
-            self._pipeo.write(''.join(handshake))
-            self._pipeo.flush()
-        except IOError:
-            badresponse()
-
-        lines = ["", "dummy"]
-        max_noise = 500
-        while lines[-1] and max_noise:
-            try:
-                l = self._pipei.readline()
-                _forwardoutput(self.ui, self._pipee)
-                if lines[-1] == "1\n" and l == "\n":
-                    break
-                if l:
-                    self.ui.debug("remote: ", l)
-                lines.append(l)
-                max_noise -= 1
-            except IOError:
-                badresponse()
-        else:
-            badresponse()
-
-        self._caps = set()
-        for l in reversed(lines):
-            if l.startswith("capabilities:"):
-                self._caps.update(l[:-1].split(":")[1].split())
-                break
-
     def _readerr(self):
         _forwardoutput(self.ui, self._pipee)
 
@@ -414,4 +414,10 @@
     proc, stdin, stdout, stderr = _makeconnection(ui, sshcmd, args, remotecmd,
                                                   remotepath, sshenv)
 
-    return sshpeer(ui, path, proc, stdin, stdout, stderr)
+    try:
+        caps = _performhandshake(ui, stdin, stdout, stderr)
+    except Exception:
+        _cleanuppipes(ui, stdout, stdin, stderr)
+        raise
+
+    return sshpeer(ui, path, proc, stdin, stdout, stderr, caps)