changeset 51685:0eb515c7bec8

typing: add trivial type hints to the convert extension's common modules This started as ensuring that the `encoding` and `orig_encoding` attributes has a type other than `Any`, so pytype can catch problems where it needs to be str for stdlib encoding and decoding. It turns out that adding the hint in `mercurial.encoding` is what was needed, but I picked a bunch of low hanging fruit while here. There's definitely more to do, and I see a problem where `shlex.shlex` is being fed bytes instead of str, but there are not enough type hints yet to make pytype notice.
author Matt Harbison <matt_harbison@yahoo.com>
date Thu, 11 Jul 2024 20:54:06 -0400
parents 20e2a20674dc
children 39033e7a6e0a
files hgext/convert/common.py hgext/convert/convcmd.py hgext/convert/filemap.py
diffstat 3 files changed, 173 insertions(+), 80 deletions(-) [+]
line wrap: on
line diff
--- a/hgext/convert/common.py	Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/common.py	Thu Jul 11 20:54:06 2024 -0400
@@ -11,6 +11,13 @@
 import re
 import shlex
 import subprocess
+import typing
+
+from typing import (
+    Any,
+    AnyStr,
+    Optional,
+)
 
 from mercurial.i18n import _
 from mercurial.pycompat import open
@@ -26,9 +33,28 @@
     procutil,
 )
 
+if typing.TYPE_CHECKING:
+    from typing import (
+        overload,
+    )
+    from mercurial import (
+        ui as uimod,
+    )
+
 propertycache = util.propertycache
 
 
+if typing.TYPE_CHECKING:
+
+    @overload
+    def _encodeornone(d: str) -> bytes:
+        pass
+
+    @overload
+    def _encodeornone(d: None) -> None:
+        pass
+
+
 def _encodeornone(d):
     if d is None:
         return
@@ -36,7 +62,7 @@
 
 
 class _shlexpy3proxy:
-    def __init__(self, l):
+    def __init__(self, l: shlex.shlex) -> None:
         self._l = l
 
     def __iter__(self):
@@ -50,11 +76,16 @@
         return self._l.infile or b'<unknown>'
 
     @property
-    def lineno(self):
+    def lineno(self) -> int:
         return self._l.lineno
 
 
-def shlexer(data=None, filepath=None, wordchars=None, whitespace=None):
+def shlexer(
+    data=None,
+    filepath: Optional[str] = None,
+    wordchars: Optional[bytes] = None,
+    whitespace: Optional[bytes] = None,
+):
     if data is None:
         data = open(filepath, b'r', encoding='latin1')
     else:
@@ -72,8 +103,8 @@
     return _shlexpy3proxy(l)
 
 
-def encodeargs(args):
-    def encodearg(s):
+def encodeargs(args: Any) -> bytes:
+    def encodearg(s: bytes) -> bytes:
         lines = base64.encodebytes(s)
         lines = [l.splitlines()[0] for l in pycompat.iterbytestr(lines)]
         return b''.join(lines)
@@ -82,7 +113,7 @@
     return encodearg(s)
 
 
-def decodeargs(s):
+def decodeargs(s: bytes) -> Any:
     s = base64.decodebytes(s)
     return pickle.loads(s)
 
@@ -91,7 +122,9 @@
     pass
 
 
-def checktool(exe, name=None, abort=True):
+def checktool(
+    exe: bytes, name: Optional[bytes] = None, abort: bool = True
+) -> None:
     name = name or exe
     if not procutil.findexe(exe):
         if abort:
@@ -105,25 +138,25 @@
     pass
 
 
-SKIPREV = b'SKIP'
+SKIPREV: bytes = b'SKIP'
 
 
 class commit:
     def __init__(
         self,
-        author,
-        date,
-        desc,
+        author: bytes,
+        date: bytes,
+        desc: bytes,
         parents,
-        branch=None,
+        branch: Optional[bytes] = None,
         rev=None,
         extra=None,
         sortkey=None,
         saverev=True,
-        phase=phases.draft,
+        phase: int = phases.draft,
         optparents=None,
         ctx=None,
-    ):
+    ) -> None:
         self.author = author or b'unknown'
         self.date = date or b'0 0'
         self.desc = desc
@@ -141,7 +174,13 @@
 class converter_source:
     """Conversion source interface"""
 
-    def __init__(self, ui, repotype, path=None, revs=None):
+    def __init__(
+        self,
+        ui: "uimod.ui",
+        repotype: bytes,
+        path: Optional[bytes] = None,
+        revs=None,
+    ) -> None:
         """Initialize conversion source (or raise NoRepo("message")
         exception if path is not a valid repository)"""
         self.ui = ui
@@ -151,7 +190,9 @@
 
         self.encoding = b'utf-8'
 
-    def checkhexformat(self, revstr, mapname=b'splicemap'):
+    def checkhexformat(
+        self, revstr: bytes, mapname: bytes = b'splicemap'
+    ) -> None:
         """fails if revstr is not a 40 byte hex. mercurial and git both uses
         such format for their revision numbering
         """
@@ -161,10 +202,10 @@
                 % (mapname, revstr)
             )
 
-    def before(self):
+    def before(self) -> None:
         pass
 
-    def after(self):
+    def after(self) -> None:
         pass
 
     def targetfilebelongstosource(self, targetfilename):
@@ -223,7 +264,7 @@
         """
         raise NotImplementedError
 
-    def recode(self, s, encoding=None):
+    def recode(self, s: AnyStr, encoding: Optional[bytes] = None) -> bytes:
         if not encoding:
             encoding = self.encoding or b'utf-8'
 
@@ -252,17 +293,17 @@
         """
         raise NotImplementedError
 
-    def converted(self, rev, sinkrev):
+    def converted(self, rev, sinkrev) -> None:
         '''Notify the source that a revision has been converted.'''
 
-    def hasnativeorder(self):
+    def hasnativeorder(self) -> bool:
         """Return true if this source has a meaningful, native revision
         order. For instance, Mercurial revisions are store sequentially
         while there is no such global ordering with Darcs.
         """
         return False
 
-    def hasnativeclose(self):
+    def hasnativeclose(self) -> bool:
         """Return true if this source has ability to close branch."""
         return False
 
@@ -280,7 +321,7 @@
         """
         return {}
 
-    def checkrevformat(self, revstr, mapname=b'splicemap'):
+    def checkrevformat(self, revstr, mapname: bytes = b'splicemap') -> bool:
         """revstr is a string that describes a revision in the given
         source control system.  Return true if revstr has correct
         format.
@@ -291,7 +332,7 @@
 class converter_sink:
     """Conversion sink (target) interface"""
 
-    def __init__(self, ui, repotype, path):
+    def __init__(self, ui: "uimod.ui", repotype: bytes, path: bytes) -> None:
         """Initialize conversion sink (or raise NoRepo("message")
         exception if path is not a valid repository)
 
@@ -359,10 +400,10 @@
         filter empty revisions.
         """
 
-    def before(self):
+    def before(self) -> None:
         pass
 
-    def after(self):
+    def after(self) -> None:
         pass
 
     def putbookmarks(self, bookmarks):
@@ -385,17 +426,17 @@
 
 
 class commandline:
-    def __init__(self, ui, command):
+    def __init__(self, ui: "uimod.ui", command: bytes) -> None:
         self.ui = ui
         self.command = command
 
-    def prerun(self):
+    def prerun(self) -> None:
         pass
 
-    def postrun(self):
+    def postrun(self) -> None:
         pass
 
-    def _cmdline(self, cmd, *args, **kwargs):
+    def _cmdline(self, cmd: bytes, *args: bytes, **kwargs) -> bytes:
         kwargs = pycompat.byteskwargs(kwargs)
         cmdline = [self.command, cmd] + list(args)
         for k, v in kwargs.items():
@@ -416,7 +457,7 @@
         cmdline = b' '.join(cmdline)
         return cmdline
 
-    def _run(self, cmd, *args, **kwargs):
+    def _run(self, cmd: bytes, *args: bytes, **kwargs):
         def popen(cmdline):
             p = subprocess.Popen(
                 procutil.tonativestr(cmdline),
@@ -429,13 +470,13 @@
 
         return self._dorun(popen, cmd, *args, **kwargs)
 
-    def _run2(self, cmd, *args, **kwargs):
+    def _run2(self, cmd: bytes, *args: bytes, **kwargs):
         return self._dorun(procutil.popen2, cmd, *args, **kwargs)
 
-    def _run3(self, cmd, *args, **kwargs):
+    def _run3(self, cmd: bytes, *args: bytes, **kwargs):
         return self._dorun(procutil.popen3, cmd, *args, **kwargs)
 
-    def _dorun(self, openfunc, cmd, *args, **kwargs):
+    def _dorun(self, openfunc, cmd: bytes, *args: bytes, **kwargs):
         cmdline = self._cmdline(cmd, *args, **kwargs)
         self.ui.debug(b'running: %s\n' % (cmdline,))
         self.prerun()
@@ -444,20 +485,20 @@
         finally:
             self.postrun()
 
-    def run(self, cmd, *args, **kwargs):
+    def run(self, cmd: bytes, *args: bytes, **kwargs):
         p = self._run(cmd, *args, **kwargs)
         output = p.communicate()[0]
         self.ui.debug(output)
         return output, p.returncode
 
-    def runlines(self, cmd, *args, **kwargs):
+    def runlines(self, cmd: bytes, *args: bytes, **kwargs):
         p = self._run(cmd, *args, **kwargs)
         output = p.stdout.readlines()
         p.wait()
         self.ui.debug(b''.join(output))
         return output, p.returncode
 
-    def checkexit(self, status, output=b''):
+    def checkexit(self, status, output: bytes = b'') -> None:
         if status:
             if output:
                 self.ui.warn(_(b'%s error:\n') % self.command)
@@ -465,12 +506,12 @@
             msg = procutil.explainexit(status)
             raise error.Abort(b'%s %s' % (self.command, msg))
 
-    def run0(self, cmd, *args, **kwargs):
+    def run0(self, cmd: bytes, *args: bytes, **kwargs):
         output, status = self.run(cmd, *args, **kwargs)
         self.checkexit(status, output)
         return output
 
-    def runlines0(self, cmd, *args, **kwargs):
+    def runlines0(self, cmd: bytes, *args: bytes, **kwargs):
         output, status = self.runlines(cmd, *args, **kwargs)
         self.checkexit(status, b''.join(output))
         return output
@@ -493,7 +534,7 @@
         # (and make happy Windows shells while doing this).
         return argmax // 2 - 1
 
-    def _limit_arglist(self, arglist, cmd, *args, **kwargs):
+    def _limit_arglist(self, arglist, cmd: bytes, *args: bytes, **kwargs):
         cmdlen = len(self._cmdline(cmd, *args, **kwargs))
         limit = self.argmax - cmdlen
         numbytes = 0
@@ -510,13 +551,13 @@
         if fl:
             yield fl
 
-    def xargs(self, arglist, cmd, *args, **kwargs):
+    def xargs(self, arglist, cmd: bytes, *args: bytes, **kwargs):
         for l in self._limit_arglist(arglist, cmd, *args, **kwargs):
             self.run0(cmd, *(list(args) + l), **kwargs)
 
 
 class mapfile(dict):
-    def __init__(self, ui, path):
+    def __init__(self, ui: "uimod.ui", path: bytes) -> None:
         super(mapfile, self).__init__()
         self.ui = ui
         self.path = path
@@ -524,7 +565,7 @@
         self.order = []
         self._read()
 
-    def _read(self):
+    def _read(self) -> None:
         if not self.path:
             return
         try:
@@ -548,7 +589,7 @@
             super(mapfile, self).__setitem__(key, value)
         fp.close()
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key, value) -> None:
         if self.fp is None:
             try:
                 self.fp = open(self.path, b'ab')
@@ -561,7 +602,7 @@
         self.fp.flush()
         super(mapfile, self).__setitem__(key, value)
 
-    def close(self):
+    def close(self) -> None:
         if self.fp:
             self.fp.close()
             self.fp = None
--- a/hgext/convert/convcmd.py	Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/convcmd.py	Thu Jul 11 20:54:06 2024 -0400
@@ -9,6 +9,14 @@
 import heapq
 import os
 import shutil
+import typing
+
+from typing import (
+    AnyStr,
+    Mapping,
+    Optional,
+    Union,
+)
 
 from mercurial.i18n import _
 from mercurial.pycompat import open
@@ -36,6 +44,11 @@
     subversion,
 )
 
+if typing.TYPE_CHECKING:
+    from mercurial import (
+        ui as uimod,
+    )
+
 mapfile = common.mapfile
 MissingTool = common.MissingTool
 NoRepo = common.NoRepo
@@ -53,10 +66,10 @@
 svn_sink = subversion.svn_sink
 svn_source = subversion.svn_source
 
-orig_encoding = b'ascii'
+orig_encoding: bytes = b'ascii'
 
 
-def readauthormap(ui, authorfile, authors=None):
+def readauthormap(ui: "uimod.ui", authorfile, authors=None):
     if authors is None:
         authors = {}
     with open(authorfile, b'rb') as afile:
@@ -86,7 +99,7 @@
     return authors
 
 
-def recode(s):
+def recode(s: AnyStr) -> bytes:
     if isinstance(s, str):
         return s.encode(pycompat.sysstr(orig_encoding), 'replace')
     else:
@@ -95,7 +108,7 @@
         )
 
 
-def mapbranch(branch, branchmap):
+def mapbranch(branch: bytes, branchmap: Mapping[bytes, bytes]) -> bytes:
     """
     >>> bmap = {b'default': b'branch1'}
     >>> for i in [b'', None]:
@@ -147,7 +160,7 @@
 ]
 
 
-def convertsource(ui, path, type, revs):
+def convertsource(ui: "uimod.ui", path: bytes, type: bytes, revs):
     exceptions = []
     if type and type not in [s[0] for s in source_converters]:
         raise error.Abort(_(b'%s: invalid source repository type') % type)
@@ -163,7 +176,9 @@
     raise error.Abort(_(b'%s: missing or unsupported repository') % path)
 
 
-def convertsink(ui, path, type):
+def convertsink(
+    ui: "uimod.ui", path: bytes, type: bytes
+) -> Union[hgconvert.mercurial_sink, subversion.svn_sink]:
     if type and type not in [s[0] for s in sink_converters]:
         raise error.Abort(_(b'%s: invalid destination repository type') % type)
     for name, sink in sink_converters:
@@ -178,7 +193,9 @@
 
 
 class progresssource:
-    def __init__(self, ui, source, filecount):
+    def __init__(
+        self, ui: "uimod.ui", source, filecount: Optional[int]
+    ) -> None:
         self.ui = ui
         self.source = source
         self.progress = ui.makeprogress(
@@ -253,7 +270,7 @@
 
 
 class converter:
-    def __init__(self, ui, source, dest, revmapfile, opts):
+    def __init__(self, ui: "uimod.ui", source, dest, revmapfile, opts) -> None:
 
         self.source = source
         self.dest = dest
@@ -280,7 +297,7 @@
         self.splicemap = self.parsesplicemap(opts.get(b'splicemap'))
         self.branchmap = mapfile(ui, opts.get(b'branchmap'))
 
-    def parsesplicemap(self, path):
+    def parsesplicemap(self, path: bytes):
         """check and validate the splicemap format and
         return a child/parents dictionary.
         Format checking has two parts.
@@ -356,7 +373,7 @@
 
         return parents
 
-    def mergesplicemap(self, parents, splicemap):
+    def mergesplicemap(self, parents, splicemap) -> None:
         """A splicemap redefines child/parent relationships. Check the
         map contains valid revision identifiers and merge the new
         links in the source graph.
@@ -488,7 +505,7 @@
 
         return s
 
-    def writeauthormap(self):
+    def writeauthormap(self) -> None:
         authorfile = self.authorfile
         if authorfile:
             self.ui.status(_(b'writing author map file %s\n') % authorfile)
@@ -501,7 +518,7 @@
                 )
             ofile.close()
 
-    def readauthormap(self, authorfile):
+    def readauthormap(self, authorfile) -> None:
         self.authors = readauthormap(self.ui, authorfile, self.authors)
 
     def cachecommit(self, rev):
@@ -511,7 +528,7 @@
         self.commitcache[rev] = commit
         return commit
 
-    def copy(self, rev):
+    def copy(self, rev) -> None:
         commit = self.commitcache[rev]
         full = self.opts.get(b'full')
         changes = self.source.getchanges(rev, full)
@@ -563,7 +580,7 @@
         self.source.converted(rev, newnode)
         self.map[rev] = newnode
 
-    def convert(self, sortmode):
+    def convert(self, sortmode) -> None:
         try:
             self.source.before()
             self.dest.before()
@@ -628,7 +645,7 @@
         finally:
             self.cleanup()
 
-    def cleanup(self):
+    def cleanup(self) -> None:
         try:
             self.dest.after()
         finally:
@@ -636,7 +653,9 @@
         self.map.close()
 
 
-def convert(ui, src, dest=None, revmapfile=None, **opts):
+def convert(
+    ui: "uimod.ui", src, dest: Optional[bytes] = None, revmapfile=None, **opts
+) -> None:
     opts = pycompat.byteskwargs(opts)
     global orig_encoding
     orig_encoding = encoding.encoding
--- a/hgext/convert/filemap.py	Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/filemap.py	Thu Jul 11 20:54:06 2024 -0400
@@ -6,6 +6,17 @@
 
 
 import posixpath
+import typing
+
+from typing import (
+    Iterator,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Set,
+    Tuple,
+    overload,
+)
 
 from mercurial.i18n import _
 from mercurial import (
@@ -14,10 +25,15 @@
 )
 from . import common
 
+if typing.TYPE_CHECKING:
+    from mercurial import (
+        ui as uimod,
+    )
+
 SKIPREV = common.SKIPREV
 
 
-def rpairs(path):
+def rpairs(path: bytes) -> Iterator[Tuple[bytes, bytes]]:
     """Yield tuples with path split at '/', starting with the full path.
     No leading, trailing or double '/', please.
     >>> for x in rpairs(b'foo/bar/baz'): print(x)
@@ -33,6 +49,17 @@
     yield b'.', path
 
 
+if typing.TYPE_CHECKING:
+
+    @overload
+    def normalize(path: bytes) -> bytes:
+        pass
+
+    @overload
+    def normalize(path: None) -> None:
+        pass
+
+
 def normalize(path):
     """We use posixpath.normpath to support cross-platform path format.
     However, it doesn't handle None input. So we wrap it up."""
@@ -46,7 +73,10 @@
     A name can be mapped to itself, a new name, or None (omit from new
     repository)."""
 
-    def __init__(self, ui, path=None):
+    rename: MutableMapping[bytes, bytes]
+    targetprefixes: Optional[Set[bytes]]
+
+    def __init__(self, ui: "uimod.ui", path=None) -> None:
         self.ui = ui
         self.include = {}
         self.exclude = {}
@@ -56,10 +86,11 @@
             if self.parse(path):
                 raise error.Abort(_(b'errors in filemap'))
 
-    def parse(self, path):
+    # TODO: cmd==b'source' case breaks if ``path``is str
+    def parse(self, path) -> int:
         errs = 0
 
-        def check(name, mapping, listname):
+        def check(name: bytes, mapping, listname: bytes):
             if not name:
                 self.ui.warn(
                     _(b'%s:%d: path to %s is missing\n')
@@ -110,7 +141,9 @@
             cmd = lex.get_token()
         return errs
 
-    def lookup(self, name, mapping):
+    def lookup(
+        self, name: bytes, mapping: Mapping[bytes, bytes]
+    ) -> Tuple[bytes, bytes, bytes]:
         name = normalize(name)
         for pre, suf in rpairs(name):
             try:
@@ -119,7 +152,7 @@
                 pass
         return b'', name, b''
 
-    def istargetfile(self, filename):
+    def istargetfile(self, filename: bytes) -> bool:
         """Return true if the given target filename is covered as a destination
         of the filemap. This is useful for identifying what parts of the target
         repo belong to the source repo and what parts don't."""
@@ -143,7 +176,7 @@
                 return True
         return False
 
-    def __call__(self, name):
+    def __call__(self, name: bytes) -> Optional[bytes]:
         if self.include:
             inc = self.lookup(name, self.include)[0]
         else:
@@ -165,7 +198,7 @@
             return newpre
         return name
 
-    def active(self):
+    def active(self) -> bool:
         return bool(self.include or self.exclude or self.rename)
 
 
@@ -185,7 +218,7 @@
 
 
 class filemap_source(common.converter_source):
-    def __init__(self, ui, baseconverter, filemap):
+    def __init__(self, ui: "uimod.ui", baseconverter, filemap) -> None:
         super(filemap_source, self).__init__(ui, baseconverter.repotype)
         self.base = baseconverter
         self.filemapper = filemapper(ui, filemap)
@@ -206,10 +239,10 @@
             b'convert', b'ignoreancestorcheck'
         )
 
-    def before(self):
+    def before(self) -> None:
         self.base.before()
 
-    def after(self):
+    def after(self) -> None:
         self.base.after()
 
     def setrevmap(self, revmap):
@@ -243,7 +276,7 @@
         self.convertedorder = converted
         return self.base.setrevmap(revmap)
 
-    def rebuild(self):
+    def rebuild(self) -> bool:
         if self._rebuilt:
             return True
         self._rebuilt = True
@@ -276,7 +309,7 @@
     def getheads(self):
         return self.base.getheads()
 
-    def getcommit(self, rev):
+    def getcommit(self, rev: bytes):
         # We want to save a reference to the commit objects to be able
         # to rewrite their parents later on.
         c = self.commits[rev] = self.base.getcommit(rev)
@@ -292,7 +325,7 @@
             return self.commits[rev]
         return self.base.getcommit(rev)
 
-    def _discard(self, *revs):
+    def _discard(self, *revs) -> None:
         for r in revs:
             if r is None:
                 continue
@@ -304,7 +337,7 @@
                 if self._rebuilt:
                     del self.children[r]
 
-    def wanted(self, rev, i):
+    def wanted(self, rev, i) -> bool:
         # Return True if we're directly interested in rev.
         #
         # i is an index selecting one of the parents of rev (if rev
@@ -332,7 +365,7 @@
         # doesn't consider it significant, and this revision should be dropped.
         return not files and b'close' not in self.commits[rev].extra
 
-    def mark_not_wanted(self, rev, p):
+    def mark_not_wanted(self, rev, p) -> None:
         # Mark rev as not interesting and update data structures.
 
         if p is None:
@@ -347,7 +380,7 @@
         self.parentmap[rev] = self.parentmap[p]
         self.wantedancestors[rev] = self.wantedancestors[p]
 
-    def mark_wanted(self, rev, parents):
+    def mark_wanted(self, rev, parents) -> None:
         # Mark rev ss wanted and update data structures.
 
         # rev will be in the restricted graph, so children of rev in
@@ -474,7 +507,7 @@
 
         return files, ncopies, ncleanp2
 
-    def targetfilebelongstosource(self, targetfilename):
+    def targetfilebelongstosource(self, targetfilename: bytes) -> bool:
         return self.filemapper.istargetfile(targetfilename)
 
     def getfile(self, name, rev):
@@ -484,7 +517,7 @@
     def gettags(self):
         return self.base.gettags()
 
-    def hasnativeorder(self):
+    def hasnativeorder(self) -> bool:
         return self.base.hasnativeorder()
 
     def lookuprev(self, rev):