--- a/mercurial/wireproto.py Fri Apr 06 16:49:57 2018 -0700
+++ b/mercurial/wireproto.py Fri Apr 06 17:14:06 2018 -0700
@@ -713,8 +713,11 @@
``name`` is the name of the wire protocol command being provided.
- ``args`` is a space-delimited list of named arguments that the command
- accepts. ``*`` is a special value that says to accept all arguments.
+ ``args`` defines the named arguments accepted by the command. It is
+ ideally a dict mapping argument names to their types. For backwards
+ compatibility, it can be a space-delimited list of argument names. For
+ version 1 transports, ``*`` denotes a special value that says to accept
+ all named arguments.
``transportpolicy`` is a POLICY_* constant denoting which transports
this wire protocol command should be exposed to. By default, commands
@@ -752,6 +755,17 @@
'got %s; expected "push" or "pull"' %
permission)
+ if 1 in transportversions and not isinstance(args, bytes):
+ raise error.ProgrammingError('arguments for version 1 commands must '
+ 'be declared as bytes')
+
+ if isinstance(args, bytes):
+ dictargs = {arg: b'legacy' for arg in args.split()}
+ elif isinstance(args, dict):
+ dictargs = args
+ else:
+ raise ValueError('args must be bytes or a dict')
+
def register(func):
if 1 in transportversions:
if name in commands:
@@ -764,7 +778,8 @@
if name in commandsv2:
raise error.ProgrammingError('%s command already registered '
'for version 2' % name)
- commandsv2[name] = commandentry(func, args=args,
+
+ commandsv2[name] = commandentry(func, args=dictargs,
transports=transports,
permission=permission)
@@ -1304,7 +1319,7 @@
for command, entry in commandsv2.items():
caps['commands'][command] = {
- 'args': sorted(entry.args.split()) if entry.args else [],
+ 'args': entry.args,
'permissions': [entry.permission],
}
@@ -1325,7 +1340,11 @@
return wireprototypes.cborresponse(caps)
-@wireprotocommand('heads', args='publiconly', permission='pull',
+@wireprotocommand('heads',
+ args={
+ 'publiconly': False,
+ },
+ permission='pull',
transportpolicy=POLICY_V2_ONLY)
def headsv2(repo, proto, publiconly=False):
if publiconly:
@@ -1333,14 +1352,22 @@
return wireprototypes.cborresponse(repo.heads())
-@wireprotocommand('known', 'nodes', permission='pull',
+@wireprotocommand('known',
+ args={
+ 'nodes': [b'deadbeef'],
+ },
+ permission='pull',
transportpolicy=POLICY_V2_ONLY)
def knownv2(repo, proto, nodes=None):
nodes = nodes or []
result = b''.join(b'1' if n else b'0' for n in repo.known(nodes))
return wireprototypes.cborresponse(result)
-@wireprotocommand('listkeys', 'namespace', permission='pull',
+@wireprotocommand('listkeys',
+ args={
+ 'namespace': b'ns',
+ },
+ permission='pull',
transportpolicy=POLICY_V2_ONLY)
def listkeysv2(repo, proto, namespace=None):
keys = repo.listkeys(encoding.tolocal(namespace))