wireprotoserver: ensure that output stream gets flushed on exception
Previously flush was happening due to Python finalizer being run on
`BufferedWriter`. With upgrade to Python 3.11 this started randomly
failing.
My guess is that the finalizer on the raw `FileIO` object may
be running before the finalizer of `BufferedWriter` has a chance to run.
At any rate, since we're not relying on finalizers in the happy case
we should also not rely on them in case of exception.
--- a/mercurial/wireprotoserver.py Mon Apr 15 16:33:37 2024 +0100
+++ b/mercurial/wireprotoserver.py Thu Apr 04 14:15:32 2024 +0100
@@ -527,24 +527,34 @@
def __init__(self, ui, repo, logfh=None, accesshidden=False):
self._ui = ui
self._repo = repo
- self._fin, self._fout = ui.protectfinout()
self._accesshidden = accesshidden
-
- # Log write I/O to stdout and stderr if configured.
- if logfh:
- self._fout = util.makeloggingfileobject(
- logfh, self._fout, b'o', logdata=True
- )
- ui.ferr = util.makeloggingfileobject(
- logfh, ui.ferr, b'e', logdata=True
- )
+ self._logfh = logfh
def serve_forever(self):
self.serveuntil(threading.Event())
- self._ui.restorefinout(self._fin, self._fout)
def serveuntil(self, ev):
"""Serve until a threading.Event is set."""
- _runsshserver(
- self._ui, self._repo, self._fin, self._fout, ev, self._accesshidden
- )
+ with self._ui.protectedfinout() as (fin, fout):
+ if self._logfh:
+ # Log write I/O to stdout and stderr if configured.
+ fout = util.makeloggingfileobject(
+ self._logfh,
+ fout,
+ b'o',
+ logdata=True,
+ )
+ self._ui.ferr = util.makeloggingfileobject(
+ self._logfh,
+ self._ui.ferr,
+ b'e',
+ logdata=True,
+ )
+ _runsshserver(
+ self._ui,
+ self._repo,
+ fin,
+ fout,
+ ev,
+ self._accesshidden,
+ )
--- a/tests/sshprotoext.py Mon Apr 15 16:33:37 2024 +0100
+++ b/tests/sshprotoext.py Thu Apr 04 14:15:32 2024 +0100
@@ -30,7 +30,7 @@
def serve_forever(self):
for i in range(10):
- self._fout.write(b'banner: line %d\n' % i)
+ self._ui.fout.write(b'banner: line %d\n' % i)
super(bannerserver, self).serve_forever()
@@ -45,17 +45,16 @@
"""
def serve_forever(self):
- l = self._fin.readline()
+ ui = self._ui
+ l = ui.fin.readline()
assert l == b'hello\n'
# Respond to unknown commands with an empty reply.
- wireprotoserver._sshv1respondbytes(self._fout, b'')
- l = self._fin.readline()
+ wireprotoserver._sshv1respondbytes(ui.fout, b'')
+ l = ui.fin.readline()
assert l == b'between\n'
- proto = wireprotoserver.sshv1protocolhandler(
- self._ui, self._fin, self._fout
- )
+ proto = wireprotoserver.sshv1protocolhandler(ui, ui.fin, ui.fout)
rsp = wireprotov1server.dispatch(self._repo, proto, b'between')
- wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
+ wireprotoserver._sshv1respondbytes(ui.fout, rsp.data)
super(prehelloserver, self).serve_forever()
--- a/tests/test-sshserver.py Mon Apr 15 16:33:37 2024 +0100
+++ b/tests/test-sshserver.py Thu Apr 04 14:15:32 2024 +0100
@@ -25,9 +25,8 @@
def assertparse(self, cmd, input, expected):
server = mockserver(input)
- proto = wireprotoserver.sshv1protocolhandler(
- server._ui, server._fin, server._fout
- )
+ ui = server._ui
+ proto = wireprotoserver.sshv1protocolhandler(ui, ui.fin, ui.fout)
_func, spec = wireprotov1server.commands[cmd]
self.assertEqual(proto.getargs(spec), expected)
@@ -35,6 +34,9 @@
def mockserver(inbytes):
ui = mockui(inbytes)
repo = mockrepo(ui)
+ # note: this test unfortunately doesn't really test anything about
+ # `sshserver` class anymore: the entirety of logic of that class lives
+ # in `serveuntil`, and that function is not even called by this test.
return wireprotoserver.sshserver(ui, repo)