changegroup: extract the file management part in its own function
authorPierre-Yves David <pierre-yves.david@fb.com>
Mon, 05 Oct 2015 00:14:47 -0700
changeset 26540 7469067de2ba
parent 26539 1956026e4db2
child 26541 d40029b4296e
changegroup: extract the file management part in its own function The current writebundle function do two things: - taking a changegroup-packer instance and storing it into a valid bundle with proper header. - creating a temporary or requested file to store that bundle We would like to make it easier to forward bundle stream directly from a remote peer to a file, so we split the two logic to be able to skip the one about building a valid bundle (the remote is already sending one).
mercurial/changegroup.py
--- a/mercurial/changegroup.py	Sun Oct 04 21:48:19 2015 -0700
+++ b/mercurial/changegroup.py	Mon Oct 05 00:14:47 2015 -0700
@@ -92,15 +92,13 @@
 # hgweb uses this list to communicate its preferred type
 bundlepriority = ['HG10GZ', 'HG10BZ', 'HG10UN']
 
-def writebundle(ui, cg, filename, bundletype, vfs=None, compression=None):
-    """Write a bundle file and return its filename.
+def writechunks(ui, chunks, filename, vfs=None):
+    """Write chunks to a file and return its filename.
 
+    The stream is assumed to be a bundle file.
     Existing files will not be overwritten.
     If no filename is specified, a temporary file is created.
-    bz2 compression can be turned off.
-    The bundle file will be deleted in case of errors.
     """
-
     fh = None
     cleanup = None
     try:
@@ -113,38 +111,8 @@
             fd, filename = tempfile.mkstemp(prefix="hg-bundle-", suffix=".hg")
             fh = os.fdopen(fd, "wb")
         cleanup = filename
-
-        if bundletype == "HG20":
-            from . import bundle2
-            bundle = bundle2.bundle20(ui)
-            bundle.setcompression(compression)
-            part = bundle.newpart('changegroup', data=cg.getchunks())
-            part.addparam('version', cg.version)
-            z = util.compressors[None]()
-            chunkiter = bundle.getchunks()
-        else:
-            # compression argument is only for the bundle2 case
-            assert compression is None
-            if cg.version != '01':
-                raise util.Abort(_('old bundle types only supports v1 '
-                                   'changegroups'))
-            header, comp = bundletypes[bundletype]
-            fh.write(header)
-            if comp not in util.compressors:
-                raise util.Abort(_('unknown stream compression type: %s')
-                                 % comp)
-            z = util.compressors[comp]()
-            chunkiter = cg.getchunks()
-
-        # parse the changegroup data, otherwise we will block
-        # in case of sshrepo because we don't know the end of the stream
-
-        # an empty chunkgroup is the end of the changegroup
-        # a changegroup has at least 2 chunkgroups (changelog and manifest).
-        # after that, an empty chunkgroup is the end of the changegroup
-        for chunk in chunkiter:
-            fh.write(z.compress(chunk))
-        fh.write(z.flush())
+        for c in chunks:
+            fh.write(c)
         cleanup = None
         return filename
     finally:
@@ -156,6 +124,49 @@
             else:
                 os.unlink(cleanup)
 
+def writebundle(ui, cg, filename, bundletype, vfs=None, compression=None):
+    """Write a bundle file and return its filename.
+
+    Existing files will not be overwritten.
+    If no filename is specified, a temporary file is created.
+    bz2 compression can be turned off.
+    The bundle file will be deleted in case of errors.
+    """
+
+    if bundletype == "HG20":
+        from . import bundle2
+        bundle = bundle2.bundle20(ui)
+        bundle.setcompression(compression)
+        part = bundle.newpart('changegroup', data=cg.getchunks())
+        part.addparam('version', cg.version)
+        chunkiter = bundle.getchunks()
+    else:
+        # compression argument is only for the bundle2 case
+        assert compression is None
+        if cg.version != '01':
+            raise util.Abort(_('old bundle types only supports v1 '
+                               'changegroups'))
+        header, comp = bundletypes[bundletype]
+        if comp not in util.compressors:
+            raise util.Abort(_('unknown stream compression type: %s')
+                             % comp)
+        z = util.compressors[comp]()
+        subchunkiter = cg.getchunks()
+        def chunkiter():
+            yield header
+            for chunk in subchunkiter:
+                yield z.compress(chunk)
+            yield z.flush()
+        chunkiter = chunkiter()
+
+    # parse the changegroup data, otherwise we will block
+    # in case of sshrepo because we don't know the end of the stream
+
+    # an empty chunkgroup is the end of the changegroup
+    # a changegroup has at least 2 chunkgroups (changelog and manifest).
+    # after that, an empty chunkgroup is the end of the changegroup
+    return writechunks(ui, chunkiter, filename, vfs=vfs)
+
 class cg1unpacker(object):
     deltaheader = _CHANGEGROUPV1_DELTA_HEADER
     deltaheadersize = struct.calcsize(deltaheader)