diff hgext/convert/common.py @ 51685:0eb515c7bec8

typing: add trivial type hints to the convert extension's common modules This started as ensuring that the `encoding` and `orig_encoding` attributes has a type other than `Any`, so pytype can catch problems where it needs to be str for stdlib encoding and decoding. It turns out that adding the hint in `mercurial.encoding` is what was needed, but I picked a bunch of low hanging fruit while here. There's definitely more to do, and I see a problem where `shlex.shlex` is being fed bytes instead of str, but there are not enough type hints yet to make pytype notice.
author Matt Harbison <matt_harbison@yahoo.com>
date Thu, 11 Jul 2024 20:54:06 -0400
parents 20e2a20674dc
children 39033e7a6e0a
line wrap: on
line diff
--- a/hgext/convert/common.py	Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/common.py	Thu Jul 11 20:54:06 2024 -0400
@@ -11,6 +11,13 @@
 import re
 import shlex
 import subprocess
+import typing
+
+from typing import (
+    Any,
+    AnyStr,
+    Optional,
+)
 
 from mercurial.i18n import _
 from mercurial.pycompat import open
@@ -26,9 +33,28 @@
     procutil,
 )
 
+if typing.TYPE_CHECKING:
+    from typing import (
+        overload,
+    )
+    from mercurial import (
+        ui as uimod,
+    )
+
 propertycache = util.propertycache
 
 
+if typing.TYPE_CHECKING:
+
+    @overload
+    def _encodeornone(d: str) -> bytes:
+        pass
+
+    @overload
+    def _encodeornone(d: None) -> None:
+        pass
+
+
 def _encodeornone(d):
     if d is None:
         return
@@ -36,7 +62,7 @@
 
 
 class _shlexpy3proxy:
-    def __init__(self, l):
+    def __init__(self, l: shlex.shlex) -> None:
         self._l = l
 
     def __iter__(self):
@@ -50,11 +76,16 @@
         return self._l.infile or b'<unknown>'
 
     @property
-    def lineno(self):
+    def lineno(self) -> int:
         return self._l.lineno
 
 
-def shlexer(data=None, filepath=None, wordchars=None, whitespace=None):
+def shlexer(
+    data=None,
+    filepath: Optional[str] = None,
+    wordchars: Optional[bytes] = None,
+    whitespace: Optional[bytes] = None,
+):
     if data is None:
         data = open(filepath, b'r', encoding='latin1')
     else:
@@ -72,8 +103,8 @@
     return _shlexpy3proxy(l)
 
 
-def encodeargs(args):
-    def encodearg(s):
+def encodeargs(args: Any) -> bytes:
+    def encodearg(s: bytes) -> bytes:
         lines = base64.encodebytes(s)
         lines = [l.splitlines()[0] for l in pycompat.iterbytestr(lines)]
         return b''.join(lines)
@@ -82,7 +113,7 @@
     return encodearg(s)
 
 
-def decodeargs(s):
+def decodeargs(s: bytes) -> Any:
     s = base64.decodebytes(s)
     return pickle.loads(s)
 
@@ -91,7 +122,9 @@
     pass
 
 
-def checktool(exe, name=None, abort=True):
+def checktool(
+    exe: bytes, name: Optional[bytes] = None, abort: bool = True
+) -> None:
     name = name or exe
     if not procutil.findexe(exe):
         if abort:
@@ -105,25 +138,25 @@
     pass
 
 
-SKIPREV = b'SKIP'
+SKIPREV: bytes = b'SKIP'
 
 
 class commit:
     def __init__(
         self,
-        author,
-        date,
-        desc,
+        author: bytes,
+        date: bytes,
+        desc: bytes,
         parents,
-        branch=None,
+        branch: Optional[bytes] = None,
         rev=None,
         extra=None,
         sortkey=None,
         saverev=True,
-        phase=phases.draft,
+        phase: int = phases.draft,
         optparents=None,
         ctx=None,
-    ):
+    ) -> None:
         self.author = author or b'unknown'
         self.date = date or b'0 0'
         self.desc = desc
@@ -141,7 +174,13 @@
 class converter_source:
     """Conversion source interface"""
 
-    def __init__(self, ui, repotype, path=None, revs=None):
+    def __init__(
+        self,
+        ui: "uimod.ui",
+        repotype: bytes,
+        path: Optional[bytes] = None,
+        revs=None,
+    ) -> None:
         """Initialize conversion source (or raise NoRepo("message")
         exception if path is not a valid repository)"""
         self.ui = ui
@@ -151,7 +190,9 @@
 
         self.encoding = b'utf-8'
 
-    def checkhexformat(self, revstr, mapname=b'splicemap'):
+    def checkhexformat(
+        self, revstr: bytes, mapname: bytes = b'splicemap'
+    ) -> None:
         """fails if revstr is not a 40 byte hex. mercurial and git both uses
         such format for their revision numbering
         """
@@ -161,10 +202,10 @@
                 % (mapname, revstr)
             )
 
-    def before(self):
+    def before(self) -> None:
         pass
 
-    def after(self):
+    def after(self) -> None:
         pass
 
     def targetfilebelongstosource(self, targetfilename):
@@ -223,7 +264,7 @@
         """
         raise NotImplementedError
 
-    def recode(self, s, encoding=None):
+    def recode(self, s: AnyStr, encoding: Optional[bytes] = None) -> bytes:
         if not encoding:
             encoding = self.encoding or b'utf-8'
 
@@ -252,17 +293,17 @@
         """
         raise NotImplementedError
 
-    def converted(self, rev, sinkrev):
+    def converted(self, rev, sinkrev) -> None:
         '''Notify the source that a revision has been converted.'''
 
-    def hasnativeorder(self):
+    def hasnativeorder(self) -> bool:
         """Return true if this source has a meaningful, native revision
         order. For instance, Mercurial revisions are store sequentially
         while there is no such global ordering with Darcs.
         """
         return False
 
-    def hasnativeclose(self):
+    def hasnativeclose(self) -> bool:
         """Return true if this source has ability to close branch."""
         return False
 
@@ -280,7 +321,7 @@
         """
         return {}
 
-    def checkrevformat(self, revstr, mapname=b'splicemap'):
+    def checkrevformat(self, revstr, mapname: bytes = b'splicemap') -> bool:
         """revstr is a string that describes a revision in the given
         source control system.  Return true if revstr has correct
         format.
@@ -291,7 +332,7 @@
 class converter_sink:
     """Conversion sink (target) interface"""
 
-    def __init__(self, ui, repotype, path):
+    def __init__(self, ui: "uimod.ui", repotype: bytes, path: bytes) -> None:
         """Initialize conversion sink (or raise NoRepo("message")
         exception if path is not a valid repository)
 
@@ -359,10 +400,10 @@
         filter empty revisions.
         """
 
-    def before(self):
+    def before(self) -> None:
         pass
 
-    def after(self):
+    def after(self) -> None:
         pass
 
     def putbookmarks(self, bookmarks):
@@ -385,17 +426,17 @@
 
 
 class commandline:
-    def __init__(self, ui, command):
+    def __init__(self, ui: "uimod.ui", command: bytes) -> None:
         self.ui = ui
         self.command = command
 
-    def prerun(self):
+    def prerun(self) -> None:
         pass
 
-    def postrun(self):
+    def postrun(self) -> None:
         pass
 
-    def _cmdline(self, cmd, *args, **kwargs):
+    def _cmdline(self, cmd: bytes, *args: bytes, **kwargs) -> bytes:
         kwargs = pycompat.byteskwargs(kwargs)
         cmdline = [self.command, cmd] + list(args)
         for k, v in kwargs.items():
@@ -416,7 +457,7 @@
         cmdline = b' '.join(cmdline)
         return cmdline
 
-    def _run(self, cmd, *args, **kwargs):
+    def _run(self, cmd: bytes, *args: bytes, **kwargs):
         def popen(cmdline):
             p = subprocess.Popen(
                 procutil.tonativestr(cmdline),
@@ -429,13 +470,13 @@
 
         return self._dorun(popen, cmd, *args, **kwargs)
 
-    def _run2(self, cmd, *args, **kwargs):
+    def _run2(self, cmd: bytes, *args: bytes, **kwargs):
         return self._dorun(procutil.popen2, cmd, *args, **kwargs)
 
-    def _run3(self, cmd, *args, **kwargs):
+    def _run3(self, cmd: bytes, *args: bytes, **kwargs):
         return self._dorun(procutil.popen3, cmd, *args, **kwargs)
 
-    def _dorun(self, openfunc, cmd, *args, **kwargs):
+    def _dorun(self, openfunc, cmd: bytes, *args: bytes, **kwargs):
         cmdline = self._cmdline(cmd, *args, **kwargs)
         self.ui.debug(b'running: %s\n' % (cmdline,))
         self.prerun()
@@ -444,20 +485,20 @@
         finally:
             self.postrun()
 
-    def run(self, cmd, *args, **kwargs):
+    def run(self, cmd: bytes, *args: bytes, **kwargs):
         p = self._run(cmd, *args, **kwargs)
         output = p.communicate()[0]
         self.ui.debug(output)
         return output, p.returncode
 
-    def runlines(self, cmd, *args, **kwargs):
+    def runlines(self, cmd: bytes, *args: bytes, **kwargs):
         p = self._run(cmd, *args, **kwargs)
         output = p.stdout.readlines()
         p.wait()
         self.ui.debug(b''.join(output))
         return output, p.returncode
 
-    def checkexit(self, status, output=b''):
+    def checkexit(self, status, output: bytes = b'') -> None:
         if status:
             if output:
                 self.ui.warn(_(b'%s error:\n') % self.command)
@@ -465,12 +506,12 @@
             msg = procutil.explainexit(status)
             raise error.Abort(b'%s %s' % (self.command, msg))
 
-    def run0(self, cmd, *args, **kwargs):
+    def run0(self, cmd: bytes, *args: bytes, **kwargs):
         output, status = self.run(cmd, *args, **kwargs)
         self.checkexit(status, output)
         return output
 
-    def runlines0(self, cmd, *args, **kwargs):
+    def runlines0(self, cmd: bytes, *args: bytes, **kwargs):
         output, status = self.runlines(cmd, *args, **kwargs)
         self.checkexit(status, b''.join(output))
         return output
@@ -493,7 +534,7 @@
         # (and make happy Windows shells while doing this).
         return argmax // 2 - 1
 
-    def _limit_arglist(self, arglist, cmd, *args, **kwargs):
+    def _limit_arglist(self, arglist, cmd: bytes, *args: bytes, **kwargs):
         cmdlen = len(self._cmdline(cmd, *args, **kwargs))
         limit = self.argmax - cmdlen
         numbytes = 0
@@ -510,13 +551,13 @@
         if fl:
             yield fl
 
-    def xargs(self, arglist, cmd, *args, **kwargs):
+    def xargs(self, arglist, cmd: bytes, *args: bytes, **kwargs):
         for l in self._limit_arglist(arglist, cmd, *args, **kwargs):
             self.run0(cmd, *(list(args) + l), **kwargs)
 
 
 class mapfile(dict):
-    def __init__(self, ui, path):
+    def __init__(self, ui: "uimod.ui", path: bytes) -> None:
         super(mapfile, self).__init__()
         self.ui = ui
         self.path = path
@@ -524,7 +565,7 @@
         self.order = []
         self._read()
 
-    def _read(self):
+    def _read(self) -> None:
         if not self.path:
             return
         try:
@@ -548,7 +589,7 @@
             super(mapfile, self).__setitem__(key, value)
         fp.close()
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key, value) -> None:
         if self.fp is None:
             try:
                 self.fp = open(self.path, b'ab')
@@ -561,7 +602,7 @@
         self.fp.flush()
         super(mapfile, self).__setitem__(key, value)
 
-    def close(self):
+    def close(self) -> None:
         if self.fp:
             self.fp.close()
             self.fp = None