typing: add trivial type hints to the convert extension's common modules
This started as ensuring that the `encoding` and `orig_encoding` attributes has
a type other than `Any`, so pytype can catch problems where it needs to be str
for stdlib encoding and decoding. It turns out that adding the hint in
`mercurial.encoding` is what was needed, but I picked a bunch of low hanging
fruit while here. There's definitely more to do, and I see a problem where
`shlex.shlex` is being fed bytes instead of str, but there are not enough type
hints yet to make pytype notice.
--- a/hgext/convert/common.py Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/common.py Thu Jul 11 20:54:06 2024 -0400
@@ -11,6 +11,13 @@
import re
import shlex
import subprocess
+import typing
+
+from typing import (
+ Any,
+ AnyStr,
+ Optional,
+)
from mercurial.i18n import _
from mercurial.pycompat import open
@@ -26,9 +33,28 @@
procutil,
)
+if typing.TYPE_CHECKING:
+ from typing import (
+ overload,
+ )
+ from mercurial import (
+ ui as uimod,
+ )
+
propertycache = util.propertycache
+if typing.TYPE_CHECKING:
+
+ @overload
+ def _encodeornone(d: str) -> bytes:
+ pass
+
+ @overload
+ def _encodeornone(d: None) -> None:
+ pass
+
+
def _encodeornone(d):
if d is None:
return
@@ -36,7 +62,7 @@
class _shlexpy3proxy:
- def __init__(self, l):
+ def __init__(self, l: shlex.shlex) -> None:
self._l = l
def __iter__(self):
@@ -50,11 +76,16 @@
return self._l.infile or b'<unknown>'
@property
- def lineno(self):
+ def lineno(self) -> int:
return self._l.lineno
-def shlexer(data=None, filepath=None, wordchars=None, whitespace=None):
+def shlexer(
+ data=None,
+ filepath: Optional[str] = None,
+ wordchars: Optional[bytes] = None,
+ whitespace: Optional[bytes] = None,
+):
if data is None:
data = open(filepath, b'r', encoding='latin1')
else:
@@ -72,8 +103,8 @@
return _shlexpy3proxy(l)
-def encodeargs(args):
- def encodearg(s):
+def encodeargs(args: Any) -> bytes:
+ def encodearg(s: bytes) -> bytes:
lines = base64.encodebytes(s)
lines = [l.splitlines()[0] for l in pycompat.iterbytestr(lines)]
return b''.join(lines)
@@ -82,7 +113,7 @@
return encodearg(s)
-def decodeargs(s):
+def decodeargs(s: bytes) -> Any:
s = base64.decodebytes(s)
return pickle.loads(s)
@@ -91,7 +122,9 @@
pass
-def checktool(exe, name=None, abort=True):
+def checktool(
+ exe: bytes, name: Optional[bytes] = None, abort: bool = True
+) -> None:
name = name or exe
if not procutil.findexe(exe):
if abort:
@@ -105,25 +138,25 @@
pass
-SKIPREV = b'SKIP'
+SKIPREV: bytes = b'SKIP'
class commit:
def __init__(
self,
- author,
- date,
- desc,
+ author: bytes,
+ date: bytes,
+ desc: bytes,
parents,
- branch=None,
+ branch: Optional[bytes] = None,
rev=None,
extra=None,
sortkey=None,
saverev=True,
- phase=phases.draft,
+ phase: int = phases.draft,
optparents=None,
ctx=None,
- ):
+ ) -> None:
self.author = author or b'unknown'
self.date = date or b'0 0'
self.desc = desc
@@ -141,7 +174,13 @@
class converter_source:
"""Conversion source interface"""
- def __init__(self, ui, repotype, path=None, revs=None):
+ def __init__(
+ self,
+ ui: "uimod.ui",
+ repotype: bytes,
+ path: Optional[bytes] = None,
+ revs=None,
+ ) -> None:
"""Initialize conversion source (or raise NoRepo("message")
exception if path is not a valid repository)"""
self.ui = ui
@@ -151,7 +190,9 @@
self.encoding = b'utf-8'
- def checkhexformat(self, revstr, mapname=b'splicemap'):
+ def checkhexformat(
+ self, revstr: bytes, mapname: bytes = b'splicemap'
+ ) -> None:
"""fails if revstr is not a 40 byte hex. mercurial and git both uses
such format for their revision numbering
"""
@@ -161,10 +202,10 @@
% (mapname, revstr)
)
- def before(self):
+ def before(self) -> None:
pass
- def after(self):
+ def after(self) -> None:
pass
def targetfilebelongstosource(self, targetfilename):
@@ -223,7 +264,7 @@
"""
raise NotImplementedError
- def recode(self, s, encoding=None):
+ def recode(self, s: AnyStr, encoding: Optional[bytes] = None) -> bytes:
if not encoding:
encoding = self.encoding or b'utf-8'
@@ -252,17 +293,17 @@
"""
raise NotImplementedError
- def converted(self, rev, sinkrev):
+ def converted(self, rev, sinkrev) -> None:
'''Notify the source that a revision has been converted.'''
- def hasnativeorder(self):
+ def hasnativeorder(self) -> bool:
"""Return true if this source has a meaningful, native revision
order. For instance, Mercurial revisions are store sequentially
while there is no such global ordering with Darcs.
"""
return False
- def hasnativeclose(self):
+ def hasnativeclose(self) -> bool:
"""Return true if this source has ability to close branch."""
return False
@@ -280,7 +321,7 @@
"""
return {}
- def checkrevformat(self, revstr, mapname=b'splicemap'):
+ def checkrevformat(self, revstr, mapname: bytes = b'splicemap') -> bool:
"""revstr is a string that describes a revision in the given
source control system. Return true if revstr has correct
format.
@@ -291,7 +332,7 @@
class converter_sink:
"""Conversion sink (target) interface"""
- def __init__(self, ui, repotype, path):
+ def __init__(self, ui: "uimod.ui", repotype: bytes, path: bytes) -> None:
"""Initialize conversion sink (or raise NoRepo("message")
exception if path is not a valid repository)
@@ -359,10 +400,10 @@
filter empty revisions.
"""
- def before(self):
+ def before(self) -> None:
pass
- def after(self):
+ def after(self) -> None:
pass
def putbookmarks(self, bookmarks):
@@ -385,17 +426,17 @@
class commandline:
- def __init__(self, ui, command):
+ def __init__(self, ui: "uimod.ui", command: bytes) -> None:
self.ui = ui
self.command = command
- def prerun(self):
+ def prerun(self) -> None:
pass
- def postrun(self):
+ def postrun(self) -> None:
pass
- def _cmdline(self, cmd, *args, **kwargs):
+ def _cmdline(self, cmd: bytes, *args: bytes, **kwargs) -> bytes:
kwargs = pycompat.byteskwargs(kwargs)
cmdline = [self.command, cmd] + list(args)
for k, v in kwargs.items():
@@ -416,7 +457,7 @@
cmdline = b' '.join(cmdline)
return cmdline
- def _run(self, cmd, *args, **kwargs):
+ def _run(self, cmd: bytes, *args: bytes, **kwargs):
def popen(cmdline):
p = subprocess.Popen(
procutil.tonativestr(cmdline),
@@ -429,13 +470,13 @@
return self._dorun(popen, cmd, *args, **kwargs)
- def _run2(self, cmd, *args, **kwargs):
+ def _run2(self, cmd: bytes, *args: bytes, **kwargs):
return self._dorun(procutil.popen2, cmd, *args, **kwargs)
- def _run3(self, cmd, *args, **kwargs):
+ def _run3(self, cmd: bytes, *args: bytes, **kwargs):
return self._dorun(procutil.popen3, cmd, *args, **kwargs)
- def _dorun(self, openfunc, cmd, *args, **kwargs):
+ def _dorun(self, openfunc, cmd: bytes, *args: bytes, **kwargs):
cmdline = self._cmdline(cmd, *args, **kwargs)
self.ui.debug(b'running: %s\n' % (cmdline,))
self.prerun()
@@ -444,20 +485,20 @@
finally:
self.postrun()
- def run(self, cmd, *args, **kwargs):
+ def run(self, cmd: bytes, *args: bytes, **kwargs):
p = self._run(cmd, *args, **kwargs)
output = p.communicate()[0]
self.ui.debug(output)
return output, p.returncode
- def runlines(self, cmd, *args, **kwargs):
+ def runlines(self, cmd: bytes, *args: bytes, **kwargs):
p = self._run(cmd, *args, **kwargs)
output = p.stdout.readlines()
p.wait()
self.ui.debug(b''.join(output))
return output, p.returncode
- def checkexit(self, status, output=b''):
+ def checkexit(self, status, output: bytes = b'') -> None:
if status:
if output:
self.ui.warn(_(b'%s error:\n') % self.command)
@@ -465,12 +506,12 @@
msg = procutil.explainexit(status)
raise error.Abort(b'%s %s' % (self.command, msg))
- def run0(self, cmd, *args, **kwargs):
+ def run0(self, cmd: bytes, *args: bytes, **kwargs):
output, status = self.run(cmd, *args, **kwargs)
self.checkexit(status, output)
return output
- def runlines0(self, cmd, *args, **kwargs):
+ def runlines0(self, cmd: bytes, *args: bytes, **kwargs):
output, status = self.runlines(cmd, *args, **kwargs)
self.checkexit(status, b''.join(output))
return output
@@ -493,7 +534,7 @@
# (and make happy Windows shells while doing this).
return argmax // 2 - 1
- def _limit_arglist(self, arglist, cmd, *args, **kwargs):
+ def _limit_arglist(self, arglist, cmd: bytes, *args: bytes, **kwargs):
cmdlen = len(self._cmdline(cmd, *args, **kwargs))
limit = self.argmax - cmdlen
numbytes = 0
@@ -510,13 +551,13 @@
if fl:
yield fl
- def xargs(self, arglist, cmd, *args, **kwargs):
+ def xargs(self, arglist, cmd: bytes, *args: bytes, **kwargs):
for l in self._limit_arglist(arglist, cmd, *args, **kwargs):
self.run0(cmd, *(list(args) + l), **kwargs)
class mapfile(dict):
- def __init__(self, ui, path):
+ def __init__(self, ui: "uimod.ui", path: bytes) -> None:
super(mapfile, self).__init__()
self.ui = ui
self.path = path
@@ -524,7 +565,7 @@
self.order = []
self._read()
- def _read(self):
+ def _read(self) -> None:
if not self.path:
return
try:
@@ -548,7 +589,7 @@
super(mapfile, self).__setitem__(key, value)
fp.close()
- def __setitem__(self, key, value):
+ def __setitem__(self, key, value) -> None:
if self.fp is None:
try:
self.fp = open(self.path, b'ab')
@@ -561,7 +602,7 @@
self.fp.flush()
super(mapfile, self).__setitem__(key, value)
- def close(self):
+ def close(self) -> None:
if self.fp:
self.fp.close()
self.fp = None
--- a/hgext/convert/convcmd.py Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/convcmd.py Thu Jul 11 20:54:06 2024 -0400
@@ -9,6 +9,14 @@
import heapq
import os
import shutil
+import typing
+
+from typing import (
+ AnyStr,
+ Mapping,
+ Optional,
+ Union,
+)
from mercurial.i18n import _
from mercurial.pycompat import open
@@ -36,6 +44,11 @@
subversion,
)
+if typing.TYPE_CHECKING:
+ from mercurial import (
+ ui as uimod,
+ )
+
mapfile = common.mapfile
MissingTool = common.MissingTool
NoRepo = common.NoRepo
@@ -53,10 +66,10 @@
svn_sink = subversion.svn_sink
svn_source = subversion.svn_source
-orig_encoding = b'ascii'
+orig_encoding: bytes = b'ascii'
-def readauthormap(ui, authorfile, authors=None):
+def readauthormap(ui: "uimod.ui", authorfile, authors=None):
if authors is None:
authors = {}
with open(authorfile, b'rb') as afile:
@@ -86,7 +99,7 @@
return authors
-def recode(s):
+def recode(s: AnyStr) -> bytes:
if isinstance(s, str):
return s.encode(pycompat.sysstr(orig_encoding), 'replace')
else:
@@ -95,7 +108,7 @@
)
-def mapbranch(branch, branchmap):
+def mapbranch(branch: bytes, branchmap: Mapping[bytes, bytes]) -> bytes:
"""
>>> bmap = {b'default': b'branch1'}
>>> for i in [b'', None]:
@@ -147,7 +160,7 @@
]
-def convertsource(ui, path, type, revs):
+def convertsource(ui: "uimod.ui", path: bytes, type: bytes, revs):
exceptions = []
if type and type not in [s[0] for s in source_converters]:
raise error.Abort(_(b'%s: invalid source repository type') % type)
@@ -163,7 +176,9 @@
raise error.Abort(_(b'%s: missing or unsupported repository') % path)
-def convertsink(ui, path, type):
+def convertsink(
+ ui: "uimod.ui", path: bytes, type: bytes
+) -> Union[hgconvert.mercurial_sink, subversion.svn_sink]:
if type and type not in [s[0] for s in sink_converters]:
raise error.Abort(_(b'%s: invalid destination repository type') % type)
for name, sink in sink_converters:
@@ -178,7 +193,9 @@
class progresssource:
- def __init__(self, ui, source, filecount):
+ def __init__(
+ self, ui: "uimod.ui", source, filecount: Optional[int]
+ ) -> None:
self.ui = ui
self.source = source
self.progress = ui.makeprogress(
@@ -253,7 +270,7 @@
class converter:
- def __init__(self, ui, source, dest, revmapfile, opts):
+ def __init__(self, ui: "uimod.ui", source, dest, revmapfile, opts) -> None:
self.source = source
self.dest = dest
@@ -280,7 +297,7 @@
self.splicemap = self.parsesplicemap(opts.get(b'splicemap'))
self.branchmap = mapfile(ui, opts.get(b'branchmap'))
- def parsesplicemap(self, path):
+ def parsesplicemap(self, path: bytes):
"""check and validate the splicemap format and
return a child/parents dictionary.
Format checking has two parts.
@@ -356,7 +373,7 @@
return parents
- def mergesplicemap(self, parents, splicemap):
+ def mergesplicemap(self, parents, splicemap) -> None:
"""A splicemap redefines child/parent relationships. Check the
map contains valid revision identifiers and merge the new
links in the source graph.
@@ -488,7 +505,7 @@
return s
- def writeauthormap(self):
+ def writeauthormap(self) -> None:
authorfile = self.authorfile
if authorfile:
self.ui.status(_(b'writing author map file %s\n') % authorfile)
@@ -501,7 +518,7 @@
)
ofile.close()
- def readauthormap(self, authorfile):
+ def readauthormap(self, authorfile) -> None:
self.authors = readauthormap(self.ui, authorfile, self.authors)
def cachecommit(self, rev):
@@ -511,7 +528,7 @@
self.commitcache[rev] = commit
return commit
- def copy(self, rev):
+ def copy(self, rev) -> None:
commit = self.commitcache[rev]
full = self.opts.get(b'full')
changes = self.source.getchanges(rev, full)
@@ -563,7 +580,7 @@
self.source.converted(rev, newnode)
self.map[rev] = newnode
- def convert(self, sortmode):
+ def convert(self, sortmode) -> None:
try:
self.source.before()
self.dest.before()
@@ -628,7 +645,7 @@
finally:
self.cleanup()
- def cleanup(self):
+ def cleanup(self) -> None:
try:
self.dest.after()
finally:
@@ -636,7 +653,9 @@
self.map.close()
-def convert(ui, src, dest=None, revmapfile=None, **opts):
+def convert(
+ ui: "uimod.ui", src, dest: Optional[bytes] = None, revmapfile=None, **opts
+) -> None:
opts = pycompat.byteskwargs(opts)
global orig_encoding
orig_encoding = encoding.encoding
--- a/hgext/convert/filemap.py Thu Jul 11 14:46:00 2024 -0400
+++ b/hgext/convert/filemap.py Thu Jul 11 20:54:06 2024 -0400
@@ -6,6 +6,17 @@
import posixpath
+import typing
+
+from typing import (
+ Iterator,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Set,
+ Tuple,
+ overload,
+)
from mercurial.i18n import _
from mercurial import (
@@ -14,10 +25,15 @@
)
from . import common
+if typing.TYPE_CHECKING:
+ from mercurial import (
+ ui as uimod,
+ )
+
SKIPREV = common.SKIPREV
-def rpairs(path):
+def rpairs(path: bytes) -> Iterator[Tuple[bytes, bytes]]:
"""Yield tuples with path split at '/', starting with the full path.
No leading, trailing or double '/', please.
>>> for x in rpairs(b'foo/bar/baz'): print(x)
@@ -33,6 +49,17 @@
yield b'.', path
+if typing.TYPE_CHECKING:
+
+ @overload
+ def normalize(path: bytes) -> bytes:
+ pass
+
+ @overload
+ def normalize(path: None) -> None:
+ pass
+
+
def normalize(path):
"""We use posixpath.normpath to support cross-platform path format.
However, it doesn't handle None input. So we wrap it up."""
@@ -46,7 +73,10 @@
A name can be mapped to itself, a new name, or None (omit from new
repository)."""
- def __init__(self, ui, path=None):
+ rename: MutableMapping[bytes, bytes]
+ targetprefixes: Optional[Set[bytes]]
+
+ def __init__(self, ui: "uimod.ui", path=None) -> None:
self.ui = ui
self.include = {}
self.exclude = {}
@@ -56,10 +86,11 @@
if self.parse(path):
raise error.Abort(_(b'errors in filemap'))
- def parse(self, path):
+ # TODO: cmd==b'source' case breaks if ``path``is str
+ def parse(self, path) -> int:
errs = 0
- def check(name, mapping, listname):
+ def check(name: bytes, mapping, listname: bytes):
if not name:
self.ui.warn(
_(b'%s:%d: path to %s is missing\n')
@@ -110,7 +141,9 @@
cmd = lex.get_token()
return errs
- def lookup(self, name, mapping):
+ def lookup(
+ self, name: bytes, mapping: Mapping[bytes, bytes]
+ ) -> Tuple[bytes, bytes, bytes]:
name = normalize(name)
for pre, suf in rpairs(name):
try:
@@ -119,7 +152,7 @@
pass
return b'', name, b''
- def istargetfile(self, filename):
+ def istargetfile(self, filename: bytes) -> bool:
"""Return true if the given target filename is covered as a destination
of the filemap. This is useful for identifying what parts of the target
repo belong to the source repo and what parts don't."""
@@ -143,7 +176,7 @@
return True
return False
- def __call__(self, name):
+ def __call__(self, name: bytes) -> Optional[bytes]:
if self.include:
inc = self.lookup(name, self.include)[0]
else:
@@ -165,7 +198,7 @@
return newpre
return name
- def active(self):
+ def active(self) -> bool:
return bool(self.include or self.exclude or self.rename)
@@ -185,7 +218,7 @@
class filemap_source(common.converter_source):
- def __init__(self, ui, baseconverter, filemap):
+ def __init__(self, ui: "uimod.ui", baseconverter, filemap) -> None:
super(filemap_source, self).__init__(ui, baseconverter.repotype)
self.base = baseconverter
self.filemapper = filemapper(ui, filemap)
@@ -206,10 +239,10 @@
b'convert', b'ignoreancestorcheck'
)
- def before(self):
+ def before(self) -> None:
self.base.before()
- def after(self):
+ def after(self) -> None:
self.base.after()
def setrevmap(self, revmap):
@@ -243,7 +276,7 @@
self.convertedorder = converted
return self.base.setrevmap(revmap)
- def rebuild(self):
+ def rebuild(self) -> bool:
if self._rebuilt:
return True
self._rebuilt = True
@@ -276,7 +309,7 @@
def getheads(self):
return self.base.getheads()
- def getcommit(self, rev):
+ def getcommit(self, rev: bytes):
# We want to save a reference to the commit objects to be able
# to rewrite their parents later on.
c = self.commits[rev] = self.base.getcommit(rev)
@@ -292,7 +325,7 @@
return self.commits[rev]
return self.base.getcommit(rev)
- def _discard(self, *revs):
+ def _discard(self, *revs) -> None:
for r in revs:
if r is None:
continue
@@ -304,7 +337,7 @@
if self._rebuilt:
del self.children[r]
- def wanted(self, rev, i):
+ def wanted(self, rev, i) -> bool:
# Return True if we're directly interested in rev.
#
# i is an index selecting one of the parents of rev (if rev
@@ -332,7 +365,7 @@
# doesn't consider it significant, and this revision should be dropped.
return not files and b'close' not in self.commits[rev].extra
- def mark_not_wanted(self, rev, p):
+ def mark_not_wanted(self, rev, p) -> None:
# Mark rev as not interesting and update data structures.
if p is None:
@@ -347,7 +380,7 @@
self.parentmap[rev] = self.parentmap[p]
self.wantedancestors[rev] = self.wantedancestors[p]
- def mark_wanted(self, rev, parents):
+ def mark_wanted(self, rev, parents) -> None:
# Mark rev ss wanted and update data structures.
# rev will be in the restricted graph, so children of rev in
@@ -474,7 +507,7 @@
return files, ncopies, ncleanp2
- def targetfilebelongstosource(self, targetfilename):
+ def targetfilebelongstosource(self, targetfilename: bytes) -> bool:
return self.filemapper.istargetfile(targetfilename)
def getfile(self, name, rev):
@@ -484,7 +517,7 @@
def gettags(self):
return self.base.gettags()
- def hasnativeorder(self):
+ def hasnativeorder(self) -> bool:
return self.base.hasnativeorder()
def lookuprev(self, rev):