typing: add some type hints for bundle2 capabilities
authorMatt Harbison <matt_harbison@yahoo.com>
Wed, 10 Jul 2024 17:09:34 -0400
changeset 51731 6fc31e7bd5db
parent 51730 a09435c0eb14
child 51732 138ab7c6a6ff
typing: add some type hints for bundle2 capabilities Somewhere between hg 3dbc7b1ecaba and hg 8e3f6b5bf720, pytype determined the signature of `bundle20.capabilities` changed from `Dict[bytes, Tuple[bytes]]` to `Dict[bytes, Union[List[bytes], Tuple[bytes]]]`. First, I did try to simply be explicit about the previously inferred type, but it does seem to mix and match list/tuple now (e.g. in `writenewbundle()`). I tried changing the new list usage to tuple, but a couple of things complained, (and I think lists of one item are a little more clear to read anyway). So then I typed the dict value as `Sequence[bytes]`, which worked fine. But there's also a module level `capabilities` field, and when that's typed, pytype complains about `Sequence[bytes]` lacking `__add__`[1]. So I gave up, and just assigned it the type it wanted, with an alias. If somebody feels motivated to make the type consistent, it's simple enough to change the alias. The mutable default value to the constructor was removed to appease PyCharm's type checking on the field. (I didn't bother running the code through pytype prior to changing it, because we've previously made an effort to remove this pattern anyway.) I'm not sure why `getrepocaps()` has a default value for `role` that apparently raises an exception. It's just flagged for now so this series can land without risking additional problems. [1] https://foss.heptapod.net/mercurial/mercurial-devel/-/jobs/2466903
mercurial/bundle2.py
--- a/mercurial/bundle2.py	Wed Jul 10 16:04:53 2024 -0400
+++ b/mercurial/bundle2.py	Wed Jul 10 17:09:34 2024 -0400
@@ -153,6 +153,7 @@
 import string
 import struct
 import sys
+import typing
 
 from .i18n import _
 from .node import (
@@ -181,6 +182,17 @@
 )
 from .interfaces import repository
 
+if typing.TYPE_CHECKING:
+    from typing import (
+        Dict,
+        List,
+        Optional,
+        Tuple,
+        Union,
+    )
+
+    Capabilities = Dict[bytes, Union[List[bytes], Tuple[bytes, ...]]]
+
 urlerr = util.urlerr
 urlreq = util.urlreq
 
@@ -602,7 +614,7 @@
             )
 
 
-def decodecaps(blob):
+def decodecaps(blob: bytes) -> "Capabilities":
     """decode a bundle2 caps bytes blob into a dictionary
 
     The blob is a list of capabilities (one per line)
@@ -662,11 +674,14 @@
 
     _magicstring = b'HG20'
 
-    def __init__(self, ui, capabilities=()):
+    def __init__(self, ui, capabilities: "Optional[Capabilities]" = None):
+        if capabilities is None:
+            capabilities = {}
+
         self.ui = ui
         self._params = []
         self._parts = []
-        self.capabilities = dict(capabilities)
+        self.capabilities: "Capabilities" = dict(capabilities)
         self._compengine = util.compengines.forbundletype(b'UN')
         self._compopts = None
         # If compression is being handled by a consumer of the raw
@@ -1612,7 +1627,7 @@
 
 # These are only the static capabilities.
 # Check the 'getrepocaps' function for the rest.
-capabilities = {
+capabilities: "Capabilities" = {
     b'HG20': (),
     b'bookmarks': (),
     b'error': (b'abort', b'unsupportedcontent', b'pushraced', b'pushkey'),
@@ -1626,7 +1641,8 @@
 }
 
 
-def getrepocaps(repo, allowpushback=False, role=None):
+# TODO: drop the default value for 'role'
+def getrepocaps(repo, allowpushback: bool = False, role=None) -> "Capabilities":
     """return the bundle2 capabilities for a given repo
 
     Exists to allow extensions (like evolution) to mutate the capabilities.
@@ -1675,7 +1691,7 @@
     return caps
 
 
-def bundle2caps(remote):
+def bundle2caps(remote) -> "Capabilities":
     """return the bundle capabilities of a peer as dict"""
     raw = remote.capable(b'bundle2')
     if not raw and raw != b'':
@@ -1684,7 +1700,7 @@
     return decodecaps(capsblob)
 
 
-def obsmarkersversion(caps):
+def obsmarkersversion(caps: "Capabilities"):
     """extract the list of supported obsmarkers versions from a bundle2caps dict"""
     obscaps = caps.get(b'obsmarkers', ())
     return [int(c[1:]) for c in obscaps if c.startswith(b'V')]
@@ -1725,7 +1741,7 @@
         msg %= count
         raise error.ProgrammingError(msg)
 
-    caps = {}
+    caps: "Capabilities" = {}
     if opts.get(b'obsolescence', False):
         caps[b'obsmarkers'] = (b'V1',)
     stream_version = opts.get(b'stream', b"")