diff hgext3rd/topic/discovery.py @ 5683:1dece375d2ab

topic: extract awful `ctx.branch` hijacking used in discovery We will need it for `hg summary`, so we start by making it its own context manager.
author Pierre-Yves David <pierre-yves.david@octobus.net>
date Wed, 23 Dec 2020 13:36:02 +0100
parents 36ccafa69095
children 4de216446c53
line wrap: on
line diff
--- a/hgext3rd/topic/discovery.py	Wed Dec 23 21:34:22 2020 +0800
+++ b/hgext3rd/topic/discovery.py	Wed Dec 23 13:36:02 2020 +0100
@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 
 import collections
+import contextlib
 import weakref
 
 from mercurial.i18n import _
@@ -20,6 +21,59 @@
 
 from mercurial import wireprotov1server
 
+@contextlib.contextmanager
+def override_context_branch(repo, publishedset=()):
+    unfi = repo.unfiltered()
+
+    class repocls(unfi.__class__):
+        # awful hack to see branch as "branch:topic"
+        def __getitem__(self, key):
+            ctx = super(repocls, self).__getitem__(key)
+            oldbranch = ctx.branch
+            rev = ctx.rev()
+
+            def branch():
+                branch = oldbranch()
+                if rev in publishedset:
+                    return branch
+                topic = ctx.topic()
+                if topic:
+                    branch = b"%s:%s" % (branch, topic)
+                return branch
+
+            ctx.branch = branch
+            return ctx
+
+        def revbranchcache(self):
+            rbc = super(repocls, self).revbranchcache()
+            localchangelog = self.changelog
+
+            def branchinfo(rev, changelog=None):
+                if changelog is None:
+                    changelog = localchangelog
+                branch, close = changelog.branchinfo(rev)
+                if rev in publishedset:
+                    return branch, close
+                topic = unfi[rev].topic()
+                if topic:
+                    branch = b"%s:%s" % (branch, topic)
+                return branch, close
+
+            rbc.branchinfo = branchinfo
+            return rbc
+
+    oldrepocls = unfi.__class__
+    try:
+        unfi.__class__ = repocls
+        if repo.filtername is not None:
+            repo = unfi.filtered(repo.filtername)
+        else:
+            repo = unfi
+        yield repo
+    finally:
+        unfi.__class__ = oldrepocls
+
+
 def _headssummary(orig, pushop, *args, **kwargs):
     repo = pushop.repo.unfiltered()
     remote = pushop.remote
@@ -61,63 +115,24 @@
             heads.sort()
         return result
 
-    class repocls(repo.__class__):
-        # awful hack to see branch as "branch:topic"
-        def __getitem__(self, key):
-            ctx = super(repocls, self).__getitem__(key)
-            oldbranch = ctx.branch
-            rev = ctx.rev()
-
-            def branch():
-                branch = oldbranch()
-                if rev in publishedset:
-                    return branch
-                topic = ctx.topic()
-                if topic:
-                    branch = b"%s:%s" % (branch, topic)
-                return branch
-
-            ctx.branch = branch
-            return ctx
-
-        def revbranchcache(self):
-            rbc = super(repocls, self).revbranchcache()
-            localchangelog = self.changelog
-
-            def branchinfo(rev, changelog=None):
-                if changelog is None:
-                    changelog = localchangelog
-                branch, close = changelog.branchinfo(rev)
-                if rev in publishedset:
-                    return branch, close
-                topic = repo[rev].topic()
-                if topic:
-                    branch = b"%s:%s" % (branch, topic)
-                return branch, close
-
-            rbc.branchinfo = branchinfo
-            return rbc
-
-    oldrepocls = repo.__class__
-    try:
-        repo.__class__ = repocls
-        if remotebranchmap is not None:
-            remote.branchmap = remotebranchmap
-        unxx = repo.filtered(b'unfiltered-topic')
-        repo.unfiltered = lambda: unxx
-        pushop.repo = repo
-        summary = orig(pushop)
-        for key, value in summary.items():
-            if b':' in key: # This is a topic
-                if value[0] is None and value[1]:
-                    summary[key] = ([value[1][0]], ) + value[1:]
-        return summary
-    finally:
-        if r'unfiltered' in vars(repo):
-            del repo.unfiltered
-        repo.__class__ = oldrepocls
-        if remotebranchmap is not None:
-            remote.branchmap = origremotebranchmap
+    with override_context_branch(repo, publishedset=publishedset):
+        try:
+            if remotebranchmap is not None:
+                remote.branchmap = remotebranchmap
+            unxx = repo.filtered(b'unfiltered-topic')
+            repo.unfiltered = lambda: unxx
+            pushop.repo = repo
+            summary = orig(pushop)
+            for key, value in summary.items():
+                if b':' in key: # This is a topic
+                    if value[0] is None and value[1]:
+                        summary[key] = ([value[1][0]], ) + value[1:]
+            return summary
+        finally:
+            if r'unfiltered' in vars(repo):
+                del repo.unfiltered
+            if remotebranchmap is not None:
+                remote.branchmap = origremotebranchmap
 
 def wireprotobranchmap(orig, repo, proto):
     if not common.hastopicext(repo):