changeset 36395:01e29e885600

util: add a file object proxy that can read at most N bytes Sometimes we have data of a known size within a stream. For performance reasons, we don't want to pre-read this data (we want to allow consumers to read on demand). For simplicitly reasons, we don't want callers to necessarily know their data is coming from within an outer stream and there is a limit to how much they should read. The class introduced by this commit provides a very simple proxy around an underlying file object that allows the consumer to .read() up to N bytes from the file object. Attempts to read past this many bytes results in a simulated EOF. Differential Revision: https://phab.mercurial-scm.org/D2377
author Gregory Szorc <gregory.szorc@gmail.com>
date Wed, 21 Feb 2018 13:41:20 -0800
parents a2d11d23bb25
children 7f8f74531b0b
files mercurial/util.py tests/test-cappedreader.py
diffstat 2 files changed, 120 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/util.py	Mon Feb 05 15:03:51 2018 +0100
+++ b/mercurial/util.py	Wed Feb 21 13:41:20 2018 -0800
@@ -1980,6 +1980,35 @@
             limit -= len(s)
         yield s
 
+class cappedreader(object):
+    """A file object proxy that allows reading up to N bytes.
+
+    Given a source file object, instances of this type allow reading up to
+    N bytes from that source file object. Attempts to read past the allowed
+    limit are treated as EOF.
+
+    It is assumed that I/O is not performed on the original file object
+    in addition to I/O that is performed by this instance. If there is,
+    state tracking will get out of sync and unexpected results will ensue.
+    """
+    def __init__(self, fh, limit):
+        """Allow reading up to <limit> bytes from <fh>."""
+        self._fh = fh
+        self._left = limit
+
+    def read(self, n=-1):
+        if not self._left:
+            return b''
+
+        if n < 0:
+            n = self._left
+
+        data = self._fh.read(min(n, self._left))
+        self._left -= len(data)
+        assert self._left >= 0
+
+        return data
+
 def makedate(timestamp=None):
     '''Return a unix timestamp (or the current time) as a (unixtime,
     offset) tuple based off the local timezone.'''
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/test-cappedreader.py	Wed Feb 21 13:41:20 2018 -0800
@@ -0,0 +1,91 @@
+from __future__ import absolute_import, print_function
+
+import io
+import unittest
+
+from mercurial import (
+    util,
+)
+
+class CappedReaderTests(unittest.TestCase):
+    def testreadfull(self):
+        source = io.BytesIO(b'x' * 100)
+
+        reader = util.cappedreader(source, 10)
+        res = reader.read(10)
+        self.assertEqual(res, b'x' * 10)
+        self.assertEqual(source.tell(), 10)
+        source.seek(0)
+
+        reader = util.cappedreader(source, 15)
+        res = reader.read(16)
+        self.assertEqual(res, b'x' * 15)
+        self.assertEqual(source.tell(), 15)
+        source.seek(0)
+
+        reader = util.cappedreader(source, 100)
+        res = reader.read(100)
+        self.assertEqual(res, b'x' * 100)
+        self.assertEqual(source.tell(), 100)
+        source.seek(0)
+
+        reader = util.cappedreader(source, 50)
+        res = reader.read()
+        self.assertEqual(res, b'x' * 50)
+        self.assertEqual(source.tell(), 50)
+        source.seek(0)
+
+    def testreadnegative(self):
+        source = io.BytesIO(b'x' * 100)
+
+        reader = util.cappedreader(source, 20)
+        res = reader.read(-1)
+        self.assertEqual(res, b'x' * 20)
+        self.assertEqual(source.tell(), 20)
+        source.seek(0)
+
+        reader = util.cappedreader(source, 100)
+        res = reader.read(-1)
+        self.assertEqual(res, b'x' * 100)
+        self.assertEqual(source.tell(), 100)
+        source.seek(0)
+
+    def testreadmultiple(self):
+        source = io.BytesIO(b'x' * 100)
+
+        reader = util.cappedreader(source, 10)
+        for i in range(10):
+            res = reader.read(1)
+            self.assertEqual(res, b'x')
+            self.assertEqual(source.tell(), i + 1)
+
+        self.assertEqual(source.tell(), 10)
+        res = reader.read(1)
+        self.assertEqual(res, b'')
+        self.assertEqual(source.tell(), 10)
+        source.seek(0)
+
+        reader = util.cappedreader(source, 45)
+        for i in range(4):
+            res = reader.read(10)
+            self.assertEqual(res, b'x' * 10)
+            self.assertEqual(source.tell(), (i + 1) * 10)
+
+        res = reader.read(10)
+        self.assertEqual(res, b'x' * 5)
+        self.assertEqual(source.tell(), 45)
+
+    def readlimitpasteof(self):
+        source = io.BytesIO(b'x' * 100)
+
+        reader = util.cappedreader(source, 1024)
+        res = reader.read(1000)
+        self.assertEqual(res, b'x' * 100)
+        self.assertEqual(source.tell(), 100)
+        res = reader.read(1000)
+        self.assertEqual(res, b'')
+        self.assertEqual(source.tell(), 100)
+
+if __name__ == '__main__':
+    import silenttestrunner
+    silenttestrunner.main(__name__)