changegroup: move all compressions utilities in util
authorPierre-Yves David <pierre-yves.david@fb.com>
Tue, 15 Sep 2015 17:35:32 -0700
changeset 26266 1e042e31bd0c
parent 26265 077f20eed4b2
child 26267 eca468b8fae4
changegroup: move all compressions utilities in util We'll reuse the compression for other things (next target bundle2), so let's make it more accessible and organised.
mercurial/changegroup.py
mercurial/util.py
--- a/mercurial/changegroup.py	Tue Sep 15 13:12:03 2015 -0700
+++ b/mercurial/changegroup.py	Tue Sep 15 17:35:32 2015 -0700
@@ -7,12 +7,10 @@
 
 from __future__ import absolute_import
 
-import bz2
 import os
 import struct
 import tempfile
 import weakref
-import zlib
 
 from .i18n import _
 from .node import (
@@ -81,20 +79,14 @@
         result = -1 + changedheads
     return result
 
-class nocompress(object):
-    def compress(self, x):
-        return x
-    def flush(self):
-        return ""
-
 bundletypes = {
-    "": ("", nocompress), # only when using unbundle on ssh and old http servers
+    "": ("", 'UN'),       # only when using unbundle on ssh and old http servers
                           # since the unification ssh accepts a header but there
                           # is no capability signaling it.
     "HG20": (), # special-cased below
-    "HG10UN": ("HG10UN", nocompress),
-    "HG10BZ": ("HG10", lambda: bz2.BZ2Compressor()),
-    "HG10GZ": ("HG10GZ", lambda: zlib.compressobj()),
+    "HG10UN": ("HG10UN", 'UN'),
+    "HG10BZ": ("HG10", 'BZ'),
+    "HG10GZ": ("HG10GZ", 'GZ'),
 }
 
 # hgweb uses this list to communicate its preferred type
@@ -127,15 +119,18 @@
             bundle = bundle2.bundle20(ui)
             part = bundle.newpart('changegroup', data=cg.getchunks())
             part.addparam('version', cg.version)
-            z = nocompress()
+            z = util.compressors['UN']()
             chunkiter = bundle.getchunks()
         else:
             if cg.version != '01':
                 raise util.Abort(_('old bundle types only supports v1 '
                                    'changegroups'))
-            header, compressor = bundletypes[bundletype]
+            header, comp = bundletypes[bundletype]
             fh.write(header)
-            z = compressor()
+            if comp not in util.compressors:
+                raise util.Abort(_('unknown stream compression type: %s')
+                                 % comp)
+            z = util.compressors[comp]()
             chunkiter = cg.getchunks()
 
         # parse the changegroup data, otherwise we will block
@@ -158,30 +153,15 @@
             else:
                 os.unlink(cleanup)
 
-def decompressor(fh, alg):
-    if alg == 'UN':
-        return fh
-    elif alg == 'GZ':
-        def generator(f):
-            zd = zlib.decompressobj()
-            for chunk in util.filechunkiter(f):
-                yield zd.decompress(chunk)
-    elif alg == 'BZ':
-        def generator(f):
-            zd = bz2.BZ2Decompressor()
-            zd.decompress("BZ")
-            for chunk in util.filechunkiter(f, 4096):
-                yield zd.decompress(chunk)
-    else:
-        raise util.Abort("unknown bundle compression '%s'" % alg)
-    return util.chunkbuffer(generator(fh))
-
 class cg1unpacker(object):
     deltaheader = _CHANGEGROUPV1_DELTA_HEADER
     deltaheadersize = struct.calcsize(deltaheader)
     version = '01'
     def __init__(self, fh, alg):
-        self._stream = decompressor(fh, alg)
+        if not alg in util.decompressors:
+            raise util.Abort(_('unknown stream compression type: %s')
+                             % alg)
+        self._stream = util.decompressors[alg](fh)
         self._type = alg
         self.callback = None
     def compressed(self):
--- a/mercurial/util.py	Tue Sep 15 13:12:03 2015 -0700
+++ b/mercurial/util.py	Tue Sep 15 17:35:32 2015 -0700
@@ -21,6 +21,8 @@
 import os, time, datetime, calendar, textwrap, signal, collections
 import imp, socket, urllib
 import gc
+import bz2
+import zlib
 
 if os.name == 'nt':
     import windows as platform
@@ -2338,5 +2340,41 @@
         yield path[:pos]
         pos = path.rfind('/', 0, pos)
 
+# compression utility
+
+class nocompress(object):
+    def compress(self, x):
+        return x
+    def flush(self):
+        return ""
+
+compressors = {
+    'UN': nocompress,
+    # lambda to prevent early import
+    'BZ': lambda: bz2.BZ2Compressor(),
+    'GZ': lambda: zlib.compressobj(),
+    }
+
+def _makedecompressor(decompcls):
+    def generator(f):
+        d = decompcls()
+        for chunk in filechunkiter(f):
+            yield d.decompress(chunk)
+    def func(fh):
+        return chunkbuffer(generator(fh))
+    return func
+
+def _bz2():
+    d = bz2.BZ2Decompressor()
+    # Bzip2 stream start with BZ, but we stripped it.
+    # we put it back for good measure.
+    d.decompress('BZ')
+    return d
+
+decompressors = {'UN': lambda fh: fh,
+                 'BZ': _makedecompressor(_bz2),
+                 'GZ': _makedecompressor(lambda: zlib.decompressobj()),
+                 }
+
 # convenient shortcut
 dst = debugstacktrace