changegroup: verify all stream reads
Mercurial often failed with struct.error or mpatch.mpatchError if incomplete
data was received from a server.
Now we validate all changegroup reads and aborts with
abort: stream ended unexpectedly (got %d bytes, expected %d)
if less than requested was read.
--- a/mercurial/changegroup.py Tue Feb 22 03:02:50 2011 +0100
+++ b/mercurial/changegroup.py Tue Feb 22 03:03:39 2011 +0100
@@ -9,18 +9,22 @@
import util
import struct, os, bz2, zlib, tempfile
-def getchunk(source):
- """return the next chunk from changegroup 'source' as a string"""
- d = source.read(4)
+def readexactly(stream, n):
+ '''read n bytes from stream.read and abort if less was available'''
+ s = stream.read(n)
+ if len(s) < n:
+ raise util.Abort(_("stream ended unexpectedly"
+ " (got %d bytes, expected %d)")
+ % (len(s), n))
+ return s
+
+def getchunk(stream):
+ """return the next chunk from stream as a string"""
+ d = readexactly(stream, 4)
l = struct.unpack(">l", d)[0]
if l <= 4:
return ""
- d = source.read(l - 4)
- if len(d) < l - 4:
- raise util.Abort(_("premature EOF reading chunk"
- " (got %d bytes, expected %d)")
- % (len(d), l - 4))
- return d
+ return readexactly(stream, l - 4)
def chunkheader(length):
"""return a changegroup chunk header (string)"""
@@ -145,7 +149,7 @@
return self._stream.close()
def chunklength(self):
- d = self.read(4)
+ d = readexactly(self._stream, 4)
l = max(0, struct.unpack(">l", d)[0] - 4)
if l and self.callback:
self.callback()
@@ -154,20 +158,15 @@
def chunk(self):
"""return the next chunk from changegroup 'source' as a string"""
l = self.chunklength()
- d = self.read(l)
- if len(d) < l:
- raise util.Abort(_("premature EOF reading chunk"
- " (got %d bytes, expected %d)")
- % (len(d), l))
- return d
+ return readexactly(self._stream, l)
def parsechunk(self):
l = self.chunklength()
if not l:
return {}
- h = self.read(80)
+ h = readexactly(self._stream, 80)
node, p1, p2, cs = struct.unpack("20s20s20s20s", h)
- data = self.read(l - 80)
+ data = readexactly(self._stream, l - 80)
return dict(node=node, p1=p1, p2=p2, cs=cs, data=data)
class headerlessfixup(object):
@@ -178,12 +177,12 @@
if self._h:
d, self._h = self._h[:n], self._h[n:]
if len(d) < n:
- d += self._fh.read(n - len(d))
+ d += readexactly(self._fh, n - len(d))
return d
- return self._fh.read(n)
+ return readexactly(self._fh, n)
def readbundle(fh, fname):
- header = fh.read(6)
+ header = readexactly(fh, 6)
if not fname:
fname = "stream"