wireproto: allow wire protocol commands to declare transport support
authorGregory Szorc <gregory.szorc@gmail.com>
Fri, 02 Mar 2018 09:47:37 -0500
changeset 36609 abc3b9801563
parent 36608 1151c731686e
child 36610 af0d38f015bb
wireproto: allow wire protocol commands to declare transport support Currently, wire protocol commands are exposed on all transports. Some wire protocol commands are only supported or sensical on some transports. In the future, new wire protocol commands may only be available on new transports and legacy wire protocol commands may not be available to newer transports. This commit introduces a mechanism to allow @wireprotocommand to declare transports for which they should not be available. The mechanism for determining if a wire protocol command is available for a given transport instance has been taught to take this knowledge into account. To help implement this feature, we add a dict to wireprototypes declaring all wire transports and their metadata. There's probably room to refactor the constants used to identify the wire protocols. But that can be in another commit. Differential Revision: https://phab.mercurial-scm.org/D2483
mercurial/wireproto.py
mercurial/wireprototypes.py
--- a/mercurial/wireproto.py	Fri Mar 02 18:50:49 2018 -0500
+++ b/mercurial/wireproto.py	Fri Mar 02 09:47:37 2018 -0500
@@ -592,9 +592,10 @@
 
 class commandentry(object):
     """Represents a declared wire protocol command."""
-    def __init__(self, func, args=''):
+    def __init__(self, func, args='', transports=None):
         self.func = func
         self.args = args
+        self.transports = transports or set()
 
     def _merge(self, func, args):
         """Merge this instance with an incoming 2-tuple.
@@ -604,7 +605,7 @@
         data not captured by the 2-tuple and a new instance containing
         the union of the two objects is returned.
         """
-        return commandentry(func, args=args)
+        return commandentry(func, args=args, transports=set(self.transports))
 
     # Old code treats instances as 2-tuples. So expose that interface.
     def __iter__(self):
@@ -640,7 +641,9 @@
             if k in self:
                 v = self[k]._merge(v[0], v[1])
             else:
-                v = commandentry(v[0], args=v[1])
+                # Use default values from @wireprotocommand.
+                v = commandentry(v[0], args=v[1],
+                                 transports=set(wireprototypes.TRANSPORTS))
         else:
             raise ValueError('command entries must be commandentry instances '
                              'or 2-tuples')
@@ -649,22 +652,52 @@
 
     def commandavailable(self, command, proto):
         """Determine if a command is available for the requested protocol."""
-        # For now, commands are available for all protocols. So do a simple
-        # membership test.
-        return command in self
+        assert proto.name in wireprototypes.TRANSPORTS
+
+        entry = self.get(command)
+
+        if not entry:
+            return False
+
+        if proto.name not in entry.transports:
+            return False
+
+        return True
+
+# Constants specifying which transports a wire protocol command should be
+# available on. For use with @wireprotocommand.
+POLICY_ALL = 'all'
+POLICY_V1_ONLY = 'v1-only'
+POLICY_V2_ONLY = 'v2-only'
 
 commands = commanddict()
 
-def wireprotocommand(name, args=''):
+def wireprotocommand(name, args='', transportpolicy=POLICY_ALL):
     """Decorator to declare a wire protocol command.
 
     ``name`` is the name of the wire protocol command being provided.
 
     ``args`` is a space-delimited list of named arguments that the command
     accepts. ``*`` is a special value that says to accept all arguments.
+
+    ``transportpolicy`` is a POLICY_* constant denoting which transports
+    this wire protocol command should be exposed to. By default, commands
+    are exposed to all wire protocol transports.
     """
+    if transportpolicy == POLICY_ALL:
+        transports = set(wireprototypes.TRANSPORTS)
+    elif transportpolicy == POLICY_V1_ONLY:
+        transports = {k for k, v in wireprototypes.TRANSPORTS.items()
+                      if v['version'] == 1}
+    elif transportpolicy == POLICY_V2_ONLY:
+        transports = {k for k, v in wireprototypes.TRANSPORTS.items()
+                      if v['version'] == 2}
+    else:
+        raise error.Abort(_('invalid transport policy value: %s') %
+                          transportpolicy)
+
     def register(func):
-        commands[name] = commandentry(func, args=args)
+        commands[name] = commandentry(func, args=args, transports=transports)
         return func
     return register
 
--- a/mercurial/wireprototypes.py	Fri Mar 02 18:50:49 2018 -0500
+++ b/mercurial/wireprototypes.py	Fri Mar 02 09:47:37 2018 -0500
@@ -13,6 +13,22 @@
 # to reflect BC breakages.
 SSHV2 = 'exp-ssh-v2-0001'
 
+# All available wire protocol transports.
+TRANSPORTS = {
+    SSHV1: {
+        'transport': 'ssh',
+        'version': 1,
+    },
+    SSHV2: {
+        'transport': 'ssh',
+        'version': 2,
+    },
+    'http-v1': {
+        'transport': 'http',
+        'version': 1,
+    }
+}
+
 class bytesresponse(object):
     """A wire protocol response consisting of raw bytes."""
     def __init__(self, data):