httppeer: add support for httppostargs when we're sending a file
authorAugie Fackler <augie@google.com>
Wed, 26 Jul 2017 17:58:19 -0400
changeset 33821 3c91cc0c5fde
parent 33820 fa7e30efe05a
child 33822 42ad7cc645a4
httppeer: add support for httppostargs when we're sending a file This is probably only used in the 'unbundle' command, but the code ended up being cleaner to make it generic and treat *all* httppostargs with a non-args request body as though they were file-like in nature. It also means we get test coverage more or less for free. A previous version of this change didn't use io.BytesIO, and it was a lot more complicated. This also fixes a server-side bug, so anyone using httppostargs should update all of their servers to this revision or later *before* this gets to their clients, otherwise servers will hang trying to over-read the POST body. Differential Revision: https://phab.mercurial-scm.org/D231
mercurial/hgweb/protocol.py
mercurial/httppeer.py
--- a/mercurial/hgweb/protocol.py	Tue Aug 15 21:09:33 2017 +0900
+++ b/mercurial/hgweb/protocol.py	Wed Jul 26 17:58:19 2017 -0400
@@ -75,6 +75,9 @@
         return args
     def getfile(self, fp):
         length = int(self.req.env['CONTENT_LENGTH'])
+        # If httppostargs is used, we need to read Content-Length
+        # minus the amount that was consumed by args.
+        length -= int(self.req.env.get('HTTP_X_HGARGS_POST', 0))
         for s in util.filechunkiter(self.req, limit=length):
             fp.write(s)
     def redirect(self):
--- a/mercurial/httppeer.py	Tue Aug 15 21:09:33 2017 +0900
+++ b/mercurial/httppeer.py	Wed Jul 26 17:58:19 2017 -0400
@@ -9,6 +9,7 @@
 from __future__ import absolute_import
 
 import errno
+import io
 import os
 import socket
 import struct
@@ -86,6 +87,45 @@
 
     resp.__class__ = readerproxy
 
+class _multifile(object):
+    def __init__(self, *fileobjs):
+        for f in fileobjs:
+            if not util.safehasattr(f, 'length'):
+                raise ValueError(
+                    '_multifile only supports file objects that '
+                    'have a length but this one does not:', type(f), f)
+        self._fileobjs = fileobjs
+        self._index = 0
+
+    @property
+    def length(self):
+        return sum(f.length for f in self._fileobjs)
+
+    def read(self, amt=None):
+        if amt <= 0:
+            return ''.join(f.read() for f in self._fileobjs)
+        parts = []
+        while amt and self._index < len(self._fileobjs):
+            parts.append(self._fileobjs[self._index].read(amt))
+            got = len(parts[-1])
+            if got < amt:
+                self._index += 1
+            amt -= got
+        return ''.join(parts)
+
+    def seek(self, offset, whence=os.SEEK_SET):
+        if whence != os.SEEK_SET:
+            raise NotImplementedError(
+                '_multifile does not support anything other'
+                ' than os.SEEK_SET for whence on seek()')
+        if offset != 0:
+            raise NotImplementedError(
+                '_multifile only supports seeking to start, but that '
+                'could be fixed if you need it')
+        for f in self._fileobjs:
+            f.seek(0)
+        self._index = 0
+
 class httppeer(wireproto.wirepeer):
     def __init__(self, ui, path):
         self._path = path
@@ -169,17 +209,19 @@
         # with infinite recursion when trying to look up capabilities
         # for the first time.
         postargsok = self._caps is not None and 'httppostargs' in self._caps
-        # TODO: support for httppostargs when data is a file-like
-        # object rather than a basestring
-        canmungedata = not data or isinstance(data, basestring)
-        if postargsok and canmungedata:
+        if postargsok and args:
             strargs = urlreq.urlencode(sorted(args.items()))
-            if strargs:
-                if not data:
-                    data = strargs
-                elif isinstance(data, basestring):
-                    data = strargs + data
-                headers['X-HgArgs-Post'] = len(strargs)
+            if not data:
+                data = strargs
+            else:
+                if isinstance(data, basestring):
+                    i = io.BytesIO(data)
+                    i.length = len(data)
+                    data = i
+                argsio = io.BytesIO(strargs)
+                argsio.length = len(strargs)
+                data = _multifile(argsio, data)
+            headers['X-HgArgs-Post'] = len(strargs)
         else:
             if len(args) > 0:
                 httpheader = self.capable('httpheader')