mercurial/sshpeer.py
changeset 46701 db8037e38085
parent 46699 0738bc25d6ac
child 46702 a4c19a162615
--- a/mercurial/sshpeer.py	Mon Feb 15 14:15:02 2021 -0500
+++ b/mercurial/sshpeer.py	Mon Feb 15 14:40:17 2021 -0500
@@ -148,14 +148,18 @@
         return self._main.flush()
 
 
-def _cleanuppipes(ui, pipei, pipeo, pipee):
+def _cleanuppipes(ui, pipei, pipeo, pipee, warn):
     """Clean up pipes used by an SSH connection."""
-    if pipeo:
+    didsomething = False
+    if pipeo and not pipeo.closed:
+        didsomething = True
         pipeo.close()
-    if pipei:
+    if pipei and not pipei.closed:
+        didsomething = True
         pipei.close()
 
-    if pipee:
+    if pipee and not pipee.closed:
+        didsomething = True
         # Try to read from the err descriptor until EOF.
         try:
             for l in pipee:
@@ -165,6 +169,17 @@
 
         pipee.close()
 
+    if didsomething and warn is not None:
+        # Encourage explicit close of sshpeers. Closing via __del__ is
+        # not very predictable when exceptions are thrown, which has led
+        # to deadlocks due to a peer get gc'ed in a fork
+        # We add our own stack trace, because the stacktrace when called
+        # from __del__ is useless.
+        if False:  # enabled in next commit
+            ui.develwarn(
+                b'missing close on SSH connection created at:\n%s' % warn
+            )
+
 
 def _makeconnection(ui, sshcmd, args, remotecmd, path, sshenv=None):
     """Create an SSH connection to a server.
@@ -416,6 +431,7 @@
         self._pipee = stderr
         self._caps = caps
         self._autoreadstderr = autoreadstderr
+        self._initstack = b''.join(util.getstackframes(1))
 
     # Commands that have a "framed" response where the first line of the
     # response contains the length of that response.
@@ -456,10 +472,11 @@
         self._cleanup()
         raise exception
 
-    def _cleanup(self):
-        _cleanuppipes(self.ui, self._pipei, self._pipeo, self._pipee)
+    def _cleanup(self, warn=None):
+        _cleanuppipes(self.ui, self._pipei, self._pipeo, self._pipee, warn=warn)
 
-    __del__ = _cleanup
+    def __del__(self):
+        self._cleanup(warn=self._initstack)
 
     def _sendrequest(self, cmd, args, framed=False):
         if self.ui.debugflag and self.ui.configbool(
@@ -611,7 +628,7 @@
     try:
         protoname, caps = _performhandshake(ui, stdin, stdout, stderr)
     except Exception:
-        _cleanuppipes(ui, stdout, stdin, stderr)
+        _cleanuppipes(ui, stdout, stdin, stderr, warn=None)
         raise
 
     if protoname == wireprototypes.SSHV1:
@@ -637,7 +654,7 @@
             autoreadstderr=autoreadstderr,
         )
     else:
-        _cleanuppipes(ui, stdout, stdin, stderr)
+        _cleanuppipes(ui, stdout, stdin, stderr, warn=None)
         raise error.RepoError(
             _(b'unknown version of SSH protocol: %s') % protoname
         )