--- a/mercurial/wireproto.py Tue Mar 06 15:08:33 2018 -0800
+++ b/mercurial/wireproto.py Wed Mar 07 16:02:24 2018 -0800
@@ -592,10 +592,12 @@
class commandentry(object):
"""Represents a declared wire protocol command."""
- def __init__(self, func, args='', transports=None):
+ def __init__(self, func, args='', transports=None,
+ permission='push'):
self.func = func
self.args = args
self.transports = transports or set()
+ self.permission = permission
def _merge(self, func, args):
"""Merge this instance with an incoming 2-tuple.
@@ -605,7 +607,8 @@
data not captured by the 2-tuple and a new instance containing
the union of the two objects is returned.
"""
- return commandentry(func, args=args, transports=set(self.transports))
+ return commandentry(func, args=args, transports=set(self.transports),
+ permission=self.permission)
# Old code treats instances as 2-tuples. So expose that interface.
def __iter__(self):
@@ -643,7 +646,8 @@
else:
# Use default values from @wireprotocommand.
v = commandentry(v[0], args=v[1],
- transports=set(wireprototypes.TRANSPORTS))
+ transports=set(wireprototypes.TRANSPORTS),
+ permission='push')
else:
raise ValueError('command entries must be commandentry instances '
'or 2-tuples')
@@ -672,12 +676,8 @@
commands = commanddict()
-# Maps wire protocol name to operation type. This is used for permissions
-# checking. All defined @wireiprotocommand should have an entry in this
-# dict.
-permissions = {}
-
-def wireprotocommand(name, args='', transportpolicy=POLICY_ALL):
+def wireprotocommand(name, args='', transportpolicy=POLICY_ALL,
+ permission='push'):
"""Decorator to declare a wire protocol command.
``name`` is the name of the wire protocol command being provided.
@@ -688,6 +688,12 @@
``transportpolicy`` is a POLICY_* constant denoting which transports
this wire protocol command should be exposed to. By default, commands
are exposed to all wire protocol transports.
+
+ ``permission`` defines the permission type needed to run this command.
+ Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
+ respectively. Default is to assume command requires ``push`` permissions
+ because otherwise commands not declaring their permissions could modify
+ a repository that is supposed to be read-only.
"""
if transportpolicy == POLICY_ALL:
transports = set(wireprototypes.TRANSPORTS)
@@ -701,14 +707,18 @@
raise error.Abort(_('invalid transport policy value: %s') %
transportpolicy)
+ if permission not in ('push', 'pull'):
+ raise error.Abort(_('invalid wire protocol permission; got %s; '
+ 'expected "push" or "pull"') % permission)
+
def register(func):
- commands[name] = commandentry(func, args=args, transports=transports)
+ commands[name] = commandentry(func, args=args, transports=transports,
+ permission=permission)
return func
return register
# TODO define a more appropriate permissions type to use for this.
-permissions['batch'] = 'pull'
-@wireprotocommand('batch', 'cmds *')
+@wireprotocommand('batch', 'cmds *', permission='pull')
def batch(repo, proto, cmds, others):
repo = repo.filtered("served")
res = []
@@ -725,11 +735,9 @@
# checking on each batched command.
# TODO formalize permission checking as part of protocol interface.
if util.safehasattr(proto, 'checkperm'):
- # Assume commands with no defined permissions are writes / for
- # pushes. This is the safest from a security perspective because
- # it doesn't allow commands with undefined semantics from
- # bypassing permissions checks.
- proto.checkperm(permissions.get(op, 'push'))
+ perm = commands[op].permission
+ assert perm in ('push', 'pull')
+ proto.checkperm(perm)
if spec:
keys = spec.split()
@@ -758,8 +766,8 @@
return bytesresponse(';'.join(res))
-permissions['between'] = 'pull'
-@wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY)
+@wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
+ permission='pull')
def between(repo, proto, pairs):
pairs = [decodelist(p, '-') for p in pairs.split(" ")]
r = []
@@ -768,8 +776,7 @@
return bytesresponse(''.join(r))
-permissions['branchmap'] = 'pull'
-@wireprotocommand('branchmap')
+@wireprotocommand('branchmap', permission='pull')
def branchmap(repo, proto):
branchmap = repo.branchmap()
heads = []
@@ -780,8 +787,8 @@
return bytesresponse('\n'.join(heads))
-permissions['branches'] = 'pull'
-@wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY)
+@wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
+ permission='pull')
def branches(repo, proto, nodes):
nodes = decodelist(nodes)
r = []
@@ -790,8 +797,7 @@
return bytesresponse(''.join(r))
-permissions['clonebundles'] = 'pull'
-@wireprotocommand('clonebundles', '')
+@wireprotocommand('clonebundles', '', permission='pull')
def clonebundles(repo, proto):
"""Server command for returning info for available bundles to seed clones.
@@ -843,13 +849,12 @@
# If you are writing an extension and consider wrapping this function. Wrap
# `_capabilities` instead.
-permissions['capabilities'] = 'pull'
-@wireprotocommand('capabilities')
+@wireprotocommand('capabilities', permission='pull')
def capabilities(repo, proto):
return bytesresponse(' '.join(_capabilities(repo, proto)))
-permissions['changegroup'] = 'pull'
-@wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY)
+@wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
+ permission='pull')
def changegroup(repo, proto, roots):
nodes = decodelist(roots)
outgoing = discovery.outgoing(repo, missingroots=nodes,
@@ -858,9 +863,9 @@
gen = iter(lambda: cg.read(32768), '')
return streamres(gen=gen)
-permissions['changegroupsubset'] = 'pull'
@wireprotocommand('changegroupsubset', 'bases heads',
- transportpolicy=POLICY_V1_ONLY)
+ transportpolicy=POLICY_V1_ONLY,
+ permission='pull')
def changegroupsubset(repo, proto, bases, heads):
bases = decodelist(bases)
heads = decodelist(heads)
@@ -870,16 +875,15 @@
gen = iter(lambda: cg.read(32768), '')
return streamres(gen=gen)
-permissions['debugwireargs'] = 'pull'
-@wireprotocommand('debugwireargs', 'one two *')
+@wireprotocommand('debugwireargs', 'one two *',
+ permission='pull')
def debugwireargs(repo, proto, one, two, others):
# only accept optional args from the known set
opts = options('debugwireargs', ['three', 'four'], others)
return bytesresponse(repo.debugwireargs(one, two,
**pycompat.strkwargs(opts)))
-permissions['getbundle'] = 'pull'
-@wireprotocommand('getbundle', '*')
+@wireprotocommand('getbundle', '*', permission='pull')
def getbundle(repo, proto, others):
opts = options('getbundle', gboptsmap.keys(), others)
for k, v in opts.iteritems():
@@ -945,14 +949,12 @@
return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
-permissions['heads'] = 'pull'
-@wireprotocommand('heads')
+@wireprotocommand('heads', permission='pull')
def heads(repo, proto):
h = repo.heads()
return bytesresponse(encodelist(h) + '\n')
-permissions['hello'] = 'pull'
-@wireprotocommand('hello')
+@wireprotocommand('hello', permission='pull')
def hello(repo, proto):
"""Called as part of SSH handshake to obtain server info.
@@ -967,14 +969,12 @@
caps = capabilities(repo, proto).data
return bytesresponse('capabilities: %s\n' % caps)
-permissions['listkeys'] = 'pull'
-@wireprotocommand('listkeys', 'namespace')
+@wireprotocommand('listkeys', 'namespace', permission='pull')
def listkeys(repo, proto, namespace):
d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
return bytesresponse(pushkeymod.encodekeys(d))
-permissions['lookup'] = 'pull'
-@wireprotocommand('lookup', 'key')
+@wireprotocommand('lookup', 'key', permission='pull')
def lookup(repo, proto, key):
try:
k = encoding.tolocal(key)
@@ -986,14 +986,12 @@
success = 0
return bytesresponse('%d %s\n' % (success, r))
-permissions['known'] = 'pull'
-@wireprotocommand('known', 'nodes *')
+@wireprotocommand('known', 'nodes *', permission='pull')
def known(repo, proto, nodes, others):
v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
return bytesresponse(v)
-permissions['pushkey'] = 'push'
-@wireprotocommand('pushkey', 'namespace key old new')
+@wireprotocommand('pushkey', 'namespace key old new', permission='push')
def pushkey(repo, proto, namespace, key, old, new):
# compatibility with pre-1.8 clients which were accidentally
# sending raw binary nodes rather than utf-8-encoded hex
@@ -1014,8 +1012,7 @@
output = output.getvalue() if output else ''
return bytesresponse('%d\n%s' % (int(r), output))
-permissions['stream_out'] = 'pull'
-@wireprotocommand('stream_out')
+@wireprotocommand('stream_out', permission='pull')
def stream(repo, proto):
'''If the server supports streaming clone, it advertises the "stream"
capability with a value representing the version and flags of the repo
@@ -1023,8 +1020,7 @@
'''
return streamres_legacy(streamclone.generatev1wireproto(repo))
-permissions['unbundle'] = 'push'
-@wireprotocommand('unbundle', 'heads')
+@wireprotocommand('unbundle', 'heads', permission='push')
def unbundle(repo, proto, heads):
their_heads = decodelist(heads)