diff mercurial/wireprotov1server.py @ 43076:2372284d9457

formatting: blacken the codebase This is using my patch to black (https://github.com/psf/black/pull/826) so we don't un-wrap collection literals. Done with: hg files 'set:**.py - mercurial/thirdparty/** - "contrib/python-zstandard/**"' | xargs black -S # skip-blame mass-reformatting only # no-check-commit reformats foo_bar functions Differential Revision: https://phab.mercurial-scm.org/D6971
author Augie Fackler <augie@google.com>
date Sun, 06 Oct 2019 09:45:02 -0400
parents 566daffc607d
children 687b865b95ad
line wrap: on
line diff
--- a/mercurial/wireprotov1server.py	Sat Oct 05 10:29:34 2019 -0400
+++ b/mercurial/wireprotov1server.py	Sun Oct 06 09:45:02 2019 -0400
@@ -39,10 +39,12 @@
 urlreq = util.urlreq
 
 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
-bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
-                        'IncompatibleClient')
+bundle2requiredhint = _(
+    'see https://www.mercurial-scm.org/wiki/' 'IncompatibleClient'
+)
 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
 
+
 def clientcompressionsupport(proto):
     """Returns a list of compression methods supported by the client.
 
@@ -55,8 +57,10 @@
             return cap[5:].split(',')
     return ['zlib', 'none']
 
+
 # wire protocol command can either return a string or one of these classes.
 
+
 def getdispatchrepo(repo, proto, command):
     """Obtain the repo used for processing wire protocol commands.
 
@@ -67,6 +71,7 @@
     viewconfig = repo.ui.config('server', 'view')
     return repo.filtered(viewconfig)
 
+
 def dispatch(repo, proto, command):
     repo = getdispatchrepo(repo, proto, command)
 
@@ -75,6 +80,7 @@
 
     return func(repo, proto, *args)
 
+
 def options(cmd, keys, others):
     opts = {}
     for k in keys:
@@ -82,10 +88,13 @@
             opts[k] = others[k]
             del others[k]
     if others:
-        procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
-                              % (cmd, ",".join(others)))
+        procutil.stderr.write(
+            "warning: %s ignored unexpected arguments %s\n"
+            % (cmd, ",".join(others))
+        )
     return opts
 
+
 def bundle1allowed(repo, action):
     """Whether a bundle1 operation is allowed from the server.
 
@@ -115,8 +124,10 @@
 
     return ui.configbool('server', 'bundle1')
 
+
 commands = wireprototypes.commanddict()
 
+
 def wireprotocommand(name, args=None, permission='push'):
     """Decorator to declare a wire protocol command.
 
@@ -132,8 +143,9 @@
     because otherwise commands not declaring their permissions could modify
     a repository that is supposed to be read-only.
     """
-    transports = {k for k, v in wireprototypes.TRANSPORTS.items()
-                  if v['version'] == 1}
+    transports = {
+        k for k, v in wireprototypes.TRANSPORTS.items() if v['version'] == 1
+    }
 
     # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
     # SSHv2.
@@ -142,27 +154,33 @@
         transports.add(wireprototypes.SSHV2)
 
     if permission not in ('push', 'pull'):
-        raise error.ProgrammingError('invalid wire protocol permission; '
-                                     'got %s; expected "push" or "pull"' %
-                                     permission)
+        raise error.ProgrammingError(
+            'invalid wire protocol permission; '
+            'got %s; expected "push" or "pull"' % permission
+        )
 
     if args is None:
         args = ''
 
     if not isinstance(args, bytes):
-        raise error.ProgrammingError('arguments for version 1 commands '
-                                     'must be declared as bytes')
+        raise error.ProgrammingError(
+            'arguments for version 1 commands ' 'must be declared as bytes'
+        )
 
     def register(func):
         if name in commands:
-            raise error.ProgrammingError('%s command already registered '
-                                         'for version 1' % name)
+            raise error.ProgrammingError(
+                '%s command already registered ' 'for version 1' % name
+            )
         commands[name] = wireprototypes.commandentry(
-            func, args=args, transports=transports, permission=permission)
+            func, args=args, transports=transports, permission=permission
+        )
 
         return func
+
     return register
 
+
 # TODO define a more appropriate permissions type to use for this.
 @wireprotocommand('batch', 'cmds *', permission='pull')
 def batch(repo, proto, cmds, others):
@@ -209,6 +227,7 @@
 
     return wireprototypes.bytesresponse(';'.join(res))
 
+
 @wireprotocommand('between', 'pairs', permission='pull')
 def between(repo, proto, pairs):
     pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
@@ -218,6 +237,7 @@
 
     return wireprototypes.bytesresponse(''.join(r))
 
+
 @wireprotocommand('branchmap', permission='pull')
 def branchmap(repo, proto):
     branchmap = repo.branchmap()
@@ -229,6 +249,7 @@
 
     return wireprototypes.bytesresponse('\n'.join(heads))
 
+
 @wireprotocommand('branches', 'nodes', permission='pull')
 def branches(repo, proto, nodes):
     nodes = wireprototypes.decodelist(nodes)
@@ -238,6 +259,7 @@
 
     return wireprototypes.bytesresponse(''.join(r))
 
+
 @wireprotocommand('clonebundles', '', permission='pull')
 def clonebundles(repo, proto):
     """Server command for returning info for available bundles to seed clones.
@@ -249,10 +271,19 @@
     data center given the client's IP address.
     """
     return wireprototypes.bytesresponse(
-        repo.vfs.tryread('clonebundles.manifest'))
+        repo.vfs.tryread('clonebundles.manifest')
+    )
+
 
-wireprotocaps = ['lookup', 'branchmap', 'pushkey',
-                 'known', 'getbundle', 'unbundlehash']
+wireprotocaps = [
+    'lookup',
+    'branchmap',
+    'pushkey',
+    'known',
+    'getbundle',
+    'unbundlehash',
+]
+
 
 def _capabilities(repo, proto):
     """return a list of capabilities for a repo
@@ -294,6 +325,7 @@
 
     return proto.addcapabilities(repo, caps)
 
+
 # If you are writing an extension and consider wrapping this function. Wrap
 # `_capabilities` instead.
 @wireprotocommand('capabilities', permission='pull')
@@ -301,33 +333,36 @@
     caps = _capabilities(repo, proto)
     return wireprototypes.bytesresponse(' '.join(sorted(caps)))
 
+
 @wireprotocommand('changegroup', 'roots', permission='pull')
 def changegroup(repo, proto, roots):
     nodes = wireprototypes.decodelist(roots)
-    outgoing = discovery.outgoing(repo, missingroots=nodes,
-                                  missingheads=repo.heads())
+    outgoing = discovery.outgoing(
+        repo, missingroots=nodes, missingheads=repo.heads()
+    )
     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
     gen = iter(lambda: cg.read(32768), '')
     return wireprototypes.streamres(gen=gen)
 
-@wireprotocommand('changegroupsubset', 'bases heads',
-                  permission='pull')
+
+@wireprotocommand('changegroupsubset', 'bases heads', permission='pull')
 def changegroupsubset(repo, proto, bases, heads):
     bases = wireprototypes.decodelist(bases)
     heads = wireprototypes.decodelist(heads)
-    outgoing = discovery.outgoing(repo, missingroots=bases,
-                                  missingheads=heads)
+    outgoing = discovery.outgoing(repo, missingroots=bases, missingheads=heads)
     cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
     gen = iter(lambda: cg.read(32768), '')
     return wireprototypes.streamres(gen=gen)
 
-@wireprotocommand('debugwireargs', 'one two *',
-                  permission='pull')
+
+@wireprotocommand('debugwireargs', 'one two *', permission='pull')
 def debugwireargs(repo, proto, one, two, others):
     # only accept optional args from the known set
     opts = options('debugwireargs', ['three', 'four'], others)
-    return wireprototypes.bytesresponse(repo.debugwireargs(
-        one, two, **pycompat.strkwargs(opts)))
+    return wireprototypes.bytesresponse(
+        repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
+    )
+
 
 def find_pullbundle(repo, proto, opts, clheads, heads, common):
     """Return a file object for the first matching pullbundle.
@@ -344,6 +379,7 @@
       E.g. do not send a bundle of all changes if the client wants only
       one specific branch of many.
     """
+
     def decodehexstring(s):
         return {binascii.unhexlify(h) for h in s.split(';')}
 
@@ -372,11 +408,13 @@
                 # Bad heads entry
                 continue
             if bundle_heads.issubset(common):
-                continue # Nothing new
+                continue  # Nothing new
             if all(cl.rev(rev) in common_anc for rev in bundle_heads):
-                continue # Still nothing new
-            if any(cl.rev(rev) not in heads_anc and
-                   cl.rev(rev) not in common_anc for rev in bundle_heads):
+                continue  # Still nothing new
+            if any(
+                cl.rev(rev) not in heads_anc and cl.rev(rev) not in common_anc
+                for rev in bundle_heads
+            ):
                 continue
         if 'bases' in entry:
             try:
@@ -395,10 +433,12 @@
             continue
     return None
 
+
 @wireprotocommand('getbundle', '*', permission='pull')
 def getbundle(repo, proto, others):
-    opts = options('getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(),
-                   others)
+    opts = options(
+        'getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(), others
+    )
     for k, v in opts.iteritems():
         keytype = wireprototypes.GETBUNDLE_ARGUMENTS[k]
         if keytype == 'nodes':
@@ -415,28 +455,29 @@
             else:
                 opts[k] = bool(v)
         elif keytype != 'plain':
-            raise KeyError('unknown getbundle option type %s'
-                           % keytype)
+            raise KeyError('unknown getbundle option type %s' % keytype)
 
     if not bundle1allowed(repo, 'pull'):
         if not exchange.bundle2requested(opts.get('bundlecaps')):
             if proto.name == 'http-v1':
                 return wireprototypes.ooberror(bundle2required)
-            raise error.Abort(bundle2requiredmain,
-                              hint=bundle2requiredhint)
+            raise error.Abort(bundle2requiredmain, hint=bundle2requiredhint)
 
     try:
         clheads = set(repo.changelog.heads())
         heads = set(opts.get('heads', set()))
         common = set(opts.get('common', set()))
         common.discard(nullid)
-        if (repo.ui.configbool('server', 'pullbundle') and
-            'partial-pull' in proto.getprotocaps()):
+        if (
+            repo.ui.configbool('server', 'pullbundle')
+            and 'partial-pull' in proto.getprotocaps()
+        ):
             # Check if a pre-built bundle covers this request.
             bundle = find_pullbundle(repo, proto, opts, clheads, heads, common)
             if bundle:
-                return wireprototypes.streamres(gen=util.filechunkiter(bundle),
-                                                prefer_uncompressed=True)
+                return wireprototypes.streamres(
+                    gen=util.filechunkiter(bundle), prefer_uncompressed=True
+                )
 
         if repo.ui.configbool('server', 'disablefullbundle'):
             # Check to see if this is a full clone.
@@ -444,36 +485,40 @@
             if changegroup and not common and clheads == heads:
                 raise error.Abort(
                     _('server has pull-based clones disabled'),
-                    hint=_('remove --pull if specified or upgrade Mercurial'))
+                    hint=_('remove --pull if specified or upgrade Mercurial'),
+                )
 
-        info, chunks = exchange.getbundlechunks(repo, 'serve',
-                                                **pycompat.strkwargs(opts))
+        info, chunks = exchange.getbundlechunks(
+            repo, 'serve', **pycompat.strkwargs(opts)
+        )
         prefercompressed = info.get('prefercompressed', True)
     except error.Abort as exc:
         # cleanly forward Abort error to the client
         if not exchange.bundle2requested(opts.get('bundlecaps')):
             if proto.name == 'http-v1':
                 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
-            raise # cannot do better for bundle1 + ssh
+            raise  # cannot do better for bundle1 + ssh
         # bundle2 request expect a bundle2 reply
         bundler = bundle2.bundle20(repo.ui)
         manargs = [('message', pycompat.bytestr(exc))]
         advargs = []
         if exc.hint is not None:
             advargs.append(('hint', exc.hint))
-        bundler.addpart(bundle2.bundlepart('error:abort',
-                                           manargs, advargs))
+        bundler.addpart(bundle2.bundlepart('error:abort', manargs, advargs))
         chunks = bundler.getchunks()
         prefercompressed = False
 
     return wireprototypes.streamres(
-        gen=chunks, prefer_uncompressed=not prefercompressed)
+        gen=chunks, prefer_uncompressed=not prefercompressed
+    )
+
 
 @wireprotocommand('heads', permission='pull')
 def heads(repo, proto):
     h = repo.heads()
     return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
 
+
 @wireprotocommand('hello', permission='pull')
 def hello(repo, proto):
     """Called as part of SSH handshake to obtain server info.
@@ -489,11 +534,13 @@
     caps = capabilities(repo, proto).data
     return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
 
+
 @wireprotocommand('listkeys', 'namespace', permission='pull')
 def listkeys(repo, proto, namespace):
     d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
     return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
 
+
 @wireprotocommand('lookup', 'key', permission='pull')
 def lookup(repo, proto, key):
     try:
@@ -506,18 +553,22 @@
         success = 0
     return wireprototypes.bytesresponse('%d %s\n' % (success, r))
 
+
 @wireprotocommand('known', 'nodes *', permission='pull')
 def known(repo, proto, nodes, others):
-    v = ''.join(b and '1' or '0'
-                for b in repo.known(wireprototypes.decodelist(nodes)))
+    v = ''.join(
+        b and '1' or '0' for b in repo.known(wireprototypes.decodelist(nodes))
+    )
     return wireprototypes.bytesresponse(v)
 
+
 @wireprotocommand('protocaps', 'caps', permission='pull')
 def protocaps(repo, proto, caps):
     if proto.name == wireprototypes.SSHV1:
         proto._protocaps = set(caps.split(' '))
     return wireprototypes.bytesresponse('OK')
 
+
 @wireprotocommand('pushkey', 'namespace key old new', permission='push')
 def pushkey(repo, proto, namespace, key, old, new):
     # compatibility with pre-1.8 clients which were accidentally
@@ -526,27 +577,35 @@
         # looks like it could be a binary node
         try:
             new.decode('utf-8')
-            new = encoding.tolocal(new) # but cleanly decodes as UTF-8
+            new = encoding.tolocal(new)  # but cleanly decodes as UTF-8
         except UnicodeDecodeError:
-            pass # binary, leave unmodified
+            pass  # binary, leave unmodified
     else:
-        new = encoding.tolocal(new) # normal path
+        new = encoding.tolocal(new)  # normal path
 
     with proto.mayberedirectstdio() as output:
-        r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
-                         encoding.tolocal(old), new) or False
+        r = (
+            repo.pushkey(
+                encoding.tolocal(namespace),
+                encoding.tolocal(key),
+                encoding.tolocal(old),
+                new,
+            )
+            or False
+        )
 
     output = output.getvalue() if output else ''
     return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
 
+
 @wireprotocommand('stream_out', permission='pull')
 def stream(repo, proto):
     '''If the server supports streaming clone, it advertises the "stream"
     capability with a value representing the version and flags of the repo
     it is serving. Client checks to see if it understands the format.
     '''
-    return wireprototypes.streamreslegacy(
-        streamclone.generatev1wireproto(repo))
+    return wireprototypes.streamreslegacy(streamclone.generatev1wireproto(repo))
+
 
 @wireprotocommand('unbundle', 'heads', permission='push')
 def unbundle(repo, proto, heads):
@@ -559,48 +618,57 @@
             try:
                 payload = proto.getpayload()
                 if repo.ui.configbool('server', 'streamunbundle'):
+
                     def cleanup():
                         # Ensure that the full payload is consumed, so
                         # that the connection doesn't contain trailing garbage.
                         for p in payload:
                             pass
+
                     fp = util.chunkbuffer(payload)
                 else:
                     # write bundle data to temporary file as it can be big
                     fp, tempname = None, None
+
                     def cleanup():
                         if fp:
                             fp.close()
                         if tempname:
                             os.unlink(tempname)
+
                     fd, tempname = pycompat.mkstemp(prefix='hg-unbundle-')
-                    repo.ui.debug('redirecting incoming bundle to %s\n' %
-                        tempname)
+                    repo.ui.debug(
+                        'redirecting incoming bundle to %s\n' % tempname
+                    )
                     fp = os.fdopen(fd, pycompat.sysstr('wb+'))
                     for p in payload:
                         fp.write(p)
                     fp.seek(0)
 
                 gen = exchange.readbundle(repo.ui, fp, None)
-                if (isinstance(gen, changegroupmod.cg1unpacker)
-                    and not bundle1allowed(repo, 'push')):
+                if isinstance(
+                    gen, changegroupmod.cg1unpacker
+                ) and not bundle1allowed(repo, 'push'):
                     if proto.name == 'http-v1':
                         # need to special case http because stderr do not get to
                         # the http client on failed push so we need to abuse
                         # some other error type to make sure the message get to
                         # the user.
                         return wireprototypes.ooberror(bundle2required)
-                    raise error.Abort(bundle2requiredmain,
-                                      hint=bundle2requiredhint)
+                    raise error.Abort(
+                        bundle2requiredmain, hint=bundle2requiredhint
+                    )
 
-                r = exchange.unbundle(repo, gen, their_heads, 'serve',
-                                      proto.client())
+                r = exchange.unbundle(
+                    repo, gen, their_heads, 'serve', proto.client()
+                )
                 if util.safehasattr(r, 'addpart'):
                     # The return looks streamable, we are in the bundle2 case
                     # and should return a stream.
                     return wireprototypes.streamreslegacy(gen=r.getchunks())
                 return wireprototypes.pushres(
-                    r, output.getvalue() if output else '')
+                    r, output.getvalue() if output else ''
+                )
 
             finally:
                 cleanup()
@@ -620,11 +688,13 @@
                         procutil.stderr.write("(%s)\n" % exc.hint)
                     procutil.stderr.flush()
                     return wireprototypes.pushres(
-                        0, output.getvalue() if output else '')
+                        0, output.getvalue() if output else ''
+                    )
                 except error.PushRaced:
                     return wireprototypes.pusherr(
                         pycompat.bytestr(exc),
-                        output.getvalue() if output else '')
+                        output.getvalue() if output else '',
+                    )
 
             bundler = bundle2.bundle20(repo.ui)
             for out in getattr(exc, '_bundle2salvagedoutput', ()):
@@ -635,15 +705,18 @@
                 except error.PushkeyFailed as exc:
                     # check client caps
                     remotecaps = getattr(exc, '_replycaps', None)
-                    if (remotecaps is not None
-                            and 'pushkey' not in remotecaps.get('error', ())):
+                    if (
+                        remotecaps is not None
+                        and 'pushkey' not in remotecaps.get('error', ())
+                    ):
                         # no support remote side, fallback to Abort handler.
                         raise
                     part = bundler.newpart('error:pushkey')
                     part.addparam('in-reply-to', exc.partid)
                     if exc.namespace is not None:
-                        part.addparam('namespace', exc.namespace,
-                                      mandatory=False)
+                        part.addparam(
+                            'namespace', exc.namespace, mandatory=False
+                        )
                     if exc.key is not None:
                         part.addparam('key', exc.key, mandatory=False)
                     if exc.new is not None:
@@ -663,9 +736,12 @@
                 advargs = []
                 if exc.hint is not None:
                     advargs.append(('hint', exc.hint))
-                bundler.addpart(bundle2.bundlepart('error:abort',
-                                                   manargs, advargs))
+                bundler.addpart(
+                    bundle2.bundlepart('error:abort', manargs, advargs)
+                )
             except error.PushRaced as exc:
-                bundler.newpart('error:pushraced',
-                                [('message', stringutil.forcebytestr(exc))])
+                bundler.newpart(
+                    'error:pushraced',
+                    [('message', stringutil.forcebytestr(exc))],
+                )
             return wireprototypes.streamreslegacy(gen=bundler.getchunks())