Mercurial > hg
changeset 36214:3b3a987bbbaa
wireprotoserver: move SSH server operation to a standalone function
The server-side processing logic will soon get a bit more complicated
in order to handle protocol switches. We will use a state machine
to help make the transitions clearer.
To prepare for this, we move SSH server operation into a standalone
function. We structure it as a very simple state machine. It only
has two states for now, with one state containing the bulk of the
logic. But things will evolve shortly.
Differential Revision: https://phab.mercurial-scm.org/D2203
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Thu, 08 Feb 2018 15:09:59 -0800 |
parents | b67d4b7e8235 |
children | 464bedc0fdb4 |
files | mercurial/wireprotoserver.py tests/sshprotoext.py tests/test-sshserver.py |
diffstat | 3 files changed, 61 insertions(+), 34 deletions(-) [+] |
line wrap: on
line diff
--- a/mercurial/wireprotoserver.py Wed Feb 14 17:35:13 2018 -0700 +++ b/mercurial/wireprotoserver.py Thu Feb 08 15:09:59 2018 -0800 @@ -409,6 +409,56 @@ client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] return 'remote:ssh:' + client +def _runsshserver(ui, repo, fin, fout): + state = 'protov1-serving' + proto = sshv1protocolhandler(ui, fin, fout) + + while True: + if state == 'protov1-serving': + # Commands are issued on new lines. + request = fin.readline()[:-1] + + # Empty lines signal to terminate the connection. + if not request: + state = 'shutdown' + continue + + available = wireproto.commands.commandavailable(request, proto) + + # This command isn't available. Send an empty response and go + # back to waiting for a new command. + if not available: + _sshv1respondbytes(fout, b'') + continue + + rsp = wireproto.dispatch(repo, proto, request) + + if isinstance(rsp, bytes): + _sshv1respondbytes(fout, rsp) + elif isinstance(rsp, wireprototypes.bytesresponse): + _sshv1respondbytes(fout, rsp.data) + elif isinstance(rsp, wireprototypes.streamres): + _sshv1respondstream(fout, rsp) + elif isinstance(rsp, wireprototypes.streamreslegacy): + _sshv1respondstream(fout, rsp) + elif isinstance(rsp, wireprototypes.pushres): + _sshv1respondbytes(fout, b'') + _sshv1respondbytes(fout, b'%d' % rsp.res) + elif isinstance(rsp, wireprototypes.pusherr): + _sshv1respondbytes(fout, rsp.res) + elif isinstance(rsp, wireprototypes.ooberror): + _sshv1respondooberror(fout, ui.ferr, rsp.message) + else: + raise error.ProgrammingError('unhandled response type from ' + 'wire protocol command: %s' % rsp) + + elif state == 'shutdown': + break + + else: + raise error.ProgrammingError('unhandled ssh server state: %s' % + state) + class sshserver(object): def __init__(self, ui, repo): self._ui = ui @@ -423,36 +473,6 @@ util.setbinary(self._fin) util.setbinary(self._fout) - self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout) - def serve_forever(self): - while self.serve_one(): - pass + _runsshserver(self._ui, self._repo, self._fin, self._fout) sys.exit(0) - - def serve_one(self): - cmd = self._fin.readline()[:-1] - if cmd and wireproto.commands.commandavailable(cmd, self._proto): - rsp = wireproto.dispatch(self._repo, self._proto, cmd) - - if isinstance(rsp, bytes): - _sshv1respondbytes(self._fout, rsp) - elif isinstance(rsp, wireprototypes.bytesresponse): - _sshv1respondbytes(self._fout, rsp.data) - elif isinstance(rsp, wireprototypes.streamres): - _sshv1respondstream(self._fout, rsp) - elif isinstance(rsp, wireprototypes.streamreslegacy): - _sshv1respondstream(self._fout, rsp) - elif isinstance(rsp, wireprototypes.pushres): - _sshv1respondbytes(self._fout, b'') - _sshv1respondbytes(self._fout, b'%d' % rsp.res) - elif isinstance(rsp, wireprototypes.pusherr): - _sshv1respondbytes(self._fout, rsp.res) - elif isinstance(rsp, wireprototypes.ooberror): - _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message) - else: - raise error.ProgrammingError('unhandled response type from ' - 'wire protocol command: %s' % rsp) - elif cmd: - _sshv1respondbytes(self._fout, b'') - return cmd != ''
--- a/tests/sshprotoext.py Wed Feb 14 17:35:13 2018 -0700 +++ b/tests/sshprotoext.py Thu Feb 08 15:09:59 2018 -0800 @@ -48,7 +48,9 @@ wireprotoserver._sshv1respondbytes(self._fout, b'') l = self._fin.readline() assert l == b'between\n' - rsp = wireproto.dispatch(self._repo, self._proto, b'between') + proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin, + self._fout) + rsp = wireproto.dispatch(self._repo, proto, b'between') wireprotoserver._sshv1respondbytes(self._fout, rsp.data) super(prehelloserver, self).serve_forever() @@ -72,8 +74,10 @@ self._fin.read(81) # Send the upgrade response. + proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin, + self._fout) self._fout.write(b'upgraded %s %s\n' % (token, name)) - servercaps = wireproto.capabilities(self._repo, self._proto) + servercaps = wireproto.capabilities(self._repo, proto) rsp = b'capabilities: %s' % servercaps.data self._fout.write(b'%d\n' % len(rsp)) self._fout.write(rsp)
--- a/tests/test-sshserver.py Wed Feb 14 17:35:13 2018 -0700 +++ b/tests/test-sshserver.py Thu Feb 08 15:09:59 2018 -0800 @@ -23,8 +23,11 @@ def assertparse(self, cmd, input, expected): server = mockserver(input) + proto = wireprotoserver.sshv1protocolhandler(server._ui, + server._fin, + server._fout) _func, spec = wireproto.commands[cmd] - self.assertEqual(server._proto.getargs(spec), expected) + self.assertEqual(proto.getargs(spec), expected) def mockserver(inbytes): ui = mockui(inbytes)