wireproto: move value encoding functions to wireprototypes (API)
authorGregory Szorc <gregory.szorc@gmail.com>
Wed, 11 Apr 2018 10:50:58 -0700
changeset 37612 5e71dea79aae
parent 37611 ae8730877371
child 37613 96d735601ca1
wireproto: move value encoding functions to wireprototypes (API) These functions should live in the same place. I plan to separate client from server code in upcoming commits. wireprototypes is where we are putting shared code like this. Differential Revision: https://phab.mercurial-scm.org/D3257
hgext/infinitepush/__init__.py
mercurial/wireproto.py
mercurial/wireprototypes.py
--- a/hgext/infinitepush/__init__.py	Tue Apr 10 19:09:35 2018 -0700
+++ b/hgext/infinitepush/__init__.py	Wed Apr 11 10:50:58 2018 -0700
@@ -127,6 +127,7 @@
     registrar,
     util,
     wireproto,
+    wireprototypes,
 )
 
 from . import (
@@ -331,7 +332,7 @@
     return orig(pushop)
 
 def wireprotolistkeyspatterns(repo, proto, namespace, patterns):
-    patterns = wireproto.decodelist(patterns)
+    patterns = wireprototypes.decodelist(patterns)
     d = repo.listkeys(encoding.tolocal(namespace), patterns).iteritems()
     return pushkey.encodekeys(d)
 
@@ -361,7 +362,7 @@
                   (namespace, patterns))
     yield {
         'namespace': encoding.fromlocal(namespace),
-        'patterns': wireproto.encodelist(patterns)
+        'patterns': wireprototypes.encodelist(patterns)
     }, f
     d = f.value
     self.ui.debug('received listkey for "%s": %i bytes\n'
--- a/mercurial/wireproto.py	Tue Apr 10 19:09:35 2018 -0700
+++ b/mercurial/wireproto.py	Wed Apr 11 10:50:58 2018 -0700
@@ -115,37 +115,11 @@
 batchable = peer.batchable
 future = peer.future
 
-# list of nodes encoding / decoding
-
-def decodelist(l, sep=' '):
-    if l:
-        return [bin(v) for v in  l.split(sep)]
-    return []
-
-def encodelist(l, sep=' '):
-    try:
-        return sep.join(map(hex, l))
-    except TypeError:
-        raise
-
-# batched call argument encoding
-
-def escapearg(plain):
-    return (plain
-            .replace(':', ':c')
-            .replace(',', ':o')
-            .replace(';', ':s')
-            .replace('=', ':e'))
-
-def unescapearg(escaped):
-    return (escaped
-            .replace(':e', '=')
-            .replace(':s', ';')
-            .replace(':o', ',')
-            .replace(':c', ':'))
 
 def encodebatchcmds(req):
     """Return a ``cmds`` argument value for the ``batch`` command."""
+    escapearg = wireprototypes.escapebatcharg
+
     cmds = []
     for op, argsdict in req:
         # Old servers didn't properly unescape argument names. So prevent
@@ -227,14 +201,14 @@
         yield {}, f
         d = f.value
         try:
-            yield decodelist(d[:-1])
+            yield wireprototypes.decodelist(d[:-1])
         except ValueError:
             self._abort(error.ResponseError(_("unexpected response:"), d))
 
     @batchable
     def known(self, nodes):
         f = future()
-        yield {'nodes': encodelist(nodes)}, f
+        yield {'nodes': wireprototypes.encodelist(nodes)}, f
         d = f.value
         try:
             yield [bool(int(b)) for b in d]
@@ -251,7 +225,7 @@
             for branchpart in d.splitlines():
                 branchname, branchheads = branchpart.split(' ', 1)
                 branchname = encoding.tolocal(urlreq.unquote(branchname))
-                branchheads = decodelist(branchheads)
+                branchheads = wireprototypes.decodelist(branchheads)
                 branchmap[branchname] = branchheads
             yield branchmap
         except TypeError:
@@ -306,7 +280,7 @@
                 raise error.ProgrammingError(
                     'Unexpectedly None keytype for key %s' % key)
             elif keytype == 'nodes':
-                value = encodelist(value)
+                value = wireprototypes.encodelist(value)
             elif keytype == 'csv':
                 value = ','.join(value)
             elif keytype == 'scsv':
@@ -338,10 +312,10 @@
         '''
 
         if heads != ['force'] and self.capable('unbundlehash'):
-            heads = encodelist(['hashed',
-                                hashlib.sha1(''.join(sorted(heads))).digest()])
+            heads = wireprototypes.encodelist(
+                ['hashed', hashlib.sha1(''.join(sorted(heads))).digest()])
         else:
-            heads = encodelist(heads)
+            heads = wireprototypes.encodelist(heads)
 
         if util.safehasattr(cg, 'deltaheader'):
             # this a bundle10, do the old style call sequence
@@ -368,10 +342,10 @@
     # Begin of ipeerlegacycommands interface.
 
     def branches(self, nodes):
-        n = encodelist(nodes)
+        n = wireprototypes.encodelist(nodes)
         d = self._call("branches", nodes=n)
         try:
-            br = [tuple(decodelist(b)) for b in d.splitlines()]
+            br = [tuple(wireprototypes.decodelist(b)) for b in d.splitlines()]
             return br
         except ValueError:
             self._abort(error.ResponseError(_("unexpected response:"), d))
@@ -380,23 +354,25 @@
         batch = 8 # avoid giant requests
         r = []
         for i in xrange(0, len(pairs), batch):
-            n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
+            n = " ".join([wireprototypes.encodelist(p, '-')
+                          for p in pairs[i:i + batch]])
             d = self._call("between", pairs=n)
             try:
-                r.extend(l and decodelist(l) or [] for l in d.splitlines())
+                r.extend(l and wireprototypes.decodelist(l) or []
+                         for l in d.splitlines())
             except ValueError:
                 self._abort(error.ResponseError(_("unexpected response:"), d))
         return r
 
     def changegroup(self, nodes, kind):
-        n = encodelist(nodes)
+        n = wireprototypes.encodelist(nodes)
         f = self._callcompressable("changegroup", roots=n)
         return changegroupmod.cg1unpacker(f, 'UN')
 
     def changegroupsubset(self, bases, heads, kind):
         self.requirecap('changegroupsubset', _('look up remote changes'))
-        bases = encodelist(bases)
-        heads = encodelist(heads)
+        bases = wireprototypes.encodelist(bases)
+        heads = wireprototypes.encodelist(heads)
         f = self._callcompressable("changegroupsubset",
                                    bases=bases, heads=heads)
         return changegroupmod.cg1unpacker(f, 'UN')
@@ -415,6 +391,8 @@
                 msg = 'devel-peer-request:    - %s (%d arguments)\n'
                 ui.debug(msg % (op, len(args)))
 
+        unescapearg = wireprototypes.unescapebatcharg
+
         rsp = self._callstream("batch", cmds=encodebatchcmds(req))
         chunk = rsp.read(1024)
         work = [chunk]
@@ -793,6 +771,7 @@
 @wireprotocommand('batch', 'cmds *', permission='pull',
                   transportpolicy=POLICY_V1_ONLY)
 def batch(repo, proto, cmds, others):
+    unescapearg = wireprototypes.unescapebatcharg
     repo = repo.filtered("served")
     res = []
     for pair in cmds.split(';'):
@@ -832,17 +811,17 @@
         assert isinstance(result, (wireprototypes.bytesresponse, bytes))
         if isinstance(result, wireprototypes.bytesresponse):
             result = result.data
-        res.append(escapearg(result))
+        res.append(wireprototypes.escapebatcharg(result))
 
     return wireprototypes.bytesresponse(';'.join(res))
 
 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
 def between(repo, proto, pairs):
-    pairs = [decodelist(p, '-') for p in pairs.split(" ")]
+    pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
     r = []
     for b in repo.between(pairs):
-        r.append(encodelist(b) + "\n")
+        r.append(wireprototypes.encodelist(b) + "\n")
 
     return wireprototypes.bytesresponse(''.join(r))
 
@@ -853,7 +832,7 @@
     heads = []
     for branch, nodes in branchmap.iteritems():
         branchname = urlreq.quote(encoding.fromlocal(branch))
-        branchnodes = encodelist(nodes)
+        branchnodes = wireprototypes.encodelist(nodes)
         heads.append('%s %s' % (branchname, branchnodes))
 
     return wireprototypes.bytesresponse('\n'.join(heads))
@@ -861,10 +840,10 @@
 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
 def branches(repo, proto, nodes):
-    nodes = decodelist(nodes)
+    nodes = wireprototypes.decodelist(nodes)
     r = []
     for b in repo.branches(nodes):
-        r.append(encodelist(b) + "\n")
+        r.append(wireprototypes.encodelist(b) + "\n")
 
     return wireprototypes.bytesresponse(''.join(r))
 
@@ -931,7 +910,7 @@
 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
 def changegroup(repo, proto, roots):
-    nodes = decodelist(roots)
+    nodes = wireprototypes.decodelist(roots)
     outgoing = discovery.outgoing(repo, missingroots=nodes,
                                   missingheads=repo.heads())
     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
@@ -942,8 +921,8 @@
                   transportpolicy=POLICY_V1_ONLY,
                   permission='pull')
 def changegroupsubset(repo, proto, bases, heads):
-    bases = decodelist(bases)
-    heads = decodelist(heads)
+    bases = wireprototypes.decodelist(bases)
+    heads = wireprototypes.decodelist(heads)
     outgoing = discovery.outgoing(repo, missingroots=bases,
                                   missingheads=heads)
     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
@@ -1029,7 +1008,7 @@
     for k, v in opts.iteritems():
         keytype = gboptsmap[k]
         if keytype == 'nodes':
-            opts[k] = decodelist(v)
+            opts[k] = wireprototypes.decodelist(v)
         elif keytype == 'csv':
             opts[k] = list(v.split(','))
         elif keytype == 'scsv':
@@ -1101,7 +1080,7 @@
 @wireprotocommand('heads', permission='pull', transportpolicy=POLICY_V1_ONLY)
 def heads(repo, proto):
     h = repo.heads()
-    return wireprototypes.bytesresponse(encodelist(h) + '\n')
+    return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
 
 @wireprotocommand('hello', permission='pull', transportpolicy=POLICY_V1_ONLY)
 def hello(repo, proto):
@@ -1140,7 +1119,8 @@
 @wireprotocommand('known', 'nodes *', permission='pull',
                   transportpolicy=POLICY_V1_ONLY)
 def known(repo, proto, nodes, others):
-    v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
+    v = ''.join(b and '1' or '0'
+                for b in repo.known(wireprototypes.decodelist(nodes)))
     return wireprototypes.bytesresponse(v)
 
 @wireprotocommand('protocaps', 'caps', permission='pull',
@@ -1185,7 +1165,7 @@
 @wireprotocommand('unbundle', 'heads', permission='push',
                   transportpolicy=POLICY_V1_ONLY)
 def unbundle(repo, proto, heads):
-    their_heads = decodelist(heads)
+    their_heads = wireprototypes.decodelist(heads)
 
     with proto.mayberedirectstdio() as output:
         try:
--- a/mercurial/wireprototypes.py	Tue Apr 10 19:09:35 2018 -0700
+++ b/mercurial/wireprototypes.py	Wed Apr 11 10:50:58 2018 -0700
@@ -5,6 +5,10 @@
 
 from __future__ import absolute_import
 
+from .node import (
+    bin,
+    hex,
+)
 from .thirdparty.zope import (
     interface as zi,
 )
@@ -102,6 +106,34 @@
     def __init__(self, v):
         self.value = v
 
+# list of nodes encoding / decoding
+def decodelist(l, sep=' '):
+    if l:
+        return [bin(v) for v in  l.split(sep)]
+    return []
+
+def encodelist(l, sep=' '):
+    try:
+        return sep.join(map(hex, l))
+    except TypeError:
+        raise
+
+# batched call argument encoding
+
+def escapebatcharg(plain):
+    return (plain
+            .replace(':', ':c')
+            .replace(',', ':o')
+            .replace(';', ':s')
+            .replace('=', ':e'))
+
+def unescapebatcharg(escaped):
+    return (escaped
+            .replace(':e', '=')
+            .replace(':s', ';')
+            .replace(':o', ',')
+            .replace(':c', ':'))
+
 class baseprotocolhandler(zi.Interface):
     """Abstract base class for wire protocol handlers.