mercurial/ancestor.py
changeset 18986 2f7186400a07
parent 18091 f7f8159caad3
child 18987 3605d4e7e618
--- a/mercurial/ancestor.py	Mon Apr 15 01:59:11 2013 +0200
+++ b/mercurial/ancestor.py	Tue Apr 16 10:08:18 2013 -0700
@@ -5,10 +5,132 @@
 # This software may be used and distributed according to the terms of the
 # GNU General Public License version 2 or any later version.
 
-import heapq, util
+import error, heapq, util
 from node import nullrev
 
-def ancestor(a, b, pfunc):
+def ancestors(pfunc, *orignodes):
+    """
+    Returns the common ancestors of a and b that are furthest from a
+    root (as measured by longest path).
+
+    pfunc must return a list of parent vertices for a given vertex.
+    """
+    if not isinstance(orignodes, set):
+        orignodes = set(orignodes)
+    if nullrev in orignodes:
+        return set()
+    if len(orignodes) <= 1:
+        return orignodes
+
+    def candidates(nodes):
+        allseen = (1 << len(nodes)) - 1
+        seen = [0] * (max(nodes) + 1)
+        for i, n in enumerate(nodes):
+            seen[n] = 1 << i
+        poison = 1 << (i + 1)
+
+        gca = set()
+        interesting = left = len(nodes)
+        nv = len(seen) - 1
+        while nv >= 0 and interesting:
+            v = nv
+            nv -= 1
+            if not seen[v]:
+                continue
+            sv = seen[v]
+            if sv < poison:
+                interesting -= 1
+                if sv == allseen:
+                    gca.add(v)
+                    sv |= poison
+                    if v in nodes:
+                        left -= 1
+                        if left <= 1:
+                            # history is linear
+                            return set([v])
+            if sv < poison:
+                for p in pfunc(v):
+                    sp = seen[p]
+                    if p == nullrev:
+                        continue
+                    if sp == 0:
+                        seen[p] = sv
+                        interesting += 1
+                    elif sp != sv:
+                        seen[p] |= sv
+            else:
+                for p in pfunc(v):
+                    if p == nullrev:
+                        continue
+                    sp = seen[p]
+                    if sp and sp < poison:
+                        interesting -= 1
+                    seen[p] = sv
+        return gca
+
+    def deepest(nodes):
+        interesting = {}
+        count = max(nodes) + 1
+        depth = [0] * count
+        seen = [0] * count
+        mapping = []
+        for (i, n) in enumerate(sorted(nodes)):
+            depth[n] = 1
+            b = 1 << i
+            seen[n] = b
+            interesting[b] = 1
+            mapping.append((b, n))
+        nv = count - 1
+        while nv >= 0 and len(interesting) > 1:
+            v = nv
+            nv -= 1
+            dv = depth[v]
+            if dv == 0:
+                continue
+            sv = seen[v]
+            for p in pfunc(v):
+                if p == nullrev:
+                    continue
+                dp = depth[p]
+                nsp = sp = seen[p]
+                if dp <= dv:
+                    depth[p] = dv + 1
+                    if sp != sv:
+                        interesting[sv] += 1
+                        nsp = seen[p] = sv
+                        if sp:
+                            interesting[sp] -= 1
+                            if interesting[sp] == 0:
+                                del interesting[sp]
+                elif dv == dp - 1:
+                    nsp = sp | sv
+                    if nsp == sp:
+                        continue
+                    seen[p] = nsp
+                    interesting.setdefault(nsp, 0)
+                    interesting[nsp] += 1
+                    interesting[sp] -= 1
+                    if interesting[sp] == 0:
+                        del interesting[sp]
+            interesting[sv] -= 1
+            if interesting[sv] == 0:
+                del interesting[sv]
+
+        if len(interesting) != 1:
+            return []
+
+        k = 0
+        for i in interesting:
+            k |= i
+        return set(n for (i, n) in mapping if k & i)
+
+    gca = candidates(orignodes)
+
+    if len(gca) <= 1:
+        return gca
+    return deepest(gca)
+
+def genericancestor(a, b, pfunc):
     """
     Returns the common ancestor of a and b that is furthest from a
     root (as measured by longest path) or None if no ancestor is
@@ -30,7 +152,7 @@
     depth = {}
     while visit:
         vertex = visit[-1]
-        pl = pfunc(vertex)
+        pl = [p for p in pfunc(vertex) if p != nullrev]
         parentcache[vertex] = pl
         if not pl:
             depth[vertex] = 0
@@ -91,6 +213,51 @@
     except StopIteration:
         return None
 
+def finddepths(nodes, pfunc):
+    visit = list(nodes)
+    rootpl = [nullrev, nullrev]
+    depth = {}
+    while visit:
+        vertex = visit[-1]
+        pl = pfunc(vertex)
+        if not pl or pl == rootpl:
+            depth[vertex] = 0
+            visit.pop()
+        else:
+            for p in pl:
+                if p != nullrev and p not in depth:
+                    visit.append(p)
+            if visit[-1] == vertex:
+                dp = [depth[p] for p in pl if p != nullrev]
+                if dp:
+                    depth[vertex] = max(dp) + 1
+                else:
+                    depth[vertex] = 0
+                visit.pop()
+    return depth
+
+def ancestor(a, b, pfunc):
+    xs = ancestors(pfunc, a, b)
+    y = genericancestor(a, b, pfunc)
+    if y == -1:
+        y = None
+    if not xs:
+        if y is None:
+            return None
+        print xs, y
+        raise error.RepoError('ancestors disagree on whether a gca exists')
+    elif y is None:
+        print xs, y
+        raise error.RepoError('ancestors disagree on whether a gca exists')
+    if y in xs:
+        return y
+    xds = finddepths(xs, pfunc)
+    xds = [ds[x] for x in xs]
+    yd = finddepths([y], pfunc)[y]
+    if len([xd != yd for xd in xds]) > 0:
+        raise error.RepoError('ancestor depths do not match')
+    return xs.pop()
+
 def missingancestors(revs, bases, pfunc):
     """Return all the ancestors of revs that are not ancestors of bases.