changeset 37780:8acd3a9ac4fd

wireproto: make version 2 @wireprotocommand an independent function Previously, the code for this decorator was shared between version 1 and version 2 commands. Very few parts of the function were identical. So I don't think sharing is justified. wireprotov2server now has its own @wireprotocommand decorator function. Because the decorator is no longer shared, code for configuring the transport policy has been removed. i.e. commands must have separate implementations for each wire protocol version. Differential Revision: https://phab.mercurial-scm.org/D3395
author Gregory Szorc <gregory.szorc@gmail.com>
date Mon, 16 Apr 2018 21:49:59 -0700
parents 379d54eae6eb
children 352932a11905
files mercurial/wireproto.py mercurial/wireprotov2server.py tests/wireprotohelpers.sh
diffstat 3 files changed, 55 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/wireproto.py	Mon Apr 16 21:38:52 2018 -0700
+++ b/mercurial/wireproto.py	Mon Apr 16 21:49:59 2018 -0700
@@ -251,32 +251,20 @@
 
         return True
 
-# Constants specifying which transports a wire protocol command should be
-# available on. For use with @wireprotocommand.
-POLICY_V1_ONLY = 'v1-only'
-POLICY_V2_ONLY = 'v2-only'
-
 # For version 1 transports.
 commands = commanddict()
 
 # For version 2 transports.
 commandsv2 = commanddict()
 
-def wireprotocommand(name, args=None, transportpolicy=POLICY_V1_ONLY,
-                     permission='push'):
+def wireprotocommand(name, args=None, permission='push'):
     """Decorator to declare a wire protocol command.
 
     ``name`` is the name of the wire protocol command being provided.
 
     ``args`` defines the named arguments accepted by the command. It is
-    ideally a dict mapping argument names to their types. For backwards
-    compatibility, it can be a space-delimited list of argument names. For
-    version 1 transports, ``*`` denotes a special value that says to accept
-    all named arguments.
-
-    ``transportpolicy`` is a POLICY_* constant denoting which transports
-    this wire protocol command should be exposed to. By default, commands
-    are exposed to all wire protocol transports.
+    a space-delimited list of argument names. ``*`` denotes a special value
+    that says to accept all named arguments.
 
     ``permission`` defines the permission type needed to run this command.
     Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
@@ -284,17 +272,8 @@
     because otherwise commands not declaring their permissions could modify
     a repository that is supposed to be read-only.
     """
-    if transportpolicy == POLICY_V1_ONLY:
-        transports = {k for k, v in wireprototypes.TRANSPORTS.items()
-                      if v['version'] == 1}
-        transportversion = 1
-    elif transportpolicy == POLICY_V2_ONLY:
-        transports = {k for k, v in wireprototypes.TRANSPORTS.items()
-                      if v['version'] == 2}
-        transportversion = 2
-    else:
-        raise error.ProgrammingError('invalid transport policy value: %s' %
-                                     transportpolicy)
+    transports = {k for k, v in wireprototypes.TRANSPORTS.items()
+                  if v['version'] == 1}
 
     # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
     # SSHv2.
@@ -307,40 +286,20 @@
                                      'got %s; expected "push" or "pull"' %
                                      permission)
 
-    if transportversion == 1:
-        if args is None:
-            args = ''
+    if args is None:
+        args = ''
 
-        if not isinstance(args, bytes):
-            raise error.ProgrammingError('arguments for version 1 commands '
-                                         'must be declared as bytes')
-    elif transportversion == 2:
-        if args is None:
-            args = {}
-
-        if not isinstance(args, dict):
-            raise error.ProgrammingError('arguments for version 2 commands '
-                                         'must be declared as dicts')
+    if not isinstance(args, bytes):
+        raise error.ProgrammingError('arguments for version 1 commands '
+                                     'must be declared as bytes')
 
     def register(func):
-        if transportversion == 1:
-            if name in commands:
-                raise error.ProgrammingError('%s command already registered '
-                                             'for version 1' % name)
-            commands[name] = commandentry(func, args=args,
-                                          transports=transports,
-                                          permission=permission)
-        elif transportversion == 2:
-            if name in commandsv2:
-                raise error.ProgrammingError('%s command already registered '
-                                             'for version 2' % name)
-
-            commandsv2[name] = commandentry(func, args=args,
-                                            transports=transports,
-                                            permission=permission)
-        else:
-            raise error.ProgrammingError('unhandled transport version: %d' %
-                                         transportversion)
+        if name in commands:
+            raise error.ProgrammingError('%s command already registered '
+                                         'for version 1' % name)
+        commands[name] = commandentry(func, args=args,
+                                      transports=transports,
+                                      permission=permission)
 
         return func
     return register
--- a/mercurial/wireprotov2server.py	Mon Apr 16 21:38:52 2018 -0700
+++ b/mercurial/wireprotov2server.py	Mon Apr 16 21:49:59 2018 -0700
@@ -405,10 +405,43 @@
 
     return proto.addcapabilities(repo, caps)
 
-def wireprotocommand(*args, **kwargs):
+def wireprotocommand(name, args=None, permission='push'):
+    """Decorator to declare a wire protocol command.
+
+    ``name`` is the name of the wire protocol command being provided.
+
+    ``args`` is a dict of argument names to example values.
+
+    ``permission`` defines the permission type needed to run this command.
+    Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
+    respectively. Default is to assume command requires ``push`` permissions
+    because otherwise commands not declaring their permissions could modify
+    a repository that is supposed to be read-only.
+    """
+    transports = {k for k, v in wireprototypes.TRANSPORTS.items()
+                  if v['version'] == 2}
+
+    if permission not in ('push', 'pull'):
+        raise error.ProgrammingError('invalid wire protocol permission; '
+                                     'got %s; expected "push" or "pull"' %
+                                     permission)
+
+    if args is None:
+        args = {}
+
+    if not isinstance(args, dict):
+        raise error.ProgrammingError('arguments for version 2 commands '
+                                     'must be declared as dicts')
+
     def register(func):
-        return wireproto.wireprotocommand(
-            *args, transportpolicy=wireproto.POLICY_V2_ONLY, **kwargs)(func)
+        if name in wireproto.commandsv2:
+            raise error.ProgrammingError('%s command already registered '
+                                         'for version 2' % name)
+
+        wireproto.commandsv2[name] = wireproto.commandentry(
+            func, args=args, transports=transports, permission=permission)
+
+        return func
 
     return register
 
--- a/tests/wireprotohelpers.sh	Mon Apr 16 21:38:52 2018 -0700
+++ b/tests/wireprotohelpers.sh	Mon Apr 16 21:49:59 2018 -0700
@@ -16,6 +16,7 @@
 cat > dummycommands.py << EOF
 from mercurial import (
     wireprototypes,
+    wireprotov2server,
     wireproto,
 )
 
@@ -23,8 +24,7 @@
 def customreadonlyv1(repo, proto):
     return wireprototypes.bytesresponse(b'customreadonly bytes response')
 
-@wireproto.wireprotocommand('customreadonly', permission='pull',
-                            transportpolicy=wireproto.POLICY_V2_ONLY)
+@wireprotov2server.wireprotocommand('customreadonly', permission='pull')
 def customreadonlyv2(repo, proto):
     return wireprototypes.cborresponse(b'customreadonly bytes response')
 
@@ -32,8 +32,7 @@
 def customreadwrite(repo, proto):
     return wireprototypes.bytesresponse(b'customreadwrite bytes response')
 
-@wireproto.wireprotocommand('customreadwrite', permission='push',
-                            transportpolicy=wireproto.POLICY_V2_ONLY)
+@wireprotov2server.wireprotocommand('customreadwrite', permission='push')
 def customreadwritev2(repo, proto):
     return wireprototypes.cborresponse(b'customreadwrite bytes response')
 EOF