--- a/tests/sshprotoext.py Sun Feb 04 14:10:56 2018 -0800
+++ b/tests/sshprotoext.py Mon Feb 05 09:14:32 2018 -0800
@@ -12,6 +12,7 @@
from mercurial import (
error,
+ extensions,
registrar,
sshpeer,
wireproto,
@@ -52,30 +53,26 @@
super(prehelloserver, self).serve_forever()
-class extrahandshakecommandspeer(sshpeer.sshpeer):
- """An ssh peer that sends extra commands as part of initial handshake."""
- def _validaterepo(self):
- mode = self._ui.config(b'sshpeer', b'handshake-mode')
- if mode == b'pre-no-args':
- self._callstream(b'no-args')
- return super(extrahandshakecommandspeer, self)._validaterepo()
- elif mode == b'pre-multiple-no-args':
- self._callstream(b'unknown1')
- self._callstream(b'unknown2')
- self._callstream(b'unknown3')
- return super(extrahandshakecommandspeer, self)._validaterepo()
- else:
- raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
- mode)
-
-def registercommands():
- def dummycommand(repo, proto):
- raise error.ProgrammingError('this should never be called')
-
- wireproto.wireprotocommand(b'no-args', b'')(dummycommand)
- wireproto.wireprotocommand(b'unknown1', b'')(dummycommand)
- wireproto.wireprotocommand(b'unknown2', b'')(dummycommand)
- wireproto.wireprotocommand(b'unknown3', b'')(dummycommand)
+def performhandshake(orig, ui, stdin, stdout, stderr):
+ """Wrapped version of sshpeer._performhandshake to send extra commands."""
+ mode = ui.config(b'sshpeer', b'handshake-mode')
+ if mode == b'pre-no-args':
+ ui.debug(b'sending no-args command\n')
+ stdin.write(b'no-args\n')
+ stdin.flush()
+ return orig(ui, stdin, stdout, stderr)
+ elif mode == b'pre-multiple-no-args':
+ ui.debug(b'sending unknown1 command\n')
+ stdin.write(b'unknown1\n')
+ ui.debug(b'sending unknown2 command\n')
+ stdin.write(b'unknown2\n')
+ ui.debug(b'sending unknown3 command\n')
+ stdin.write(b'unknown3\n')
+ stdin.flush()
+ return orig(ui, stdin, stdout, stderr)
+ else:
+ raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' %
+ mode)
def extsetup(ui):
# It's easier for tests to define the server behavior via environment
@@ -94,7 +91,6 @@
peermode = ui.config(b'sshpeer', b'mode')
if peermode == b'extra-handshake-commands':
- sshpeer.sshpeer = extrahandshakecommandspeer
- registercommands()
+ extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake)
elif peermode:
raise error.ProgrammingError(b'unknown peer mode: %s' % peermode)