mercurial/wireproto.py
changeset 37535 69e46c1834ac
parent 37534 465187fec06f
child 37536 2003da12f49b
--- a/mercurial/wireproto.py	Fri Apr 06 16:49:57 2018 -0700
+++ b/mercurial/wireproto.py	Fri Apr 06 17:14:06 2018 -0700
@@ -713,8 +713,11 @@
 
     ``name`` is the name of the wire protocol command being provided.
 
-    ``args`` is a space-delimited list of named arguments that the command
-    accepts. ``*`` is a special value that says to accept all arguments.
+    ``args`` defines the named arguments accepted by the command. It is
+    ideally a dict mapping argument names to their types. For backwards
+    compatibility, it can be a space-delimited list of argument names. For
+    version 1 transports, ``*`` denotes a special value that says to accept
+    all named arguments.
 
     ``transportpolicy`` is a POLICY_* constant denoting which transports
     this wire protocol command should be exposed to. By default, commands
@@ -752,6 +755,17 @@
                                      'got %s; expected "push" or "pull"' %
                                      permission)
 
+    if 1 in transportversions and not isinstance(args, bytes):
+        raise error.ProgrammingError('arguments for version 1 commands must '
+                                     'be declared as bytes')
+
+    if isinstance(args, bytes):
+        dictargs = {arg: b'legacy' for arg in args.split()}
+    elif isinstance(args, dict):
+        dictargs = args
+    else:
+        raise ValueError('args must be bytes or a dict')
+
     def register(func):
         if 1 in transportversions:
             if name in commands:
@@ -764,7 +778,8 @@
             if name in commandsv2:
                 raise error.ProgrammingError('%s command already registered '
                                              'for version 2' % name)
-            commandsv2[name] = commandentry(func, args=args,
+
+            commandsv2[name] = commandentry(func, args=dictargs,
                                             transports=transports,
                                             permission=permission)
 
@@ -1304,7 +1319,7 @@
 
     for command, entry in commandsv2.items():
         caps['commands'][command] = {
-            'args': sorted(entry.args.split()) if entry.args else [],
+            'args': entry.args,
             'permissions': [entry.permission],
         }
 
@@ -1325,7 +1340,11 @@
 
     return wireprototypes.cborresponse(caps)
 
-@wireprotocommand('heads', args='publiconly', permission='pull',
+@wireprotocommand('heads',
+                  args={
+                      'publiconly': False,
+                  },
+                  permission='pull',
                   transportpolicy=POLICY_V2_ONLY)
 def headsv2(repo, proto, publiconly=False):
     if publiconly:
@@ -1333,14 +1352,22 @@
 
     return wireprototypes.cborresponse(repo.heads())
 
-@wireprotocommand('known', 'nodes', permission='pull',
+@wireprotocommand('known',
+                  args={
+                      'nodes': [b'deadbeef'],
+                  },
+                  permission='pull',
                   transportpolicy=POLICY_V2_ONLY)
 def knownv2(repo, proto, nodes=None):
     nodes = nodes or []
     result = b''.join(b'1' if n else b'0' for n in repo.known(nodes))
     return wireprototypes.cborresponse(result)
 
-@wireprotocommand('listkeys', 'namespace', permission='pull',
+@wireprotocommand('listkeys',
+                  args={
+                      'namespace': b'ns',
+                  },
+                  permission='pull',
                   transportpolicy=POLICY_V2_ONLY)
 def listkeysv2(repo, proto, namespace=None):
     keys = repo.listkeys(encoding.tolocal(namespace))