--- a/mercurial/wireprotov1server.py Sat Oct 05 10:29:34 2019 -0400
+++ b/mercurial/wireprotov1server.py Sun Oct 06 09:45:02 2019 -0400
@@ -39,10 +39,12 @@
urlreq = util.urlreq
bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
-bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
- 'IncompatibleClient')
+bundle2requiredhint = _(
+ 'see https://www.mercurial-scm.org/wiki/' 'IncompatibleClient'
+)
bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
+
def clientcompressionsupport(proto):
"""Returns a list of compression methods supported by the client.
@@ -55,8 +57,10 @@
return cap[5:].split(',')
return ['zlib', 'none']
+
# wire protocol command can either return a string or one of these classes.
+
def getdispatchrepo(repo, proto, command):
"""Obtain the repo used for processing wire protocol commands.
@@ -67,6 +71,7 @@
viewconfig = repo.ui.config('server', 'view')
return repo.filtered(viewconfig)
+
def dispatch(repo, proto, command):
repo = getdispatchrepo(repo, proto, command)
@@ -75,6 +80,7 @@
return func(repo, proto, *args)
+
def options(cmd, keys, others):
opts = {}
for k in keys:
@@ -82,10 +88,13 @@
opts[k] = others[k]
del others[k]
if others:
- procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
- % (cmd, ",".join(others)))
+ procutil.stderr.write(
+ "warning: %s ignored unexpected arguments %s\n"
+ % (cmd, ",".join(others))
+ )
return opts
+
def bundle1allowed(repo, action):
"""Whether a bundle1 operation is allowed from the server.
@@ -115,8 +124,10 @@
return ui.configbool('server', 'bundle1')
+
commands = wireprototypes.commanddict()
+
def wireprotocommand(name, args=None, permission='push'):
"""Decorator to declare a wire protocol command.
@@ -132,8 +143,9 @@
because otherwise commands not declaring their permissions could modify
a repository that is supposed to be read-only.
"""
- transports = {k for k, v in wireprototypes.TRANSPORTS.items()
- if v['version'] == 1}
+ transports = {
+ k for k, v in wireprototypes.TRANSPORTS.items() if v['version'] == 1
+ }
# Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
# SSHv2.
@@ -142,27 +154,33 @@
transports.add(wireprototypes.SSHV2)
if permission not in ('push', 'pull'):
- raise error.ProgrammingError('invalid wire protocol permission; '
- 'got %s; expected "push" or "pull"' %
- permission)
+ raise error.ProgrammingError(
+ 'invalid wire protocol permission; '
+ 'got %s; expected "push" or "pull"' % permission
+ )
if args is None:
args = ''
if not isinstance(args, bytes):
- raise error.ProgrammingError('arguments for version 1 commands '
- 'must be declared as bytes')
+ raise error.ProgrammingError(
+ 'arguments for version 1 commands ' 'must be declared as bytes'
+ )
def register(func):
if name in commands:
- raise error.ProgrammingError('%s command already registered '
- 'for version 1' % name)
+ raise error.ProgrammingError(
+ '%s command already registered ' 'for version 1' % name
+ )
commands[name] = wireprototypes.commandentry(
- func, args=args, transports=transports, permission=permission)
+ func, args=args, transports=transports, permission=permission
+ )
return func
+
return register
+
# TODO define a more appropriate permissions type to use for this.
@wireprotocommand('batch', 'cmds *', permission='pull')
def batch(repo, proto, cmds, others):
@@ -209,6 +227,7 @@
return wireprototypes.bytesresponse(';'.join(res))
+
@wireprotocommand('between', 'pairs', permission='pull')
def between(repo, proto, pairs):
pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
@@ -218,6 +237,7 @@
return wireprototypes.bytesresponse(''.join(r))
+
@wireprotocommand('branchmap', permission='pull')
def branchmap(repo, proto):
branchmap = repo.branchmap()
@@ -229,6 +249,7 @@
return wireprototypes.bytesresponse('\n'.join(heads))
+
@wireprotocommand('branches', 'nodes', permission='pull')
def branches(repo, proto, nodes):
nodes = wireprototypes.decodelist(nodes)
@@ -238,6 +259,7 @@
return wireprototypes.bytesresponse(''.join(r))
+
@wireprotocommand('clonebundles', '', permission='pull')
def clonebundles(repo, proto):
"""Server command for returning info for available bundles to seed clones.
@@ -249,10 +271,19 @@
data center given the client's IP address.
"""
return wireprototypes.bytesresponse(
- repo.vfs.tryread('clonebundles.manifest'))
+ repo.vfs.tryread('clonebundles.manifest')
+ )
+
-wireprotocaps = ['lookup', 'branchmap', 'pushkey',
- 'known', 'getbundle', 'unbundlehash']
+wireprotocaps = [
+ 'lookup',
+ 'branchmap',
+ 'pushkey',
+ 'known',
+ 'getbundle',
+ 'unbundlehash',
+]
+
def _capabilities(repo, proto):
"""return a list of capabilities for a repo
@@ -294,6 +325,7 @@
return proto.addcapabilities(repo, caps)
+
# If you are writing an extension and consider wrapping this function. Wrap
# `_capabilities` instead.
@wireprotocommand('capabilities', permission='pull')
@@ -301,33 +333,36 @@
caps = _capabilities(repo, proto)
return wireprototypes.bytesresponse(' '.join(sorted(caps)))
+
@wireprotocommand('changegroup', 'roots', permission='pull')
def changegroup(repo, proto, roots):
nodes = wireprototypes.decodelist(roots)
- outgoing = discovery.outgoing(repo, missingroots=nodes,
- missingheads=repo.heads())
+ outgoing = discovery.outgoing(
+ repo, missingroots=nodes, missingheads=repo.heads()
+ )
cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
gen = iter(lambda: cg.read(32768), '')
return wireprototypes.streamres(gen=gen)
-@wireprotocommand('changegroupsubset', 'bases heads',
- permission='pull')
+
+@wireprotocommand('changegroupsubset', 'bases heads', permission='pull')
def changegroupsubset(repo, proto, bases, heads):
bases = wireprototypes.decodelist(bases)
heads = wireprototypes.decodelist(heads)
- outgoing = discovery.outgoing(repo, missingroots=bases,
- missingheads=heads)
+ outgoing = discovery.outgoing(repo, missingroots=bases, missingheads=heads)
cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
gen = iter(lambda: cg.read(32768), '')
return wireprototypes.streamres(gen=gen)
-@wireprotocommand('debugwireargs', 'one two *',
- permission='pull')
+
+@wireprotocommand('debugwireargs', 'one two *', permission='pull')
def debugwireargs(repo, proto, one, two, others):
# only accept optional args from the known set
opts = options('debugwireargs', ['three', 'four'], others)
- return wireprototypes.bytesresponse(repo.debugwireargs(
- one, two, **pycompat.strkwargs(opts)))
+ return wireprototypes.bytesresponse(
+ repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
+ )
+
def find_pullbundle(repo, proto, opts, clheads, heads, common):
"""Return a file object for the first matching pullbundle.
@@ -344,6 +379,7 @@
E.g. do not send a bundle of all changes if the client wants only
one specific branch of many.
"""
+
def decodehexstring(s):
return {binascii.unhexlify(h) for h in s.split(';')}
@@ -372,11 +408,13 @@
# Bad heads entry
continue
if bundle_heads.issubset(common):
- continue # Nothing new
+ continue # Nothing new
if all(cl.rev(rev) in common_anc for rev in bundle_heads):
- continue # Still nothing new
- if any(cl.rev(rev) not in heads_anc and
- cl.rev(rev) not in common_anc for rev in bundle_heads):
+ continue # Still nothing new
+ if any(
+ cl.rev(rev) not in heads_anc and cl.rev(rev) not in common_anc
+ for rev in bundle_heads
+ ):
continue
if 'bases' in entry:
try:
@@ -395,10 +433,12 @@
continue
return None
+
@wireprotocommand('getbundle', '*', permission='pull')
def getbundle(repo, proto, others):
- opts = options('getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(),
- others)
+ opts = options(
+ 'getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(), others
+ )
for k, v in opts.iteritems():
keytype = wireprototypes.GETBUNDLE_ARGUMENTS[k]
if keytype == 'nodes':
@@ -415,28 +455,29 @@
else:
opts[k] = bool(v)
elif keytype != 'plain':
- raise KeyError('unknown getbundle option type %s'
- % keytype)
+ raise KeyError('unknown getbundle option type %s' % keytype)
if not bundle1allowed(repo, 'pull'):
if not exchange.bundle2requested(opts.get('bundlecaps')):
if proto.name == 'http-v1':
return wireprototypes.ooberror(bundle2required)
- raise error.Abort(bundle2requiredmain,
- hint=bundle2requiredhint)
+ raise error.Abort(bundle2requiredmain, hint=bundle2requiredhint)
try:
clheads = set(repo.changelog.heads())
heads = set(opts.get('heads', set()))
common = set(opts.get('common', set()))
common.discard(nullid)
- if (repo.ui.configbool('server', 'pullbundle') and
- 'partial-pull' in proto.getprotocaps()):
+ if (
+ repo.ui.configbool('server', 'pullbundle')
+ and 'partial-pull' in proto.getprotocaps()
+ ):
# Check if a pre-built bundle covers this request.
bundle = find_pullbundle(repo, proto, opts, clheads, heads, common)
if bundle:
- return wireprototypes.streamres(gen=util.filechunkiter(bundle),
- prefer_uncompressed=True)
+ return wireprototypes.streamres(
+ gen=util.filechunkiter(bundle), prefer_uncompressed=True
+ )
if repo.ui.configbool('server', 'disablefullbundle'):
# Check to see if this is a full clone.
@@ -444,36 +485,40 @@
if changegroup and not common and clheads == heads:
raise error.Abort(
_('server has pull-based clones disabled'),
- hint=_('remove --pull if specified or upgrade Mercurial'))
+ hint=_('remove --pull if specified or upgrade Mercurial'),
+ )
- info, chunks = exchange.getbundlechunks(repo, 'serve',
- **pycompat.strkwargs(opts))
+ info, chunks = exchange.getbundlechunks(
+ repo, 'serve', **pycompat.strkwargs(opts)
+ )
prefercompressed = info.get('prefercompressed', True)
except error.Abort as exc:
# cleanly forward Abort error to the client
if not exchange.bundle2requested(opts.get('bundlecaps')):
if proto.name == 'http-v1':
return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
- raise # cannot do better for bundle1 + ssh
+ raise # cannot do better for bundle1 + ssh
# bundle2 request expect a bundle2 reply
bundler = bundle2.bundle20(repo.ui)
manargs = [('message', pycompat.bytestr(exc))]
advargs = []
if exc.hint is not None:
advargs.append(('hint', exc.hint))
- bundler.addpart(bundle2.bundlepart('error:abort',
- manargs, advargs))
+ bundler.addpart(bundle2.bundlepart('error:abort', manargs, advargs))
chunks = bundler.getchunks()
prefercompressed = False
return wireprototypes.streamres(
- gen=chunks, prefer_uncompressed=not prefercompressed)
+ gen=chunks, prefer_uncompressed=not prefercompressed
+ )
+
@wireprotocommand('heads', permission='pull')
def heads(repo, proto):
h = repo.heads()
return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
+
@wireprotocommand('hello', permission='pull')
def hello(repo, proto):
"""Called as part of SSH handshake to obtain server info.
@@ -489,11 +534,13 @@
caps = capabilities(repo, proto).data
return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
+
@wireprotocommand('listkeys', 'namespace', permission='pull')
def listkeys(repo, proto, namespace):
d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
+
@wireprotocommand('lookup', 'key', permission='pull')
def lookup(repo, proto, key):
try:
@@ -506,18 +553,22 @@
success = 0
return wireprototypes.bytesresponse('%d %s\n' % (success, r))
+
@wireprotocommand('known', 'nodes *', permission='pull')
def known(repo, proto, nodes, others):
- v = ''.join(b and '1' or '0'
- for b in repo.known(wireprototypes.decodelist(nodes)))
+ v = ''.join(
+ b and '1' or '0' for b in repo.known(wireprototypes.decodelist(nodes))
+ )
return wireprototypes.bytesresponse(v)
+
@wireprotocommand('protocaps', 'caps', permission='pull')
def protocaps(repo, proto, caps):
if proto.name == wireprototypes.SSHV1:
proto._protocaps = set(caps.split(' '))
return wireprototypes.bytesresponse('OK')
+
@wireprotocommand('pushkey', 'namespace key old new', permission='push')
def pushkey(repo, proto, namespace, key, old, new):
# compatibility with pre-1.8 clients which were accidentally
@@ -526,27 +577,35 @@
# looks like it could be a binary node
try:
new.decode('utf-8')
- new = encoding.tolocal(new) # but cleanly decodes as UTF-8
+ new = encoding.tolocal(new) # but cleanly decodes as UTF-8
except UnicodeDecodeError:
- pass # binary, leave unmodified
+ pass # binary, leave unmodified
else:
- new = encoding.tolocal(new) # normal path
+ new = encoding.tolocal(new) # normal path
with proto.mayberedirectstdio() as output:
- r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
- encoding.tolocal(old), new) or False
+ r = (
+ repo.pushkey(
+ encoding.tolocal(namespace),
+ encoding.tolocal(key),
+ encoding.tolocal(old),
+ new,
+ )
+ or False
+ )
output = output.getvalue() if output else ''
return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
+
@wireprotocommand('stream_out', permission='pull')
def stream(repo, proto):
'''If the server supports streaming clone, it advertises the "stream"
capability with a value representing the version and flags of the repo
it is serving. Client checks to see if it understands the format.
'''
- return wireprototypes.streamreslegacy(
- streamclone.generatev1wireproto(repo))
+ return wireprototypes.streamreslegacy(streamclone.generatev1wireproto(repo))
+
@wireprotocommand('unbundle', 'heads', permission='push')
def unbundle(repo, proto, heads):
@@ -559,48 +618,57 @@
try:
payload = proto.getpayload()
if repo.ui.configbool('server', 'streamunbundle'):
+
def cleanup():
# Ensure that the full payload is consumed, so
# that the connection doesn't contain trailing garbage.
for p in payload:
pass
+
fp = util.chunkbuffer(payload)
else:
# write bundle data to temporary file as it can be big
fp, tempname = None, None
+
def cleanup():
if fp:
fp.close()
if tempname:
os.unlink(tempname)
+
fd, tempname = pycompat.mkstemp(prefix='hg-unbundle-')
- repo.ui.debug('redirecting incoming bundle to %s\n' %
- tempname)
+ repo.ui.debug(
+ 'redirecting incoming bundle to %s\n' % tempname
+ )
fp = os.fdopen(fd, pycompat.sysstr('wb+'))
for p in payload:
fp.write(p)
fp.seek(0)
gen = exchange.readbundle(repo.ui, fp, None)
- if (isinstance(gen, changegroupmod.cg1unpacker)
- and not bundle1allowed(repo, 'push')):
+ if isinstance(
+ gen, changegroupmod.cg1unpacker
+ ) and not bundle1allowed(repo, 'push'):
if proto.name == 'http-v1':
# need to special case http because stderr do not get to
# the http client on failed push so we need to abuse
# some other error type to make sure the message get to
# the user.
return wireprototypes.ooberror(bundle2required)
- raise error.Abort(bundle2requiredmain,
- hint=bundle2requiredhint)
+ raise error.Abort(
+ bundle2requiredmain, hint=bundle2requiredhint
+ )
- r = exchange.unbundle(repo, gen, their_heads, 'serve',
- proto.client())
+ r = exchange.unbundle(
+ repo, gen, their_heads, 'serve', proto.client()
+ )
if util.safehasattr(r, 'addpart'):
# The return looks streamable, we are in the bundle2 case
# and should return a stream.
return wireprototypes.streamreslegacy(gen=r.getchunks())
return wireprototypes.pushres(
- r, output.getvalue() if output else '')
+ r, output.getvalue() if output else ''
+ )
finally:
cleanup()
@@ -620,11 +688,13 @@
procutil.stderr.write("(%s)\n" % exc.hint)
procutil.stderr.flush()
return wireprototypes.pushres(
- 0, output.getvalue() if output else '')
+ 0, output.getvalue() if output else ''
+ )
except error.PushRaced:
return wireprototypes.pusherr(
pycompat.bytestr(exc),
- output.getvalue() if output else '')
+ output.getvalue() if output else '',
+ )
bundler = bundle2.bundle20(repo.ui)
for out in getattr(exc, '_bundle2salvagedoutput', ()):
@@ -635,15 +705,18 @@
except error.PushkeyFailed as exc:
# check client caps
remotecaps = getattr(exc, '_replycaps', None)
- if (remotecaps is not None
- and 'pushkey' not in remotecaps.get('error', ())):
+ if (
+ remotecaps is not None
+ and 'pushkey' not in remotecaps.get('error', ())
+ ):
# no support remote side, fallback to Abort handler.
raise
part = bundler.newpart('error:pushkey')
part.addparam('in-reply-to', exc.partid)
if exc.namespace is not None:
- part.addparam('namespace', exc.namespace,
- mandatory=False)
+ part.addparam(
+ 'namespace', exc.namespace, mandatory=False
+ )
if exc.key is not None:
part.addparam('key', exc.key, mandatory=False)
if exc.new is not None:
@@ -663,9 +736,12 @@
advargs = []
if exc.hint is not None:
advargs.append(('hint', exc.hint))
- bundler.addpart(bundle2.bundlepart('error:abort',
- manargs, advargs))
+ bundler.addpart(
+ bundle2.bundlepart('error:abort', manargs, advargs)
+ )
except error.PushRaced as exc:
- bundler.newpart('error:pushraced',
- [('message', stringutil.forcebytestr(exc))])
+ bundler.newpart(
+ 'error:pushraced',
+ [('message', stringutil.forcebytestr(exc))],
+ )
return wireprototypes.streamreslegacy(gen=bundler.getchunks())