--- a/mercurial/wireprotoframing.py Thu Oct 04 17:39:16 2018 -0700
+++ b/mercurial/wireprotoframing.py Mon Oct 08 17:10:59 2018 -0700
@@ -648,6 +648,140 @@
flags=FLAG_COMMAND_RESPONSE_CONTINUATION,
payload=payload)
+# TODO consider defining encoders/decoders using the util.compressionengine
+# mechanism.
+
+class identityencoder(object):
+ """Encoder for the "identity" stream encoding profile."""
+ def __init__(self, ui):
+ pass
+
+ def encode(self, data):
+ return data
+
+ def flush(self):
+ return b''
+
+ def finish(self):
+ return b''
+
+class identitydecoder(object):
+ """Decoder for the "identity" stream encoding profile."""
+
+ def __init__(self, ui, extraobjs):
+ if extraobjs:
+ raise error.Abort(_('identity decoder received unexpected '
+ 'additional values'))
+
+ def decode(self, data):
+ return data
+
+class zlibencoder(object):
+ def __init__(self, ui):
+ import zlib
+ self._zlib = zlib
+ self._compressor = zlib.compressobj()
+
+ def encode(self, data):
+ return self._compressor.compress(data)
+
+ def flush(self):
+ # Z_SYNC_FLUSH doesn't reset compression context, which is
+ # what we want.
+ return self._compressor.flush(self._zlib.Z_SYNC_FLUSH)
+
+ def finish(self):
+ res = self._compressor.flush(self._zlib.Z_FINISH)
+ self._compressor = None
+ return res
+
+class zlibdecoder(object):
+ def __init__(self, ui, extraobjs):
+ import zlib
+
+ if extraobjs:
+ raise error.Abort(_('zlib decoder received unexpected '
+ 'additional values'))
+
+ self._decompressor = zlib.decompressobj()
+
+ def decode(self, data):
+ # Python 2's zlib module doesn't use the buffer protocol and can't
+ # handle all bytes-like types.
+ if not pycompat.ispy3 and isinstance(data, bytearray):
+ data = bytes(data)
+
+ return self._decompressor.decompress(data)
+
+class zstdbaseencoder(object):
+ def __init__(self, level):
+ from . import zstd
+
+ self._zstd = zstd
+ cctx = zstd.ZstdCompressor(level=level)
+ self._compressor = cctx.compressobj()
+
+ def encode(self, data):
+ return self._compressor.compress(data)
+
+ def flush(self):
+ # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the
+ # compressor and allows a decompressor to access all encoded data
+ # up to this point.
+ return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK)
+
+ def finish(self):
+ res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH)
+ self._compressor = None
+ return res
+
+class zstd8mbencoder(zstdbaseencoder):
+ def __init__(self, ui):
+ super(zstd8mbencoder, self).__init__(3)
+
+class zstdbasedecoder(object):
+ def __init__(self, maxwindowsize):
+ from . import zstd
+ dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize)
+ self._decompressor = dctx.decompressobj()
+
+ def decode(self, data):
+ return self._decompressor.decompress(data)
+
+class zstd8mbdecoder(zstdbasedecoder):
+ def __init__(self, ui, extraobjs):
+ if extraobjs:
+ raise error.Abort(_('zstd8mb decoder received unexpected '
+ 'additional values'))
+
+ super(zstd8mbdecoder, self).__init__(maxwindowsize=8 * 1048576)
+
+# We lazily populate this to avoid excessive module imports when importing
+# this module.
+STREAM_ENCODERS = {}
+STREAM_ENCODERS_ORDER = []
+
+def populatestreamencoders():
+ if STREAM_ENCODERS:
+ return
+
+ try:
+ from . import zstd
+ zstd.__version__
+ except ImportError:
+ zstd = None
+
+ # zstandard is fastest and is preferred.
+ if zstd:
+ STREAM_ENCODERS[b'zstd-8mb'] = (zstd8mbencoder, zstd8mbdecoder)
+ STREAM_ENCODERS_ORDER.append(b'zstd-8mb')
+
+ STREAM_ENCODERS[b'zlib'] = (zlibencoder, zlibdecoder)
+ STREAM_ENCODERS_ORDER.append(b'zlib')
+
+ STREAM_ENCODERS[b'identity'] = (identityencoder, identitydecoder)
+ STREAM_ENCODERS_ORDER.append(b'identity')
+
class stream(object):
"""Represents a logical unidirectional series of frames."""
@@ -671,16 +805,70 @@
class inputstream(stream):
"""Represents a stream used for receiving data."""
- def setdecoder(self, name, extraobjs):
+ def __init__(self, streamid, active=False):
+ super(inputstream, self).__init__(streamid, active=active)
+ self._decoder = None
+
+ def setdecoder(self, ui, name, extraobjs):
"""Set the decoder for this stream.
Receives the stream profile name and any additional CBOR objects
decoded from the stream encoding settings frame payloads.
"""
+ if name not in STREAM_ENCODERS:
+ raise error.Abort(_('unknown stream decoder: %s') % name)
+
+ self._decoder = STREAM_ENCODERS[name][1](ui, extraobjs)
+
+ def decode(self, data):
+ # Default is identity decoder. We don't bother instantiating one
+ # because it is trivial.
+ if not self._decoder:
+ return data
+
+ return self._decoder.decode(data)
+
+ def flush(self):
+ if not self._decoder:
+ return b''
+
+ return self._decoder.flush()
class outputstream(stream):
"""Represents a stream used for sending data."""
+ def __init__(self, streamid, active=False):
+ super(outputstream, self).__init__(streamid, active=active)
+ self._encoder = None
+
+ def setencoder(self, ui, name):
+ """Set the encoder for this stream.
+
+ Receives the stream profile name.
+ """
+ if name not in STREAM_ENCODERS:
+ raise error.Abort(_('unknown stream encoder: %s') % name)
+
+ self._encoder = STREAM_ENCODERS[name][0](ui)
+
+ def encode(self, data):
+ if not self._encoder:
+ return data
+
+ return self._encoder.encode(data)
+
+ def flush(self):
+ if not self._encoder:
+ return b''
+
+ return self._encoder.flush()
+
+ def finish(self):
+ if not self._encoder:
+ return b''
+
+ self._encoder.finish()
+
def ensureserverstream(stream):
if stream.streamid % 2:
raise error.ProgrammingError('server should only write to even '
@@ -786,6 +974,8 @@
# Sender protocol settings are optional. Set implied default values.
self._sendersettings = dict(DEFAULT_PROTOCOL_SETTINGS)
+ populatestreamencoders()
+
def onframerecv(self, frame):
"""Process a frame that has been received off the wire.
@@ -1384,6 +1574,8 @@
self._incomingstreams = {}
self._streamsettingsdecoders = {}
+ populatestreamencoders()
+
def callcommand(self, name, args, datafh=None, redirect=None):
"""Request that a command be executed.
@@ -1494,9 +1686,13 @@
self._incomingstreams[frame.streamid] = inputstream(
frame.streamid)
+ stream = self._incomingstreams[frame.streamid]
+
+ # If the payload is encoded, ask the stream to decode it. We
+ # merely substitute the decoded result into the frame payload as
+ # if it had been transferred all along.
if frame.streamflags & STREAM_FLAG_ENCODING_APPLIED:
- raise error.ProgrammingError('support for decoding stream '
- 'payloads not yet implemneted')
+ frame.payload = stream.decode(frame.payload)
if frame.streamflags & STREAM_FLAG_END_STREAM:
del self._incomingstreams[frame.streamid]
@@ -1573,7 +1769,8 @@
}
try:
- self._incomingstreams[frame.streamid].setdecoder(decoded[0],
+ self._incomingstreams[frame.streamid].setdecoder(self._ui,
+ decoded[0],
decoded[1:])
except Exception as e:
return 'error', {
--- a/tests/test-wireproto-clientreactor.py Thu Oct 04 17:39:16 2018 -0700
+++ b/tests/test-wireproto-clientreactor.py Mon Oct 08 17:10:59 2018 -0700
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import unittest
+import zlib
from mercurial import (
error,
@@ -11,6 +12,12 @@
cborutil,
)
+try:
+ from mercurial import zstd
+ zstd.__version__
+except ImportError:
+ zstd = None
+
ffs = framing.makeframefromhumanstring
globalui = uimod.ui()
@@ -261,8 +268,11 @@
action, meta = sendframe(reactor,
ffs(b'1 2 stream-begin stream-settings eos %s' % data))
- self.assertEqual(action, b'noop')
- self.assertEqual(meta, {})
+ self.assertEqual(action, b'error')
+ self.assertEqual(meta, {
+ b'message': b'error setting stream decoder: identity decoder '
+ b'received unexpected additional values',
+ })
def testmultipleframes(self):
reactor = framing.clientreactor(globalui, buffersends=False)
@@ -286,6 +296,309 @@
self.assertEqual(action, b'noop')
self.assertEqual(meta, {})
+ def testinvalidencoder(self):
+ reactor = framing.clientreactor(globalui, buffersends=False)
+
+ request, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ action, meta = sendframe(reactor,
+ ffs(b'1 2 stream-begin stream-settings eos cbor:b"badvalue"'))
+
+ self.assertEqual(action, b'error')
+ self.assertEqual(meta, {
+ b'message': b'error setting stream decoder: unknown stream '
+ b'decoder: badvalue',
+ })
+
+ def testzlibencoding(self):
+ reactor = framing.clientreactor(globalui, buffersends=False)
+
+ request, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zlib"' %
+ request.requestid))
+
+ self.assertEqual(action, b'noop')
+ self.assertEqual(meta, {})
+
+ result = {
+ b'status': b'ok',
+ }
+ encoded = b''.join(cborutil.streamencode(result))
+
+ compressed = zlib.compress(encoded)
+ self.assertEqual(zlib.decompress(compressed), encoded)
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response eos %s' %
+ (request.requestid, compressed)))
+
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], encoded)
+
+ def testzlibencodingsinglebyteframes(self):
+ reactor = framing.clientreactor(globalui, buffersends=False)
+
+ request, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zlib"' %
+ request.requestid))
+
+ self.assertEqual(action, b'noop')
+ self.assertEqual(meta, {})
+
+ result = {
+ b'status': b'ok',
+ }
+ encoded = b''.join(cborutil.streamencode(result))
+
+ compressed = zlib.compress(encoded)
+ self.assertEqual(zlib.decompress(compressed), encoded)
+
+ chunks = []
+
+ for i in range(len(compressed)):
+ char = compressed[i:i + 1]
+ if char == b'\\':
+ char = b'\\\\'
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response continuation %s' %
+ (request.requestid, char)))
+
+ self.assertEqual(action, b'responsedata')
+ chunks.append(meta[b'data'])
+ self.assertTrue(meta[b'expectmore'])
+ self.assertFalse(meta[b'eos'])
+
+ # zlib will have the full data decoded at this point, even though
+ # we haven't flushed.
+ self.assertEqual(b''.join(chunks), encoded)
+
+ # End the stream for good measure.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-end command-response eos ' % request.requestid))
+
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], b'')
+ self.assertFalse(meta[b'expectmore'])
+ self.assertTrue(meta[b'eos'])
+
+ def testzlibmultipleresponses(self):
+ # We feed in zlib compressed data on the same stream but belonging to
+ # 2 different requests. This tests our flushing behavior.
+ reactor = framing.clientreactor(globalui, buffersends=False,
+ hasmultiplesend=True)
+
+ request1, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ request2, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ outstream = framing.outputstream(2)
+ outstream.setencoder(globalui, b'zlib')
+
+ response1 = b''.join(cborutil.streamencode({
+ b'status': b'ok',
+ b'extra': b'response1' * 10,
+ }))
+
+ response2 = b''.join(cborutil.streamencode({
+ b'status': b'error',
+ b'extra': b'response2' * 10,
+ }))
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zlib"' %
+ request1.requestid))
+
+ self.assertEqual(action, b'noop')
+ self.assertEqual(meta, {})
+
+ # Feeding partial data in won't get anything useful out.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response continuation %s' % (
+ request1.requestid, outstream.encode(response1))))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], b'')
+
+ # But flushing data at both ends will get our original data.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response eos %s' % (
+ request1.requestid, outstream.flush())))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], response1)
+
+ # We should be able to reuse the compressor/decompressor for the
+ # 2nd response.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response continuation %s' % (
+ request2.requestid, outstream.encode(response2))))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], b'')
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response eos %s' % (
+ request2.requestid, outstream.flush())))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], response2)
+
+ @unittest.skipUnless(zstd, 'zstd not available')
+ def testzstd8mbencoding(self):
+ reactor = framing.clientreactor(globalui, buffersends=False)
+
+ request, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zstd-8mb"' %
+ request.requestid))
+
+ self.assertEqual(action, b'noop')
+ self.assertEqual(meta, {})
+
+ result = {
+ b'status': b'ok',
+ }
+ encoded = b''.join(cborutil.streamencode(result))
+
+ encoder = framing.zstd8mbencoder(globalui)
+ compressed = encoder.encode(encoded) + encoder.finish()
+ self.assertEqual(zstd.ZstdDecompressor().decompress(
+ compressed, max_output_size=len(encoded)), encoded)
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response eos %s' %
+ (request.requestid, compressed)))
+
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], encoded)
+
+ @unittest.skipUnless(zstd, 'zstd not available')
+ def testzstd8mbencodingsinglebyteframes(self):
+ reactor = framing.clientreactor(globalui, buffersends=False)
+
+ request, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zstd-8mb"' %
+ request.requestid))
+
+ self.assertEqual(action, b'noop')
+ self.assertEqual(meta, {})
+
+ result = {
+ b'status': b'ok',
+ }
+ encoded = b''.join(cborutil.streamencode(result))
+
+ compressed = zstd.ZstdCompressor().compress(encoded)
+ self.assertEqual(zstd.ZstdDecompressor().decompress(compressed),
+ encoded)
+
+ chunks = []
+
+ for i in range(len(compressed)):
+ char = compressed[i:i + 1]
+ if char == b'\\':
+ char = b'\\\\'
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response continuation %s' %
+ (request.requestid, char)))
+
+ self.assertEqual(action, b'responsedata')
+ chunks.append(meta[b'data'])
+ self.assertTrue(meta[b'expectmore'])
+ self.assertFalse(meta[b'eos'])
+
+ # zstd decompressor will flush at frame boundaries.
+ self.assertEqual(b''.join(chunks), encoded)
+
+ # End the stream for good measure.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-end command-response eos ' % request.requestid))
+
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], b'')
+ self.assertFalse(meta[b'expectmore'])
+ self.assertTrue(meta[b'eos'])
+
+ @unittest.skipUnless(zstd, 'zstd not available')
+ def testzstd8mbmultipleresponses(self):
+ # We feed in zstd compressed data on the same stream but belonging to
+ # 2 different requests. This tests our flushing behavior.
+ reactor = framing.clientreactor(globalui, buffersends=False,
+ hasmultiplesend=True)
+
+ request1, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ request2, action, meta = reactor.callcommand(b'foo', {})
+ for f in meta[b'framegen']:
+ pass
+
+ outstream = framing.outputstream(2)
+ outstream.setencoder(globalui, b'zstd-8mb')
+
+ response1 = b''.join(cborutil.streamencode({
+ b'status': b'ok',
+ b'extra': b'response1' * 10,
+ }))
+
+ response2 = b''.join(cborutil.streamencode({
+ b'status': b'error',
+ b'extra': b'response2' * 10,
+ }))
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zstd-8mb"' %
+ request1.requestid))
+
+ self.assertEqual(action, b'noop')
+ self.assertEqual(meta, {})
+
+ # Feeding partial data in won't get anything useful out.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response continuation %s' % (
+ request1.requestid, outstream.encode(response1))))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], b'')
+
+ # But flushing data at both ends will get our original data.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response eos %s' % (
+ request1.requestid, outstream.flush())))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], response1)
+
+ # We should be able to reuse the compressor/decompressor for the
+ # 2nd response.
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response continuation %s' % (
+ request2.requestid, outstream.encode(response2))))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], b'')
+
+ action, meta = sendframe(reactor,
+ ffs(b'%d 2 encoded command-response eos %s' % (
+ request2.requestid, outstream.flush())))
+ self.assertEqual(action, b'responsedata')
+ self.assertEqual(meta[b'data'], response2)
+
if __name__ == '__main__':
import silenttestrunner
silenttestrunner.main(__name__)