mercurial/bundle2.py
changeset 35112 073eec083e25
parent 35046 241d9caca11e
child 35113 8aa43ff9c12c
--- a/mercurial/bundle2.py	Mon Nov 13 19:20:34 2017 -0800
+++ b/mercurial/bundle2.py	Mon Nov 13 19:22:11 2017 -0800
@@ -854,7 +854,7 @@
         indebug(self.ui, 'start extraction of bundle2 parts')
         headerblock = self._readpartheader()
         while headerblock is not None:
-            part = unbundlepart(self.ui, headerblock, self._fp)
+            part = seekableunbundlepart(self.ui, headerblock, self._fp)
             yield part
             # Seek to the end of the part to force it's consumption so the next
             # part can be read. But then seek back to the beginning so the
@@ -1155,7 +1155,7 @@
         if headerblock is None:
             indebug(self.ui, 'no part found during interruption.')
             return
-        part = unbundlepart(self.ui, headerblock, self._fp)
+        part = seekableunbundlepart(self.ui, headerblock, self._fp)
         op = interruptoperation(self.ui)
         hardabort = False
         try:
@@ -1207,10 +1207,8 @@
         self.advisoryparams = None
         self.params = None
         self.mandatorykeys = ()
-        self._payloadstream = None
         self._readheader()
         self._mandatory = None
-        self._chunkindex = [] #(payload, file) position tuples for chunk starts
         self._pos = 0
 
     def _fromheader(self, size):
@@ -1237,46 +1235,6 @@
         self.params.update(self.advisoryparams)
         self.mandatorykeys = frozenset(p[0] for p in mandatoryparams)
 
-    def _payloadchunks(self, chunknum=0):
-        '''seek to specified chunk and start yielding data'''
-        if len(self._chunkindex) == 0:
-            assert chunknum == 0, 'Must start with chunk 0'
-            self._chunkindex.append((0, self._tellfp()))
-        else:
-            assert chunknum < len(self._chunkindex), \
-                   'Unknown chunk %d' % chunknum
-            self._seekfp(self._chunkindex[chunknum][1])
-
-        pos = self._chunkindex[chunknum][0]
-        payloadsize = self._unpack(_fpayloadsize)[0]
-        indebug(self.ui, 'payload chunk size: %i' % payloadsize)
-        while payloadsize:
-            if payloadsize == flaginterrupt:
-                # interruption detection, the handler will now read a
-                # single part and process it.
-                interrupthandler(self.ui, self._fp)()
-            elif payloadsize < 0:
-                msg = 'negative payload chunk size: %i' %  payloadsize
-                raise error.BundleValueError(msg)
-            else:
-                result = self._readexact(payloadsize)
-                chunknum += 1
-                pos += payloadsize
-                if chunknum == len(self._chunkindex):
-                    self._chunkindex.append((pos, self._tellfp()))
-                yield result
-            payloadsize = self._unpack(_fpayloadsize)[0]
-            indebug(self.ui, 'payload chunk size: %i' % payloadsize)
-
-    def _findchunk(self, pos):
-        '''for a given payload position, return a chunk number and offset'''
-        for chunk, (ppos, fpos) in enumerate(self._chunkindex):
-            if ppos == pos:
-                return chunk, 0
-            elif ppos > pos:
-                return chunk - 1, pos - self._chunkindex[chunk - 1][0]
-        raise ValueError('Unknown chunk')
-
     def _readheader(self):
         """read the header and setup the object"""
         typesize = self._unpackheader(_fparttypesize)[0]
@@ -1328,6 +1286,69 @@
             self.consumed = True
         return data
 
+class seekableunbundlepart(unbundlepart):
+    """A bundle2 part in a bundle that is seekable.
+
+    Regular ``unbundlepart`` instances can only be read once. This class
+    extends ``unbundlepart`` to enable bi-directional seeking within the
+    part.
+
+    Bundle2 part data consists of framed chunks. Offsets when seeking
+    refer to the decoded data, not the offsets in the underlying bundle2
+    stream.
+
+    To facilitate quickly seeking within the decoded data, instances of this
+    class maintain a mapping between offsets in the underlying stream and
+    the decoded payload. This mapping will consume memory in proportion
+    to the number of chunks within the payload (which almost certainly
+    increases in proportion with the size of the part).
+    """
+    def __init__(self, ui, header, fp):
+        # (payload, file) offsets for chunk starts.
+        self._chunkindex = []
+
+        super(seekableunbundlepart, self).__init__(ui, header, fp)
+
+    def _payloadchunks(self, chunknum=0):
+        '''seek to specified chunk and start yielding data'''
+        if len(self._chunkindex) == 0:
+            assert chunknum == 0, 'Must start with chunk 0'
+            self._chunkindex.append((0, self._tellfp()))
+        else:
+            assert chunknum < len(self._chunkindex), \
+                   'Unknown chunk %d' % chunknum
+            self._seekfp(self._chunkindex[chunknum][1])
+
+        pos = self._chunkindex[chunknum][0]
+        payloadsize = self._unpack(_fpayloadsize)[0]
+        indebug(self.ui, 'payload chunk size: %i' % payloadsize)
+        while payloadsize:
+            if payloadsize == flaginterrupt:
+                # interruption detection, the handler will now read a
+                # single part and process it.
+                interrupthandler(self.ui, self._fp)()
+            elif payloadsize < 0:
+                msg = 'negative payload chunk size: %i' %  payloadsize
+                raise error.BundleValueError(msg)
+            else:
+                result = self._readexact(payloadsize)
+                chunknum += 1
+                pos += payloadsize
+                if chunknum == len(self._chunkindex):
+                    self._chunkindex.append((pos, self._tellfp()))
+                yield result
+            payloadsize = self._unpack(_fpayloadsize)[0]
+            indebug(self.ui, 'payload chunk size: %i' % payloadsize)
+
+    def _findchunk(self, pos):
+        '''for a given payload position, return a chunk number and offset'''
+        for chunk, (ppos, fpos) in enumerate(self._chunkindex):
+            if ppos == pos:
+                return chunk, 0
+            elif ppos > pos:
+                return chunk - 1, pos - self._chunkindex[chunk - 1][0]
+        raise ValueError('Unknown chunk')
+
     def tell(self):
         return self._pos