changeset 43540:bad4a26b4607

repoview: define filteredchangelog as a top-level (non-local) class As suggested by Greg. This makes it easier for extensions to override the filtering. Differential Revision: https://phab.mercurial-scm.org/D7256
author Martin von Zweigbergk <martinvonz@google.com>
date Wed, 06 Nov 2019 00:35:41 -0800
parents 9391784299e9
children 3a463e5e470b
files mercurial/repoview.py
diffstat 1 files changed, 88 insertions(+), 85 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/repoview.py	Wed Nov 06 17:35:24 2019 -0500
+++ b/mercurial/repoview.py	Wed Nov 06 00:35:41 2019 -0800
@@ -227,108 +227,111 @@
     cl = copy.copy(unfichangelog)
     cl.filteredrevs = filteredrevs
 
-    class filteredchangelog(cl.__class__):
-        def tiprev(self):
-            """filtered version of revlog.tiprev"""
-            for i in pycompat.xrange(len(self) - 1, -2, -1):
-                if i not in self.filteredrevs:
-                    return i
+    cl.__class__ = type(
+        'filteredchangelog', (filteredchangelogmixin, cl.__class__), {}
+    )
 
-        def __contains__(self, rev):
-            """filtered version of revlog.__contains__"""
-            return 0 <= rev < len(self) and rev not in self.filteredrevs
+    return cl
+
 
-        def __iter__(self):
-            """filtered version of revlog.__iter__"""
+class filteredchangelogmixin(object):
+    def tiprev(self):
+        """filtered version of revlog.tiprev"""
+        for i in pycompat.xrange(len(self) - 1, -2, -1):
+            if i not in self.filteredrevs:
+                return i
 
-            def filterediter():
-                for i in pycompat.xrange(len(self)):
-                    if i not in self.filteredrevs:
-                        yield i
+    def __contains__(self, rev):
+        """filtered version of revlog.__contains__"""
+        return 0 <= rev < len(self) and rev not in self.filteredrevs
 
-            return filterediter()
+    def __iter__(self):
+        """filtered version of revlog.__iter__"""
 
-        def revs(self, start=0, stop=None):
-            """filtered version of revlog.revs"""
-            for i in super(filteredchangelog, self).revs(start, stop):
+        def filterediter():
+            for i in pycompat.xrange(len(self)):
                 if i not in self.filteredrevs:
                     yield i
 
-        def _checknofilteredinrevs(self, revs):
-            """raise the appropriate error if 'revs' contains a filtered revision
+        return filterediter()
+
+    def revs(self, start=0, stop=None):
+        """filtered version of revlog.revs"""
+        for i in super(filteredchangelogmixin, self).revs(start, stop):
+            if i not in self.filteredrevs:
+                yield i
 
-            This returns a version of 'revs' to be used thereafter by the caller.
-            In particular, if revs is an iterator, it is converted into a set.
-            """
-            safehasattr = util.safehasattr
-            if safehasattr(revs, '__next__'):
-                # Note that inspect.isgenerator() is not true for iterators,
-                revs = set(revs)
+    def _checknofilteredinrevs(self, revs):
+        """raise the appropriate error if 'revs' contains a filtered revision
+
+        This returns a version of 'revs' to be used thereafter by the caller.
+        In particular, if revs is an iterator, it is converted into a set.
+        """
+        safehasattr = util.safehasattr
+        if safehasattr(revs, '__next__'):
+            # Note that inspect.isgenerator() is not true for iterators,
+            revs = set(revs)
 
-            filteredrevs = self.filteredrevs
-            if safehasattr(revs, 'first'):  # smartset
-                offenders = revs & filteredrevs
-            else:
-                offenders = filteredrevs.intersection(revs)
-
-            for rev in offenders:
-                raise error.FilteredIndexError(rev)
-            return revs
+        filteredrevs = self.filteredrevs
+        if safehasattr(revs, 'first'):  # smartset
+            offenders = revs & filteredrevs
+        else:
+            offenders = filteredrevs.intersection(revs)
 
-        def headrevs(self, revs=None):
-            if revs is None:
-                try:
-                    return self.index.headrevsfiltered(self.filteredrevs)
-                # AttributeError covers non-c-extension environments and
-                # old c extensions without filter handling.
-                except AttributeError:
-                    return self._headrevs()
+        for rev in offenders:
+            raise error.FilteredIndexError(rev)
+        return revs
 
-            revs = self._checknofilteredinrevs(revs)
-            return super(filteredchangelog, self).headrevs(revs)
+    def headrevs(self, revs=None):
+        if revs is None:
+            try:
+                return self.index.headrevsfiltered(self.filteredrevs)
+            # AttributeError covers non-c-extension environments and
+            # old c extensions without filter handling.
+            except AttributeError:
+                return self._headrevs()
 
-        def strip(self, *args, **kwargs):
-            # XXX make something better than assert
-            # We can't expect proper strip behavior if we are filtered.
-            assert not self.filteredrevs
-            super(filteredchangelog, self).strip(*args, **kwargs)
+        revs = self._checknofilteredinrevs(revs)
+        return super(filteredchangelogmixin, self).headrevs(revs)
+
+    def strip(self, *args, **kwargs):
+        # XXX make something better than assert
+        # We can't expect proper strip behavior if we are filtered.
+        assert not self.filteredrevs
+        super(filteredchangelogmixin, self).strip(*args, **kwargs)
 
-        def rev(self, node):
-            """filtered version of revlog.rev"""
-            r = super(filteredchangelog, self).rev(node)
-            if r in self.filteredrevs:
-                raise error.FilteredLookupError(
-                    hex(node), self.indexfile, _(b'filtered node')
-                )
-            return r
-
-        def node(self, rev):
-            """filtered version of revlog.node"""
-            if rev in self.filteredrevs:
-                raise error.FilteredIndexError(rev)
-            return super(filteredchangelog, self).node(rev)
+    def rev(self, node):
+        """filtered version of revlog.rev"""
+        r = super(filteredchangelogmixin, self).rev(node)
+        if r in self.filteredrevs:
+            raise error.FilteredLookupError(
+                hex(node), self.indexfile, _(b'filtered node')
+            )
+        return r
 
-        def linkrev(self, rev):
-            """filtered version of revlog.linkrev"""
-            if rev in self.filteredrevs:
-                raise error.FilteredIndexError(rev)
-            return super(filteredchangelog, self).linkrev(rev)
+    def node(self, rev):
+        """filtered version of revlog.node"""
+        if rev in self.filteredrevs:
+            raise error.FilteredIndexError(rev)
+        return super(filteredchangelogmixin, self).node(rev)
+
+    def linkrev(self, rev):
+        """filtered version of revlog.linkrev"""
+        if rev in self.filteredrevs:
+            raise error.FilteredIndexError(rev)
+        return super(filteredchangelogmixin, self).linkrev(rev)
 
-        def parentrevs(self, rev):
-            """filtered version of revlog.parentrevs"""
-            if rev in self.filteredrevs:
-                raise error.FilteredIndexError(rev)
-            return super(filteredchangelog, self).parentrevs(rev)
+    def parentrevs(self, rev):
+        """filtered version of revlog.parentrevs"""
+        if rev in self.filteredrevs:
+            raise error.FilteredIndexError(rev)
+        return super(filteredchangelogmixin, self).parentrevs(rev)
 
-        def flags(self, rev):
-            """filtered version of revlog.flags"""
-            if rev in self.filteredrevs:
-                raise error.FilteredIndexError(rev)
-            return super(filteredchangelog, self).flags(rev)
-
-    cl.__class__ = filteredchangelog
-
-    return cl
+    def flags(self, rev):
+        """filtered version of revlog.flags"""
+        if rev in self.filteredrevs:
+            raise error.FilteredIndexError(rev)
+        return super(filteredchangelogmixin, self).flags(rev)
 
 
 class repoview(object):