changeset 37621:5537d8f5e989

patch: make extract() a context manager (API) Previously, this function was creating a temporary file and relying on callers to unlink it. Yuck. We convert the function to a context manager and tie the lifetime of the temporary file to that of the context manager. This changed indentation not only from the context manager, but also from the elination of try blocks. It was just easier to split the heart of extract() into its own function. The single consumer of this function has been refactored to use it as a context manager. Code for cleaning up the file in tryimportone() has also been removed. .. api:: ``patch.extract()`` is now a context manager. Callers no longer have to worry about deleting the temporary file it creates, as the file is tied to the lifetime of the context manager. Differential Revision: https://phab.mercurial-scm.org/D3306
author Gregory Szorc <gregory.szorc@gmail.com>
date Thu, 12 Apr 2018 23:14:38 -0700
parents fd1dd79cff20
children bfdd20d22a86
files mercurial/cmdutil.py mercurial/commands.py mercurial/patch.py
diffstat 3 files changed, 210 insertions(+), 208 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/cmdutil.py	Thu Apr 12 23:06:27 2018 -0700
+++ b/mercurial/cmdutil.py	Thu Apr 12 23:14:38 2018 -0700
@@ -1379,141 +1379,139 @@
     strip = opts["strip"]
     prefix = opts["prefix"]
     sim = float(opts.get('similarity') or 0)
+
     if not tmpname:
-        return (None, None, False)
+        return None, None, False
 
     rejects = False
 
-    try:
-        cmdline_message = logmessage(ui, opts)
-        if cmdline_message:
-            # pickup the cmdline msg
-            message = cmdline_message
-        elif message:
-            # pickup the patch msg
-            message = message.strip()
-        else:
-            # launch the editor
-            message = None
-        ui.debug('message:\n%s\n' % (message or ''))
-
-        if len(parents) == 1:
-            parents.append(repo[nullid])
-        if opts.get('exact'):
-            if not nodeid or not p1:
-                raise error.Abort(_('not a Mercurial patch'))
+    cmdline_message = logmessage(ui, opts)
+    if cmdline_message:
+        # pickup the cmdline msg
+        message = cmdline_message
+    elif message:
+        # pickup the patch msg
+        message = message.strip()
+    else:
+        # launch the editor
+        message = None
+    ui.debug('message:\n%s\n' % (message or ''))
+
+    if len(parents) == 1:
+        parents.append(repo[nullid])
+    if opts.get('exact'):
+        if not nodeid or not p1:
+            raise error.Abort(_('not a Mercurial patch'))
+        p1 = repo[p1]
+        p2 = repo[p2 or nullid]
+    elif p2:
+        try:
             p1 = repo[p1]
-            p2 = repo[p2 or nullid]
-        elif p2:
-            try:
-                p1 = repo[p1]
-                p2 = repo[p2]
-                # Without any options, consider p2 only if the
-                # patch is being applied on top of the recorded
-                # first parent.
-                if p1 != parents[0]:
-                    p1 = parents[0]
-                    p2 = repo[nullid]
-            except error.RepoError:
-                p1, p2 = parents
-            if p2.node() == nullid:
-                ui.warn(_("warning: import the patch as a normal revision\n"
-                          "(use --exact to import the patch as a merge)\n"))
+            p2 = repo[p2]
+            # Without any options, consider p2 only if the
+            # patch is being applied on top of the recorded
+            # first parent.
+            if p1 != parents[0]:
+                p1 = parents[0]
+                p2 = repo[nullid]
+        except error.RepoError:
+            p1, p2 = parents
+        if p2.node() == nullid:
+            ui.warn(_("warning: import the patch as a normal revision\n"
+                      "(use --exact to import the patch as a merge)\n"))
+    else:
+        p1, p2 = parents
+
+    n = None
+    if update:
+        if p1 != parents[0]:
+            updatefunc(repo, p1.node())
+        if p2 != parents[1]:
+            repo.setparents(p1.node(), p2.node())
+
+        if opts.get('exact') or importbranch:
+            repo.dirstate.setbranch(branch or 'default')
+
+        partial = opts.get('partial', False)
+        files = set()
+        try:
+            patch.patch(ui, repo, tmpname, strip=strip, prefix=prefix,
+                        files=files, eolmode=None, similarity=sim / 100.0)
+        except error.PatchError as e:
+            if not partial:
+                raise error.Abort(pycompat.bytestr(e))
+            if partial:
+                rejects = True
+
+        files = list(files)
+        if nocommit:
+            if message:
+                msgs.append(message)
         else:
-            p1, p2 = parents
-
-        n = None
-        if update:
-            if p1 != parents[0]:
-                updatefunc(repo, p1.node())
-            if p2 != parents[1]:
-                repo.setparents(p1.node(), p2.node())
-
-            if opts.get('exact') or importbranch:
-                repo.dirstate.setbranch(branch or 'default')
-
-            partial = opts.get('partial', False)
+            if opts.get('exact') or p2:
+                # If you got here, you either use --force and know what
+                # you are doing or used --exact or a merge patch while
+                # being updated to its first parent.
+                m = None
+            else:
+                m = scmutil.matchfiles(repo, files or [])
+            editform = mergeeditform(repo[None], 'import.normal')
+            if opts.get('exact'):
+                editor = None
+            else:
+                editor = getcommiteditor(editform=editform,
+                                         **pycompat.strkwargs(opts))
+            extra = {}
+            for idfunc in extrapreimport:
+                extrapreimportmap[idfunc](repo, patchdata, extra, opts)
+            overrides = {}
+            if partial:
+                overrides[('ui', 'allowemptycommit')] = True
+            with repo.ui.configoverride(overrides, 'import'):
+                n = repo.commit(message, user,
+                                date, match=m,
+                                editor=editor, extra=extra)
+                for idfunc in extrapostimport:
+                    extrapostimportmap[idfunc](repo[n])
+    else:
+        if opts.get('exact') or importbranch:
+            branch = branch or 'default'
+        else:
+            branch = p1.branch()
+        store = patch.filestore()
+        try:
             files = set()
             try:
-                patch.patch(ui, repo, tmpname, strip=strip, prefix=prefix,
-                            files=files, eolmode=None, similarity=sim / 100.0)
+                patch.patchrepo(ui, repo, p1, store, tmpname, strip, prefix,
+                                files, eolmode=None)
             except error.PatchError as e:
-                if not partial:
-                    raise error.Abort(pycompat.bytestr(e))
-                if partial:
-                    rejects = True
-
-            files = list(files)
-            if nocommit:
-                if message:
-                    msgs.append(message)
+                raise error.Abort(stringutil.forcebytestr(e))
+            if opts.get('exact'):
+                editor = None
             else:
-                if opts.get('exact') or p2:
-                    # If you got here, you either use --force and know what
-                    # you are doing or used --exact or a merge patch while
-                    # being updated to its first parent.
-                    m = None
-                else:
-                    m = scmutil.matchfiles(repo, files or [])
-                editform = mergeeditform(repo[None], 'import.normal')
-                if opts.get('exact'):
-                    editor = None
-                else:
-                    editor = getcommiteditor(editform=editform,
-                                             **pycompat.strkwargs(opts))
-                extra = {}
-                for idfunc in extrapreimport:
-                    extrapreimportmap[idfunc](repo, patchdata, extra, opts)
-                overrides = {}
-                if partial:
-                    overrides[('ui', 'allowemptycommit')] = True
-                with repo.ui.configoverride(overrides, 'import'):
-                    n = repo.commit(message, user,
-                                    date, match=m,
-                                    editor=editor, extra=extra)
-                    for idfunc in extrapostimport:
-                        extrapostimportmap[idfunc](repo[n])
-        else:
-            if opts.get('exact') or importbranch:
-                branch = branch or 'default'
-            else:
-                branch = p1.branch()
-            store = patch.filestore()
-            try:
-                files = set()
-                try:
-                    patch.patchrepo(ui, repo, p1, store, tmpname, strip, prefix,
-                                    files, eolmode=None)
-                except error.PatchError as e:
-                    raise error.Abort(stringutil.forcebytestr(e))
-                if opts.get('exact'):
-                    editor = None
-                else:
-                    editor = getcommiteditor(editform='import.bypass')
-                memctx = context.memctx(repo, (p1.node(), p2.node()),
-                                            message,
-                                            files=files,
-                                            filectxfn=store,
-                                            user=user,
-                                            date=date,
-                                            branch=branch,
-                                            editor=editor)
-                n = memctx.commit()
-            finally:
-                store.close()
-        if opts.get('exact') and nocommit:
-            # --exact with --no-commit is still useful in that it does merge
-            # and branch bits
-            ui.warn(_("warning: can't check exact import with --no-commit\n"))
-        elif opts.get('exact') and hex(n) != nodeid:
-            raise error.Abort(_('patch is damaged or loses information'))
-        msg = _('applied to working directory')
-        if n:
-            # i18n: refers to a short changeset id
-            msg = _('created %s') % short(n)
-        return (msg, n, rejects)
-    finally:
-        os.unlink(tmpname)
+                editor = getcommiteditor(editform='import.bypass')
+            memctx = context.memctx(repo, (p1.node(), p2.node()),
+                                    message,
+                                    files=files,
+                                    filectxfn=store,
+                                    user=user,
+                                    date=date,
+                                    branch=branch,
+                                    editor=editor)
+            n = memctx.commit()
+        finally:
+            store.close()
+    if opts.get('exact') and nocommit:
+        # --exact with --no-commit is still useful in that it does merge
+        # and branch bits
+        ui.warn(_("warning: can't check exact import with --no-commit\n"))
+    elif opts.get('exact') and hex(n) != nodeid:
+        raise error.Abort(_('patch is damaged or loses information'))
+    msg = _('applied to working directory')
+    if n:
+        # i18n: refers to a short changeset id
+        msg = _('created %s') % short(n)
+    return msg, n, rejects
 
 # facility to let extensions include additional data in an exported patch
 # list of identifiers to be executed in order
--- a/mercurial/commands.py	Thu Apr 12 23:06:27 2018 -0700
+++ b/mercurial/commands.py	Thu Apr 12 23:14:38 2018 -0700
@@ -3089,11 +3089,10 @@
 
             haspatch = False
             for hunk in patch.split(patchfile):
-                patchdata = patch.extract(ui, hunk)
-
-                msg, node, rej = cmdutil.tryimportone(ui, repo, patchdata,
-                                                      parents, opts,
-                                                      msgs, hg.clean)
+                with patch.extract(ui, hunk) as patchdata:
+                    msg, node, rej = cmdutil.tryimportone(ui, repo, patchdata,
+                                                          parents, opts,
+                                                          msgs, hg.clean)
                 if msg:
                     haspatch = True
                     ui.note(msg + '\n')
--- a/mercurial/patch.py	Thu Apr 12 23:06:27 2018 -0700
+++ b/mercurial/patch.py	Thu Apr 12 23:14:38 2018 -0700
@@ -9,6 +9,7 @@
 from __future__ import absolute_import, print_function
 
 import collections
+import contextlib
 import copy
 import difflib
 import email
@@ -192,6 +193,7 @@
                   ('Node ID', 'nodeid'),
                  ]
 
+@contextlib.contextmanager
 def extract(ui, fileobj):
     '''extract patch from data read from fileobj.
 
@@ -209,6 +211,16 @@
     Any item can be missing from the dictionary. If filename is missing,
     fileobj did not contain a patch. Caller must unlink filename when done.'''
 
+    fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
+    tmpfp = os.fdopen(fd, r'wb')
+    try:
+        yield _extract(ui, fileobj, tmpname, tmpfp)
+    finally:
+        tmpfp.close()
+        os.unlink(tmpname)
+
+def _extract(ui, fileobj, tmpname, tmpfp):
+
     # attempt to detect the start of a patch
     # (this heuristic is borrowed from quilt)
     diffre = re.compile(br'^(?:Index:[ \t]|diff[ \t]-|RCS file: |'
@@ -218,86 +230,80 @@
                         re.MULTILINE | re.DOTALL)
 
     data = {}
-    fd, tmpname = tempfile.mkstemp(prefix='hg-patch-')
-    tmpfp = os.fdopen(fd, r'wb')
-    try:
-        msg = pycompat.emailparser().parse(fileobj)
+
+    msg = pycompat.emailparser().parse(fileobj)
 
-        subject = msg[r'Subject'] and mail.headdecode(msg[r'Subject'])
-        data['user'] = msg[r'From'] and mail.headdecode(msg[r'From'])
-        if not subject and not data['user']:
-            # Not an email, restore parsed headers if any
-            subject = '\n'.join(': '.join(map(encoding.strtolocal, h))
-                                for h in msg.items()) + '\n'
+    subject = msg[r'Subject'] and mail.headdecode(msg[r'Subject'])
+    data['user'] = msg[r'From'] and mail.headdecode(msg[r'From'])
+    if not subject and not data['user']:
+        # Not an email, restore parsed headers if any
+        subject = '\n'.join(': '.join(map(encoding.strtolocal, h))
+                            for h in msg.items()) + '\n'
 
-        # should try to parse msg['Date']
-        parents = []
+    # should try to parse msg['Date']
+    parents = []
 
-        if subject:
-            if subject.startswith('[PATCH'):
-                pend = subject.find(']')
-                if pend >= 0:
-                    subject = subject[pend + 1:].lstrip()
-            subject = re.sub(br'\n[ \t]+', ' ', subject)
-            ui.debug('Subject: %s\n' % subject)
-        if data['user']:
-            ui.debug('From: %s\n' % data['user'])
-        diffs_seen = 0
-        ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
-        message = ''
-        for part in msg.walk():
-            content_type = pycompat.bytestr(part.get_content_type())
-            ui.debug('Content-Type: %s\n' % content_type)
-            if content_type not in ok_types:
-                continue
-            payload = part.get_payload(decode=True)
-            m = diffre.search(payload)
-            if m:
-                hgpatch = False
-                hgpatchheader = False
-                ignoretext = False
+    if subject:
+        if subject.startswith('[PATCH'):
+            pend = subject.find(']')
+            if pend >= 0:
+                subject = subject[pend + 1:].lstrip()
+        subject = re.sub(br'\n[ \t]+', ' ', subject)
+        ui.debug('Subject: %s\n' % subject)
+    if data['user']:
+        ui.debug('From: %s\n' % data['user'])
+    diffs_seen = 0
+    ok_types = ('text/plain', 'text/x-diff', 'text/x-patch')
+    message = ''
+    for part in msg.walk():
+        content_type = pycompat.bytestr(part.get_content_type())
+        ui.debug('Content-Type: %s\n' % content_type)
+        if content_type not in ok_types:
+            continue
+        payload = part.get_payload(decode=True)
+        m = diffre.search(payload)
+        if m:
+            hgpatch = False
+            hgpatchheader = False
+            ignoretext = False
 
-                ui.debug('found patch at byte %d\n' % m.start(0))
-                diffs_seen += 1
-                cfp = stringio()
-                for line in payload[:m.start(0)].splitlines():
-                    if line.startswith('# HG changeset patch') and not hgpatch:
-                        ui.debug('patch generated by hg export\n')
-                        hgpatch = True
-                        hgpatchheader = True
-                        # drop earlier commit message content
-                        cfp.seek(0)
-                        cfp.truncate()
-                        subject = None
-                    elif hgpatchheader:
-                        if line.startswith('# User '):
-                            data['user'] = line[7:]
-                            ui.debug('From: %s\n' % data['user'])
-                        elif line.startswith("# Parent "):
-                            parents.append(line[9:].lstrip())
-                        elif line.startswith("# "):
-                            for header, key in patchheadermap:
-                                prefix = '# %s ' % header
-                                if line.startswith(prefix):
-                                    data[key] = line[len(prefix):]
-                        else:
-                            hgpatchheader = False
-                    elif line == '---':
-                        ignoretext = True
-                    if not hgpatchheader and not ignoretext:
-                        cfp.write(line)
-                        cfp.write('\n')
-                message = cfp.getvalue()
-                if tmpfp:
-                    tmpfp.write(payload)
-                    if not payload.endswith('\n'):
-                        tmpfp.write('\n')
-            elif not diffs_seen and message and content_type == 'text/plain':
-                message += '\n' + payload
-    except: # re-raises
-        tmpfp.close()
-        os.unlink(tmpname)
-        raise
+            ui.debug('found patch at byte %d\n' % m.start(0))
+            diffs_seen += 1
+            cfp = stringio()
+            for line in payload[:m.start(0)].splitlines():
+                if line.startswith('# HG changeset patch') and not hgpatch:
+                    ui.debug('patch generated by hg export\n')
+                    hgpatch = True
+                    hgpatchheader = True
+                    # drop earlier commit message content
+                    cfp.seek(0)
+                    cfp.truncate()
+                    subject = None
+                elif hgpatchheader:
+                    if line.startswith('# User '):
+                        data['user'] = line[7:]
+                        ui.debug('From: %s\n' % data['user'])
+                    elif line.startswith("# Parent "):
+                        parents.append(line[9:].lstrip())
+                    elif line.startswith("# "):
+                        for header, key in patchheadermap:
+                            prefix = '# %s ' % header
+                            if line.startswith(prefix):
+                                data[key] = line[len(prefix):]
+                    else:
+                        hgpatchheader = False
+                elif line == '---':
+                    ignoretext = True
+                if not hgpatchheader and not ignoretext:
+                    cfp.write(line)
+                    cfp.write('\n')
+            message = cfp.getvalue()
+            if tmpfp:
+                tmpfp.write(payload)
+                if not payload.endswith('\n'):
+                    tmpfp.write('\n')
+        elif not diffs_seen and message and content_type == 'text/plain':
+            message += '\n' + payload
 
     if subject and not message.startswith(subject):
         message = '%s\n%s' % (subject, message)
@@ -310,8 +316,7 @@
 
     if diffs_seen:
         data['filename'] = tmpname
-    else:
-        os.unlink(tmpname)
+
     return data
 
 class patchmeta(object):