Mercurial > hg
diff mercurial/wireprotov1server.py @ 43076:2372284d9457
formatting: blacken the codebase
This is using my patch to black
(https://github.com/psf/black/pull/826) so we don't un-wrap collection
literals.
Done with:
hg files 'set:**.py - mercurial/thirdparty/** - "contrib/python-zstandard/**"' | xargs black -S
# skip-blame mass-reformatting only
# no-check-commit reformats foo_bar functions
Differential Revision: https://phab.mercurial-scm.org/D6971
author | Augie Fackler <augie@google.com> |
---|---|
date | Sun, 06 Oct 2019 09:45:02 -0400 |
parents | 566daffc607d |
children | 687b865b95ad |
line wrap: on
line diff
--- a/mercurial/wireprotov1server.py Sat Oct 05 10:29:34 2019 -0400 +++ b/mercurial/wireprotov1server.py Sun Oct 06 09:45:02 2019 -0400 @@ -39,10 +39,12 @@ urlreq = util.urlreq bundle2requiredmain = _('incompatible Mercurial client; bundle2 required') -bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/' - 'IncompatibleClient') +bundle2requiredhint = _( + 'see https://www.mercurial-scm.org/wiki/' 'IncompatibleClient' +) bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint) + def clientcompressionsupport(proto): """Returns a list of compression methods supported by the client. @@ -55,8 +57,10 @@ return cap[5:].split(',') return ['zlib', 'none'] + # wire protocol command can either return a string or one of these classes. + def getdispatchrepo(repo, proto, command): """Obtain the repo used for processing wire protocol commands. @@ -67,6 +71,7 @@ viewconfig = repo.ui.config('server', 'view') return repo.filtered(viewconfig) + def dispatch(repo, proto, command): repo = getdispatchrepo(repo, proto, command) @@ -75,6 +80,7 @@ return func(repo, proto, *args) + def options(cmd, keys, others): opts = {} for k in keys: @@ -82,10 +88,13 @@ opts[k] = others[k] del others[k] if others: - procutil.stderr.write("warning: %s ignored unexpected arguments %s\n" - % (cmd, ",".join(others))) + procutil.stderr.write( + "warning: %s ignored unexpected arguments %s\n" + % (cmd, ",".join(others)) + ) return opts + def bundle1allowed(repo, action): """Whether a bundle1 operation is allowed from the server. @@ -115,8 +124,10 @@ return ui.configbool('server', 'bundle1') + commands = wireprototypes.commanddict() + def wireprotocommand(name, args=None, permission='push'): """Decorator to declare a wire protocol command. @@ -132,8 +143,9 @@ because otherwise commands not declaring their permissions could modify a repository that is supposed to be read-only. """ - transports = {k for k, v in wireprototypes.TRANSPORTS.items() - if v['version'] == 1} + transports = { + k for k, v in wireprototypes.TRANSPORTS.items() if v['version'] == 1 + } # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to # SSHv2. @@ -142,27 +154,33 @@ transports.add(wireprototypes.SSHV2) if permission not in ('push', 'pull'): - raise error.ProgrammingError('invalid wire protocol permission; ' - 'got %s; expected "push" or "pull"' % - permission) + raise error.ProgrammingError( + 'invalid wire protocol permission; ' + 'got %s; expected "push" or "pull"' % permission + ) if args is None: args = '' if not isinstance(args, bytes): - raise error.ProgrammingError('arguments for version 1 commands ' - 'must be declared as bytes') + raise error.ProgrammingError( + 'arguments for version 1 commands ' 'must be declared as bytes' + ) def register(func): if name in commands: - raise error.ProgrammingError('%s command already registered ' - 'for version 1' % name) + raise error.ProgrammingError( + '%s command already registered ' 'for version 1' % name + ) commands[name] = wireprototypes.commandentry( - func, args=args, transports=transports, permission=permission) + func, args=args, transports=transports, permission=permission + ) return func + return register + # TODO define a more appropriate permissions type to use for this. @wireprotocommand('batch', 'cmds *', permission='pull') def batch(repo, proto, cmds, others): @@ -209,6 +227,7 @@ return wireprototypes.bytesresponse(';'.join(res)) + @wireprotocommand('between', 'pairs', permission='pull') def between(repo, proto, pairs): pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")] @@ -218,6 +237,7 @@ return wireprototypes.bytesresponse(''.join(r)) + @wireprotocommand('branchmap', permission='pull') def branchmap(repo, proto): branchmap = repo.branchmap() @@ -229,6 +249,7 @@ return wireprototypes.bytesresponse('\n'.join(heads)) + @wireprotocommand('branches', 'nodes', permission='pull') def branches(repo, proto, nodes): nodes = wireprototypes.decodelist(nodes) @@ -238,6 +259,7 @@ return wireprototypes.bytesresponse(''.join(r)) + @wireprotocommand('clonebundles', '', permission='pull') def clonebundles(repo, proto): """Server command for returning info for available bundles to seed clones. @@ -249,10 +271,19 @@ data center given the client's IP address. """ return wireprototypes.bytesresponse( - repo.vfs.tryread('clonebundles.manifest')) + repo.vfs.tryread('clonebundles.manifest') + ) + -wireprotocaps = ['lookup', 'branchmap', 'pushkey', - 'known', 'getbundle', 'unbundlehash'] +wireprotocaps = [ + 'lookup', + 'branchmap', + 'pushkey', + 'known', + 'getbundle', + 'unbundlehash', +] + def _capabilities(repo, proto): """return a list of capabilities for a repo @@ -294,6 +325,7 @@ return proto.addcapabilities(repo, caps) + # If you are writing an extension and consider wrapping this function. Wrap # `_capabilities` instead. @wireprotocommand('capabilities', permission='pull') @@ -301,33 +333,36 @@ caps = _capabilities(repo, proto) return wireprototypes.bytesresponse(' '.join(sorted(caps))) + @wireprotocommand('changegroup', 'roots', permission='pull') def changegroup(repo, proto, roots): nodes = wireprototypes.decodelist(roots) - outgoing = discovery.outgoing(repo, missingroots=nodes, - missingheads=repo.heads()) + outgoing = discovery.outgoing( + repo, missingroots=nodes, missingheads=repo.heads() + ) cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve') gen = iter(lambda: cg.read(32768), '') return wireprototypes.streamres(gen=gen) -@wireprotocommand('changegroupsubset', 'bases heads', - permission='pull') + +@wireprotocommand('changegroupsubset', 'bases heads', permission='pull') def changegroupsubset(repo, proto, bases, heads): bases = wireprototypes.decodelist(bases) heads = wireprototypes.decodelist(heads) - outgoing = discovery.outgoing(repo, missingroots=bases, - missingheads=heads) + outgoing = discovery.outgoing(repo, missingroots=bases, missingheads=heads) cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve') gen = iter(lambda: cg.read(32768), '') return wireprototypes.streamres(gen=gen) -@wireprotocommand('debugwireargs', 'one two *', - permission='pull') + +@wireprotocommand('debugwireargs', 'one two *', permission='pull') def debugwireargs(repo, proto, one, two, others): # only accept optional args from the known set opts = options('debugwireargs', ['three', 'four'], others) - return wireprototypes.bytesresponse(repo.debugwireargs( - one, two, **pycompat.strkwargs(opts))) + return wireprototypes.bytesresponse( + repo.debugwireargs(one, two, **pycompat.strkwargs(opts)) + ) + def find_pullbundle(repo, proto, opts, clheads, heads, common): """Return a file object for the first matching pullbundle. @@ -344,6 +379,7 @@ E.g. do not send a bundle of all changes if the client wants only one specific branch of many. """ + def decodehexstring(s): return {binascii.unhexlify(h) for h in s.split(';')} @@ -372,11 +408,13 @@ # Bad heads entry continue if bundle_heads.issubset(common): - continue # Nothing new + continue # Nothing new if all(cl.rev(rev) in common_anc for rev in bundle_heads): - continue # Still nothing new - if any(cl.rev(rev) not in heads_anc and - cl.rev(rev) not in common_anc for rev in bundle_heads): + continue # Still nothing new + if any( + cl.rev(rev) not in heads_anc and cl.rev(rev) not in common_anc + for rev in bundle_heads + ): continue if 'bases' in entry: try: @@ -395,10 +433,12 @@ continue return None + @wireprotocommand('getbundle', '*', permission='pull') def getbundle(repo, proto, others): - opts = options('getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(), - others) + opts = options( + 'getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(), others + ) for k, v in opts.iteritems(): keytype = wireprototypes.GETBUNDLE_ARGUMENTS[k] if keytype == 'nodes': @@ -415,28 +455,29 @@ else: opts[k] = bool(v) elif keytype != 'plain': - raise KeyError('unknown getbundle option type %s' - % keytype) + raise KeyError('unknown getbundle option type %s' % keytype) if not bundle1allowed(repo, 'pull'): if not exchange.bundle2requested(opts.get('bundlecaps')): if proto.name == 'http-v1': return wireprototypes.ooberror(bundle2required) - raise error.Abort(bundle2requiredmain, - hint=bundle2requiredhint) + raise error.Abort(bundle2requiredmain, hint=bundle2requiredhint) try: clheads = set(repo.changelog.heads()) heads = set(opts.get('heads', set())) common = set(opts.get('common', set())) common.discard(nullid) - if (repo.ui.configbool('server', 'pullbundle') and - 'partial-pull' in proto.getprotocaps()): + if ( + repo.ui.configbool('server', 'pullbundle') + and 'partial-pull' in proto.getprotocaps() + ): # Check if a pre-built bundle covers this request. bundle = find_pullbundle(repo, proto, opts, clheads, heads, common) if bundle: - return wireprototypes.streamres(gen=util.filechunkiter(bundle), - prefer_uncompressed=True) + return wireprototypes.streamres( + gen=util.filechunkiter(bundle), prefer_uncompressed=True + ) if repo.ui.configbool('server', 'disablefullbundle'): # Check to see if this is a full clone. @@ -444,36 +485,40 @@ if changegroup and not common and clheads == heads: raise error.Abort( _('server has pull-based clones disabled'), - hint=_('remove --pull if specified or upgrade Mercurial')) + hint=_('remove --pull if specified or upgrade Mercurial'), + ) - info, chunks = exchange.getbundlechunks(repo, 'serve', - **pycompat.strkwargs(opts)) + info, chunks = exchange.getbundlechunks( + repo, 'serve', **pycompat.strkwargs(opts) + ) prefercompressed = info.get('prefercompressed', True) except error.Abort as exc: # cleanly forward Abort error to the client if not exchange.bundle2requested(opts.get('bundlecaps')): if proto.name == 'http-v1': return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n') - raise # cannot do better for bundle1 + ssh + raise # cannot do better for bundle1 + ssh # bundle2 request expect a bundle2 reply bundler = bundle2.bundle20(repo.ui) manargs = [('message', pycompat.bytestr(exc))] advargs = [] if exc.hint is not None: advargs.append(('hint', exc.hint)) - bundler.addpart(bundle2.bundlepart('error:abort', - manargs, advargs)) + bundler.addpart(bundle2.bundlepart('error:abort', manargs, advargs)) chunks = bundler.getchunks() prefercompressed = False return wireprototypes.streamres( - gen=chunks, prefer_uncompressed=not prefercompressed) + gen=chunks, prefer_uncompressed=not prefercompressed + ) + @wireprotocommand('heads', permission='pull') def heads(repo, proto): h = repo.heads() return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n') + @wireprotocommand('hello', permission='pull') def hello(repo, proto): """Called as part of SSH handshake to obtain server info. @@ -489,11 +534,13 @@ caps = capabilities(repo, proto).data return wireprototypes.bytesresponse('capabilities: %s\n' % caps) + @wireprotocommand('listkeys', 'namespace', permission='pull') def listkeys(repo, proto, namespace): d = sorted(repo.listkeys(encoding.tolocal(namespace)).items()) return wireprototypes.bytesresponse(pushkeymod.encodekeys(d)) + @wireprotocommand('lookup', 'key', permission='pull') def lookup(repo, proto, key): try: @@ -506,18 +553,22 @@ success = 0 return wireprototypes.bytesresponse('%d %s\n' % (success, r)) + @wireprotocommand('known', 'nodes *', permission='pull') def known(repo, proto, nodes, others): - v = ''.join(b and '1' or '0' - for b in repo.known(wireprototypes.decodelist(nodes))) + v = ''.join( + b and '1' or '0' for b in repo.known(wireprototypes.decodelist(nodes)) + ) return wireprototypes.bytesresponse(v) + @wireprotocommand('protocaps', 'caps', permission='pull') def protocaps(repo, proto, caps): if proto.name == wireprototypes.SSHV1: proto._protocaps = set(caps.split(' ')) return wireprototypes.bytesresponse('OK') + @wireprotocommand('pushkey', 'namespace key old new', permission='push') def pushkey(repo, proto, namespace, key, old, new): # compatibility with pre-1.8 clients which were accidentally @@ -526,27 +577,35 @@ # looks like it could be a binary node try: new.decode('utf-8') - new = encoding.tolocal(new) # but cleanly decodes as UTF-8 + new = encoding.tolocal(new) # but cleanly decodes as UTF-8 except UnicodeDecodeError: - pass # binary, leave unmodified + pass # binary, leave unmodified else: - new = encoding.tolocal(new) # normal path + new = encoding.tolocal(new) # normal path with proto.mayberedirectstdio() as output: - r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key), - encoding.tolocal(old), new) or False + r = ( + repo.pushkey( + encoding.tolocal(namespace), + encoding.tolocal(key), + encoding.tolocal(old), + new, + ) + or False + ) output = output.getvalue() if output else '' return wireprototypes.bytesresponse('%d\n%s' % (int(r), output)) + @wireprotocommand('stream_out', permission='pull') def stream(repo, proto): '''If the server supports streaming clone, it advertises the "stream" capability with a value representing the version and flags of the repo it is serving. Client checks to see if it understands the format. ''' - return wireprototypes.streamreslegacy( - streamclone.generatev1wireproto(repo)) + return wireprototypes.streamreslegacy(streamclone.generatev1wireproto(repo)) + @wireprotocommand('unbundle', 'heads', permission='push') def unbundle(repo, proto, heads): @@ -559,48 +618,57 @@ try: payload = proto.getpayload() if repo.ui.configbool('server', 'streamunbundle'): + def cleanup(): # Ensure that the full payload is consumed, so # that the connection doesn't contain trailing garbage. for p in payload: pass + fp = util.chunkbuffer(payload) else: # write bundle data to temporary file as it can be big fp, tempname = None, None + def cleanup(): if fp: fp.close() if tempname: os.unlink(tempname) + fd, tempname = pycompat.mkstemp(prefix='hg-unbundle-') - repo.ui.debug('redirecting incoming bundle to %s\n' % - tempname) + repo.ui.debug( + 'redirecting incoming bundle to %s\n' % tempname + ) fp = os.fdopen(fd, pycompat.sysstr('wb+')) for p in payload: fp.write(p) fp.seek(0) gen = exchange.readbundle(repo.ui, fp, None) - if (isinstance(gen, changegroupmod.cg1unpacker) - and not bundle1allowed(repo, 'push')): + if isinstance( + gen, changegroupmod.cg1unpacker + ) and not bundle1allowed(repo, 'push'): if proto.name == 'http-v1': # need to special case http because stderr do not get to # the http client on failed push so we need to abuse # some other error type to make sure the message get to # the user. return wireprototypes.ooberror(bundle2required) - raise error.Abort(bundle2requiredmain, - hint=bundle2requiredhint) + raise error.Abort( + bundle2requiredmain, hint=bundle2requiredhint + ) - r = exchange.unbundle(repo, gen, their_heads, 'serve', - proto.client()) + r = exchange.unbundle( + repo, gen, their_heads, 'serve', proto.client() + ) if util.safehasattr(r, 'addpart'): # The return looks streamable, we are in the bundle2 case # and should return a stream. return wireprototypes.streamreslegacy(gen=r.getchunks()) return wireprototypes.pushres( - r, output.getvalue() if output else '') + r, output.getvalue() if output else '' + ) finally: cleanup() @@ -620,11 +688,13 @@ procutil.stderr.write("(%s)\n" % exc.hint) procutil.stderr.flush() return wireprototypes.pushres( - 0, output.getvalue() if output else '') + 0, output.getvalue() if output else '' + ) except error.PushRaced: return wireprototypes.pusherr( pycompat.bytestr(exc), - output.getvalue() if output else '') + output.getvalue() if output else '', + ) bundler = bundle2.bundle20(repo.ui) for out in getattr(exc, '_bundle2salvagedoutput', ()): @@ -635,15 +705,18 @@ except error.PushkeyFailed as exc: # check client caps remotecaps = getattr(exc, '_replycaps', None) - if (remotecaps is not None - and 'pushkey' not in remotecaps.get('error', ())): + if ( + remotecaps is not None + and 'pushkey' not in remotecaps.get('error', ()) + ): # no support remote side, fallback to Abort handler. raise part = bundler.newpart('error:pushkey') part.addparam('in-reply-to', exc.partid) if exc.namespace is not None: - part.addparam('namespace', exc.namespace, - mandatory=False) + part.addparam( + 'namespace', exc.namespace, mandatory=False + ) if exc.key is not None: part.addparam('key', exc.key, mandatory=False) if exc.new is not None: @@ -663,9 +736,12 @@ advargs = [] if exc.hint is not None: advargs.append(('hint', exc.hint)) - bundler.addpart(bundle2.bundlepart('error:abort', - manargs, advargs)) + bundler.addpart( + bundle2.bundlepart('error:abort', manargs, advargs) + ) except error.PushRaced as exc: - bundler.newpart('error:pushraced', - [('message', stringutil.forcebytestr(exc))]) + bundler.newpart( + 'error:pushraced', + [('message', stringutil.forcebytestr(exc))], + ) return wireprototypes.streamreslegacy(gen=bundler.getchunks())