--- a/mercurial/bundle2.py Wed Jul 10 16:04:53 2024 -0400
+++ b/mercurial/bundle2.py Wed Jul 10 17:09:34 2024 -0400
@@ -153,6 +153,7 @@
import string
import struct
import sys
+import typing
from .i18n import _
from .node import (
@@ -181,6 +182,17 @@
)
from .interfaces import repository
+if typing.TYPE_CHECKING:
+ from typing import (
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ )
+
+ Capabilities = Dict[bytes, Union[List[bytes], Tuple[bytes, ...]]]
+
urlerr = util.urlerr
urlreq = util.urlreq
@@ -602,7 +614,7 @@
)
-def decodecaps(blob):
+def decodecaps(blob: bytes) -> "Capabilities":
"""decode a bundle2 caps bytes blob into a dictionary
The blob is a list of capabilities (one per line)
@@ -662,11 +674,14 @@
_magicstring = b'HG20'
- def __init__(self, ui, capabilities=()):
+ def __init__(self, ui, capabilities: "Optional[Capabilities]" = None):
+ if capabilities is None:
+ capabilities = {}
+
self.ui = ui
self._params = []
self._parts = []
- self.capabilities = dict(capabilities)
+ self.capabilities: "Capabilities" = dict(capabilities)
self._compengine = util.compengines.forbundletype(b'UN')
self._compopts = None
# If compression is being handled by a consumer of the raw
@@ -1612,7 +1627,7 @@
# These are only the static capabilities.
# Check the 'getrepocaps' function for the rest.
-capabilities = {
+capabilities: "Capabilities" = {
b'HG20': (),
b'bookmarks': (),
b'error': (b'abort', b'unsupportedcontent', b'pushraced', b'pushkey'),
@@ -1626,7 +1641,8 @@
}
-def getrepocaps(repo, allowpushback=False, role=None):
+# TODO: drop the default value for 'role'
+def getrepocaps(repo, allowpushback: bool = False, role=None) -> "Capabilities":
"""return the bundle2 capabilities for a given repo
Exists to allow extensions (like evolution) to mutate the capabilities.
@@ -1675,7 +1691,7 @@
return caps
-def bundle2caps(remote):
+def bundle2caps(remote) -> "Capabilities":
"""return the bundle capabilities of a peer as dict"""
raw = remote.capable(b'bundle2')
if not raw and raw != b'':
@@ -1684,7 +1700,7 @@
return decodecaps(capsblob)
-def obsmarkersversion(caps):
+def obsmarkersversion(caps: "Capabilities"):
"""extract the list of supported obsmarkers versions from a bundle2caps dict"""
obscaps = caps.get(b'obsmarkers', ())
return [int(c[1:]) for c in obscaps if c.startswith(b'V')]
@@ -1725,7 +1741,7 @@
msg %= count
raise error.ProgrammingError(msg)
- caps = {}
+ caps: "Capabilities" = {}
if opts.get(b'obsolescence', False):
caps[b'obsmarkers'] = (b'V1',)
stream_version = opts.get(b'stream', b"")