Mercurial > hg
changeset 51685:0eb515c7bec8
typing: add trivial type hints to the convert extension's common modules
This started as ensuring that the `encoding` and `orig_encoding` attributes has
a type other than `Any`, so pytype can catch problems where it needs to be str
for stdlib encoding and decoding. It turns out that adding the hint in
`mercurial.encoding` is what was needed, but I picked a bunch of low hanging
fruit while here. There's definitely more to do, and I see a problem where
`shlex.shlex` is being fed bytes instead of str, but there are not enough type
hints yet to make pytype notice.
author | Matt Harbison <matt_harbison@yahoo.com> |
---|---|
date | Thu, 11 Jul 2024 20:54:06 -0400 |
parents | 20e2a20674dc |
children | 39033e7a6e0a |
files | hgext/convert/common.py hgext/convert/convcmd.py hgext/convert/filemap.py |
diffstat | 3 files changed, 173 insertions(+), 80 deletions(-) [+] |
line wrap: on
line diff
--- a/hgext/convert/common.py Thu Jul 11 14:46:00 2024 -0400 +++ b/hgext/convert/common.py Thu Jul 11 20:54:06 2024 -0400 @@ -11,6 +11,13 @@ import re import shlex import subprocess +import typing + +from typing import ( + Any, + AnyStr, + Optional, +) from mercurial.i18n import _ from mercurial.pycompat import open @@ -26,9 +33,28 @@ procutil, ) +if typing.TYPE_CHECKING: + from typing import ( + overload, + ) + from mercurial import ( + ui as uimod, + ) + propertycache = util.propertycache +if typing.TYPE_CHECKING: + + @overload + def _encodeornone(d: str) -> bytes: + pass + + @overload + def _encodeornone(d: None) -> None: + pass + + def _encodeornone(d): if d is None: return @@ -36,7 +62,7 @@ class _shlexpy3proxy: - def __init__(self, l): + def __init__(self, l: shlex.shlex) -> None: self._l = l def __iter__(self): @@ -50,11 +76,16 @@ return self._l.infile or b'<unknown>' @property - def lineno(self): + def lineno(self) -> int: return self._l.lineno -def shlexer(data=None, filepath=None, wordchars=None, whitespace=None): +def shlexer( + data=None, + filepath: Optional[str] = None, + wordchars: Optional[bytes] = None, + whitespace: Optional[bytes] = None, +): if data is None: data = open(filepath, b'r', encoding='latin1') else: @@ -72,8 +103,8 @@ return _shlexpy3proxy(l) -def encodeargs(args): - def encodearg(s): +def encodeargs(args: Any) -> bytes: + def encodearg(s: bytes) -> bytes: lines = base64.encodebytes(s) lines = [l.splitlines()[0] for l in pycompat.iterbytestr(lines)] return b''.join(lines) @@ -82,7 +113,7 @@ return encodearg(s) -def decodeargs(s): +def decodeargs(s: bytes) -> Any: s = base64.decodebytes(s) return pickle.loads(s) @@ -91,7 +122,9 @@ pass -def checktool(exe, name=None, abort=True): +def checktool( + exe: bytes, name: Optional[bytes] = None, abort: bool = True +) -> None: name = name or exe if not procutil.findexe(exe): if abort: @@ -105,25 +138,25 @@ pass -SKIPREV = b'SKIP' +SKIPREV: bytes = b'SKIP' class commit: def __init__( self, - author, - date, - desc, + author: bytes, + date: bytes, + desc: bytes, parents, - branch=None, + branch: Optional[bytes] = None, rev=None, extra=None, sortkey=None, saverev=True, - phase=phases.draft, + phase: int = phases.draft, optparents=None, ctx=None, - ): + ) -> None: self.author = author or b'unknown' self.date = date or b'0 0' self.desc = desc @@ -141,7 +174,13 @@ class converter_source: """Conversion source interface""" - def __init__(self, ui, repotype, path=None, revs=None): + def __init__( + self, + ui: "uimod.ui", + repotype: bytes, + path: Optional[bytes] = None, + revs=None, + ) -> None: """Initialize conversion source (or raise NoRepo("message") exception if path is not a valid repository)""" self.ui = ui @@ -151,7 +190,9 @@ self.encoding = b'utf-8' - def checkhexformat(self, revstr, mapname=b'splicemap'): + def checkhexformat( + self, revstr: bytes, mapname: bytes = b'splicemap' + ) -> None: """fails if revstr is not a 40 byte hex. mercurial and git both uses such format for their revision numbering """ @@ -161,10 +202,10 @@ % (mapname, revstr) ) - def before(self): + def before(self) -> None: pass - def after(self): + def after(self) -> None: pass def targetfilebelongstosource(self, targetfilename): @@ -223,7 +264,7 @@ """ raise NotImplementedError - def recode(self, s, encoding=None): + def recode(self, s: AnyStr, encoding: Optional[bytes] = None) -> bytes: if not encoding: encoding = self.encoding or b'utf-8' @@ -252,17 +293,17 @@ """ raise NotImplementedError - def converted(self, rev, sinkrev): + def converted(self, rev, sinkrev) -> None: '''Notify the source that a revision has been converted.''' - def hasnativeorder(self): + def hasnativeorder(self) -> bool: """Return true if this source has a meaningful, native revision order. For instance, Mercurial revisions are store sequentially while there is no such global ordering with Darcs. """ return False - def hasnativeclose(self): + def hasnativeclose(self) -> bool: """Return true if this source has ability to close branch.""" return False @@ -280,7 +321,7 @@ """ return {} - def checkrevformat(self, revstr, mapname=b'splicemap'): + def checkrevformat(self, revstr, mapname: bytes = b'splicemap') -> bool: """revstr is a string that describes a revision in the given source control system. Return true if revstr has correct format. @@ -291,7 +332,7 @@ class converter_sink: """Conversion sink (target) interface""" - def __init__(self, ui, repotype, path): + def __init__(self, ui: "uimod.ui", repotype: bytes, path: bytes) -> None: """Initialize conversion sink (or raise NoRepo("message") exception if path is not a valid repository) @@ -359,10 +400,10 @@ filter empty revisions. """ - def before(self): + def before(self) -> None: pass - def after(self): + def after(self) -> None: pass def putbookmarks(self, bookmarks): @@ -385,17 +426,17 @@ class commandline: - def __init__(self, ui, command): + def __init__(self, ui: "uimod.ui", command: bytes) -> None: self.ui = ui self.command = command - def prerun(self): + def prerun(self) -> None: pass - def postrun(self): + def postrun(self) -> None: pass - def _cmdline(self, cmd, *args, **kwargs): + def _cmdline(self, cmd: bytes, *args: bytes, **kwargs) -> bytes: kwargs = pycompat.byteskwargs(kwargs) cmdline = [self.command, cmd] + list(args) for k, v in kwargs.items(): @@ -416,7 +457,7 @@ cmdline = b' '.join(cmdline) return cmdline - def _run(self, cmd, *args, **kwargs): + def _run(self, cmd: bytes, *args: bytes, **kwargs): def popen(cmdline): p = subprocess.Popen( procutil.tonativestr(cmdline), @@ -429,13 +470,13 @@ return self._dorun(popen, cmd, *args, **kwargs) - def _run2(self, cmd, *args, **kwargs): + def _run2(self, cmd: bytes, *args: bytes, **kwargs): return self._dorun(procutil.popen2, cmd, *args, **kwargs) - def _run3(self, cmd, *args, **kwargs): + def _run3(self, cmd: bytes, *args: bytes, **kwargs): return self._dorun(procutil.popen3, cmd, *args, **kwargs) - def _dorun(self, openfunc, cmd, *args, **kwargs): + def _dorun(self, openfunc, cmd: bytes, *args: bytes, **kwargs): cmdline = self._cmdline(cmd, *args, **kwargs) self.ui.debug(b'running: %s\n' % (cmdline,)) self.prerun() @@ -444,20 +485,20 @@ finally: self.postrun() - def run(self, cmd, *args, **kwargs): + def run(self, cmd: bytes, *args: bytes, **kwargs): p = self._run(cmd, *args, **kwargs) output = p.communicate()[0] self.ui.debug(output) return output, p.returncode - def runlines(self, cmd, *args, **kwargs): + def runlines(self, cmd: bytes, *args: bytes, **kwargs): p = self._run(cmd, *args, **kwargs) output = p.stdout.readlines() p.wait() self.ui.debug(b''.join(output)) return output, p.returncode - def checkexit(self, status, output=b''): + def checkexit(self, status, output: bytes = b'') -> None: if status: if output: self.ui.warn(_(b'%s error:\n') % self.command) @@ -465,12 +506,12 @@ msg = procutil.explainexit(status) raise error.Abort(b'%s %s' % (self.command, msg)) - def run0(self, cmd, *args, **kwargs): + def run0(self, cmd: bytes, *args: bytes, **kwargs): output, status = self.run(cmd, *args, **kwargs) self.checkexit(status, output) return output - def runlines0(self, cmd, *args, **kwargs): + def runlines0(self, cmd: bytes, *args: bytes, **kwargs): output, status = self.runlines(cmd, *args, **kwargs) self.checkexit(status, b''.join(output)) return output @@ -493,7 +534,7 @@ # (and make happy Windows shells while doing this). return argmax // 2 - 1 - def _limit_arglist(self, arglist, cmd, *args, **kwargs): + def _limit_arglist(self, arglist, cmd: bytes, *args: bytes, **kwargs): cmdlen = len(self._cmdline(cmd, *args, **kwargs)) limit = self.argmax - cmdlen numbytes = 0 @@ -510,13 +551,13 @@ if fl: yield fl - def xargs(self, arglist, cmd, *args, **kwargs): + def xargs(self, arglist, cmd: bytes, *args: bytes, **kwargs): for l in self._limit_arglist(arglist, cmd, *args, **kwargs): self.run0(cmd, *(list(args) + l), **kwargs) class mapfile(dict): - def __init__(self, ui, path): + def __init__(self, ui: "uimod.ui", path: bytes) -> None: super(mapfile, self).__init__() self.ui = ui self.path = path @@ -524,7 +565,7 @@ self.order = [] self._read() - def _read(self): + def _read(self) -> None: if not self.path: return try: @@ -548,7 +589,7 @@ super(mapfile, self).__setitem__(key, value) fp.close() - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: if self.fp is None: try: self.fp = open(self.path, b'ab') @@ -561,7 +602,7 @@ self.fp.flush() super(mapfile, self).__setitem__(key, value) - def close(self): + def close(self) -> None: if self.fp: self.fp.close() self.fp = None
--- a/hgext/convert/convcmd.py Thu Jul 11 14:46:00 2024 -0400 +++ b/hgext/convert/convcmd.py Thu Jul 11 20:54:06 2024 -0400 @@ -9,6 +9,14 @@ import heapq import os import shutil +import typing + +from typing import ( + AnyStr, + Mapping, + Optional, + Union, +) from mercurial.i18n import _ from mercurial.pycompat import open @@ -36,6 +44,11 @@ subversion, ) +if typing.TYPE_CHECKING: + from mercurial import ( + ui as uimod, + ) + mapfile = common.mapfile MissingTool = common.MissingTool NoRepo = common.NoRepo @@ -53,10 +66,10 @@ svn_sink = subversion.svn_sink svn_source = subversion.svn_source -orig_encoding = b'ascii' +orig_encoding: bytes = b'ascii' -def readauthormap(ui, authorfile, authors=None): +def readauthormap(ui: "uimod.ui", authorfile, authors=None): if authors is None: authors = {} with open(authorfile, b'rb') as afile: @@ -86,7 +99,7 @@ return authors -def recode(s): +def recode(s: AnyStr) -> bytes: if isinstance(s, str): return s.encode(pycompat.sysstr(orig_encoding), 'replace') else: @@ -95,7 +108,7 @@ ) -def mapbranch(branch, branchmap): +def mapbranch(branch: bytes, branchmap: Mapping[bytes, bytes]) -> bytes: """ >>> bmap = {b'default': b'branch1'} >>> for i in [b'', None]: @@ -147,7 +160,7 @@ ] -def convertsource(ui, path, type, revs): +def convertsource(ui: "uimod.ui", path: bytes, type: bytes, revs): exceptions = [] if type and type not in [s[0] for s in source_converters]: raise error.Abort(_(b'%s: invalid source repository type') % type) @@ -163,7 +176,9 @@ raise error.Abort(_(b'%s: missing or unsupported repository') % path) -def convertsink(ui, path, type): +def convertsink( + ui: "uimod.ui", path: bytes, type: bytes +) -> Union[hgconvert.mercurial_sink, subversion.svn_sink]: if type and type not in [s[0] for s in sink_converters]: raise error.Abort(_(b'%s: invalid destination repository type') % type) for name, sink in sink_converters: @@ -178,7 +193,9 @@ class progresssource: - def __init__(self, ui, source, filecount): + def __init__( + self, ui: "uimod.ui", source, filecount: Optional[int] + ) -> None: self.ui = ui self.source = source self.progress = ui.makeprogress( @@ -253,7 +270,7 @@ class converter: - def __init__(self, ui, source, dest, revmapfile, opts): + def __init__(self, ui: "uimod.ui", source, dest, revmapfile, opts) -> None: self.source = source self.dest = dest @@ -280,7 +297,7 @@ self.splicemap = self.parsesplicemap(opts.get(b'splicemap')) self.branchmap = mapfile(ui, opts.get(b'branchmap')) - def parsesplicemap(self, path): + def parsesplicemap(self, path: bytes): """check and validate the splicemap format and return a child/parents dictionary. Format checking has two parts. @@ -356,7 +373,7 @@ return parents - def mergesplicemap(self, parents, splicemap): + def mergesplicemap(self, parents, splicemap) -> None: """A splicemap redefines child/parent relationships. Check the map contains valid revision identifiers and merge the new links in the source graph. @@ -488,7 +505,7 @@ return s - def writeauthormap(self): + def writeauthormap(self) -> None: authorfile = self.authorfile if authorfile: self.ui.status(_(b'writing author map file %s\n') % authorfile) @@ -501,7 +518,7 @@ ) ofile.close() - def readauthormap(self, authorfile): + def readauthormap(self, authorfile) -> None: self.authors = readauthormap(self.ui, authorfile, self.authors) def cachecommit(self, rev): @@ -511,7 +528,7 @@ self.commitcache[rev] = commit return commit - def copy(self, rev): + def copy(self, rev) -> None: commit = self.commitcache[rev] full = self.opts.get(b'full') changes = self.source.getchanges(rev, full) @@ -563,7 +580,7 @@ self.source.converted(rev, newnode) self.map[rev] = newnode - def convert(self, sortmode): + def convert(self, sortmode) -> None: try: self.source.before() self.dest.before() @@ -628,7 +645,7 @@ finally: self.cleanup() - def cleanup(self): + def cleanup(self) -> None: try: self.dest.after() finally: @@ -636,7 +653,9 @@ self.map.close() -def convert(ui, src, dest=None, revmapfile=None, **opts): +def convert( + ui: "uimod.ui", src, dest: Optional[bytes] = None, revmapfile=None, **opts +) -> None: opts = pycompat.byteskwargs(opts) global orig_encoding orig_encoding = encoding.encoding
--- a/hgext/convert/filemap.py Thu Jul 11 14:46:00 2024 -0400 +++ b/hgext/convert/filemap.py Thu Jul 11 20:54:06 2024 -0400 @@ -6,6 +6,17 @@ import posixpath +import typing + +from typing import ( + Iterator, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + overload, +) from mercurial.i18n import _ from mercurial import ( @@ -14,10 +25,15 @@ ) from . import common +if typing.TYPE_CHECKING: + from mercurial import ( + ui as uimod, + ) + SKIPREV = common.SKIPREV -def rpairs(path): +def rpairs(path: bytes) -> Iterator[Tuple[bytes, bytes]]: """Yield tuples with path split at '/', starting with the full path. No leading, trailing or double '/', please. >>> for x in rpairs(b'foo/bar/baz'): print(x) @@ -33,6 +49,17 @@ yield b'.', path +if typing.TYPE_CHECKING: + + @overload + def normalize(path: bytes) -> bytes: + pass + + @overload + def normalize(path: None) -> None: + pass + + def normalize(path): """We use posixpath.normpath to support cross-platform path format. However, it doesn't handle None input. So we wrap it up.""" @@ -46,7 +73,10 @@ A name can be mapped to itself, a new name, or None (omit from new repository).""" - def __init__(self, ui, path=None): + rename: MutableMapping[bytes, bytes] + targetprefixes: Optional[Set[bytes]] + + def __init__(self, ui: "uimod.ui", path=None) -> None: self.ui = ui self.include = {} self.exclude = {} @@ -56,10 +86,11 @@ if self.parse(path): raise error.Abort(_(b'errors in filemap')) - def parse(self, path): + # TODO: cmd==b'source' case breaks if ``path``is str + def parse(self, path) -> int: errs = 0 - def check(name, mapping, listname): + def check(name: bytes, mapping, listname: bytes): if not name: self.ui.warn( _(b'%s:%d: path to %s is missing\n') @@ -110,7 +141,9 @@ cmd = lex.get_token() return errs - def lookup(self, name, mapping): + def lookup( + self, name: bytes, mapping: Mapping[bytes, bytes] + ) -> Tuple[bytes, bytes, bytes]: name = normalize(name) for pre, suf in rpairs(name): try: @@ -119,7 +152,7 @@ pass return b'', name, b'' - def istargetfile(self, filename): + def istargetfile(self, filename: bytes) -> bool: """Return true if the given target filename is covered as a destination of the filemap. This is useful for identifying what parts of the target repo belong to the source repo and what parts don't.""" @@ -143,7 +176,7 @@ return True return False - def __call__(self, name): + def __call__(self, name: bytes) -> Optional[bytes]: if self.include: inc = self.lookup(name, self.include)[0] else: @@ -165,7 +198,7 @@ return newpre return name - def active(self): + def active(self) -> bool: return bool(self.include or self.exclude or self.rename) @@ -185,7 +218,7 @@ class filemap_source(common.converter_source): - def __init__(self, ui, baseconverter, filemap): + def __init__(self, ui: "uimod.ui", baseconverter, filemap) -> None: super(filemap_source, self).__init__(ui, baseconverter.repotype) self.base = baseconverter self.filemapper = filemapper(ui, filemap) @@ -206,10 +239,10 @@ b'convert', b'ignoreancestorcheck' ) - def before(self): + def before(self) -> None: self.base.before() - def after(self): + def after(self) -> None: self.base.after() def setrevmap(self, revmap): @@ -243,7 +276,7 @@ self.convertedorder = converted return self.base.setrevmap(revmap) - def rebuild(self): + def rebuild(self) -> bool: if self._rebuilt: return True self._rebuilt = True @@ -276,7 +309,7 @@ def getheads(self): return self.base.getheads() - def getcommit(self, rev): + def getcommit(self, rev: bytes): # We want to save a reference to the commit objects to be able # to rewrite their parents later on. c = self.commits[rev] = self.base.getcommit(rev) @@ -292,7 +325,7 @@ return self.commits[rev] return self.base.getcommit(rev) - def _discard(self, *revs): + def _discard(self, *revs) -> None: for r in revs: if r is None: continue @@ -304,7 +337,7 @@ if self._rebuilt: del self.children[r] - def wanted(self, rev, i): + def wanted(self, rev, i) -> bool: # Return True if we're directly interested in rev. # # i is an index selecting one of the parents of rev (if rev @@ -332,7 +365,7 @@ # doesn't consider it significant, and this revision should be dropped. return not files and b'close' not in self.commits[rev].extra - def mark_not_wanted(self, rev, p): + def mark_not_wanted(self, rev, p) -> None: # Mark rev as not interesting and update data structures. if p is None: @@ -347,7 +380,7 @@ self.parentmap[rev] = self.parentmap[p] self.wantedancestors[rev] = self.wantedancestors[p] - def mark_wanted(self, rev, parents): + def mark_wanted(self, rev, parents) -> None: # Mark rev ss wanted and update data structures. # rev will be in the restricted graph, so children of rev in @@ -474,7 +507,7 @@ return files, ncopies, ncleanp2 - def targetfilebelongstosource(self, targetfilename): + def targetfilebelongstosource(self, targetfilename: bytes) -> bool: return self.filemapper.istargetfile(targetfilename) def getfile(self, name, rev): @@ -484,7 +517,7 @@ def gettags(self): return self.base.gettags() - def hasnativeorder(self): + def hasnativeorder(self) -> bool: return self.base.hasnativeorder() def lookuprev(self, rev):