diff -r 5d44c4d1d516 -r e67522413ca8 mercurial/wireprotoframing.py --- a/mercurial/wireprotoframing.py Thu Oct 04 17:39:16 2018 -0700 +++ b/mercurial/wireprotoframing.py Mon Oct 08 17:10:59 2018 -0700 @@ -648,6 +648,140 @@ flags=FLAG_COMMAND_RESPONSE_CONTINUATION, payload=payload) +# TODO consider defining encoders/decoders using the util.compressionengine +# mechanism. + +class identityencoder(object): + """Encoder for the "identity" stream encoding profile.""" + def __init__(self, ui): + pass + + def encode(self, data): + return data + + def flush(self): + return b'' + + def finish(self): + return b'' + +class identitydecoder(object): + """Decoder for the "identity" stream encoding profile.""" + + def __init__(self, ui, extraobjs): + if extraobjs: + raise error.Abort(_('identity decoder received unexpected ' + 'additional values')) + + def decode(self, data): + return data + +class zlibencoder(object): + def __init__(self, ui): + import zlib + self._zlib = zlib + self._compressor = zlib.compressobj() + + def encode(self, data): + return self._compressor.compress(data) + + def flush(self): + # Z_SYNC_FLUSH doesn't reset compression context, which is + # what we want. + return self._compressor.flush(self._zlib.Z_SYNC_FLUSH) + + def finish(self): + res = self._compressor.flush(self._zlib.Z_FINISH) + self._compressor = None + return res + +class zlibdecoder(object): + def __init__(self, ui, extraobjs): + import zlib + + if extraobjs: + raise error.Abort(_('zlib decoder received unexpected ' + 'additional values')) + + self._decompressor = zlib.decompressobj() + + def decode(self, data): + # Python 2's zlib module doesn't use the buffer protocol and can't + # handle all bytes-like types. + if not pycompat.ispy3 and isinstance(data, bytearray): + data = bytes(data) + + return self._decompressor.decompress(data) + +class zstdbaseencoder(object): + def __init__(self, level): + from . import zstd + + self._zstd = zstd + cctx = zstd.ZstdCompressor(level=level) + self._compressor = cctx.compressobj() + + def encode(self, data): + return self._compressor.compress(data) + + def flush(self): + # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the + # compressor and allows a decompressor to access all encoded data + # up to this point. + return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK) + + def finish(self): + res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH) + self._compressor = None + return res + +class zstd8mbencoder(zstdbaseencoder): + def __init__(self, ui): + super(zstd8mbencoder, self).__init__(3) + +class zstdbasedecoder(object): + def __init__(self, maxwindowsize): + from . import zstd + dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize) + self._decompressor = dctx.decompressobj() + + def decode(self, data): + return self._decompressor.decompress(data) + +class zstd8mbdecoder(zstdbasedecoder): + def __init__(self, ui, extraobjs): + if extraobjs: + raise error.Abort(_('zstd8mb decoder received unexpected ' + 'additional values')) + + super(zstd8mbdecoder, self).__init__(maxwindowsize=8 * 1048576) + +# We lazily populate this to avoid excessive module imports when importing +# this module. +STREAM_ENCODERS = {} +STREAM_ENCODERS_ORDER = [] + +def populatestreamencoders(): + if STREAM_ENCODERS: + return + + try: + from . import zstd + zstd.__version__ + except ImportError: + zstd = None + + # zstandard is fastest and is preferred. + if zstd: + STREAM_ENCODERS[b'zstd-8mb'] = (zstd8mbencoder, zstd8mbdecoder) + STREAM_ENCODERS_ORDER.append(b'zstd-8mb') + + STREAM_ENCODERS[b'zlib'] = (zlibencoder, zlibdecoder) + STREAM_ENCODERS_ORDER.append(b'zlib') + + STREAM_ENCODERS[b'identity'] = (identityencoder, identitydecoder) + STREAM_ENCODERS_ORDER.append(b'identity') + class stream(object): """Represents a logical unidirectional series of frames.""" @@ -671,16 +805,70 @@ class inputstream(stream): """Represents a stream used for receiving data.""" - def setdecoder(self, name, extraobjs): + def __init__(self, streamid, active=False): + super(inputstream, self).__init__(streamid, active=active) + self._decoder = None + + def setdecoder(self, ui, name, extraobjs): """Set the decoder for this stream. Receives the stream profile name and any additional CBOR objects decoded from the stream encoding settings frame payloads. """ + if name not in STREAM_ENCODERS: + raise error.Abort(_('unknown stream decoder: %s') % name) + + self._decoder = STREAM_ENCODERS[name][1](ui, extraobjs) + + def decode(self, data): + # Default is identity decoder. We don't bother instantiating one + # because it is trivial. + if not self._decoder: + return data + + return self._decoder.decode(data) + + def flush(self): + if not self._decoder: + return b'' + + return self._decoder.flush() class outputstream(stream): """Represents a stream used for sending data.""" + def __init__(self, streamid, active=False): + super(outputstream, self).__init__(streamid, active=active) + self._encoder = None + + def setencoder(self, ui, name): + """Set the encoder for this stream. + + Receives the stream profile name. + """ + if name not in STREAM_ENCODERS: + raise error.Abort(_('unknown stream encoder: %s') % name) + + self._encoder = STREAM_ENCODERS[name][0](ui) + + def encode(self, data): + if not self._encoder: + return data + + return self._encoder.encode(data) + + def flush(self): + if not self._encoder: + return b'' + + return self._encoder.flush() + + def finish(self): + if not self._encoder: + return b'' + + self._encoder.finish() + def ensureserverstream(stream): if stream.streamid % 2: raise error.ProgrammingError('server should only write to even ' @@ -786,6 +974,8 @@ # Sender protocol settings are optional. Set implied default values. self._sendersettings = dict(DEFAULT_PROTOCOL_SETTINGS) + populatestreamencoders() + def onframerecv(self, frame): """Process a frame that has been received off the wire. @@ -1384,6 +1574,8 @@ self._incomingstreams = {} self._streamsettingsdecoders = {} + populatestreamencoders() + def callcommand(self, name, args, datafh=None, redirect=None): """Request that a command be executed. @@ -1494,9 +1686,13 @@ self._incomingstreams[frame.streamid] = inputstream( frame.streamid) + stream = self._incomingstreams[frame.streamid] + + # If the payload is encoded, ask the stream to decode it. We + # merely substitute the decoded result into the frame payload as + # if it had been transferred all along. if frame.streamflags & STREAM_FLAG_ENCODING_APPLIED: - raise error.ProgrammingError('support for decoding stream ' - 'payloads not yet implemneted') + frame.payload = stream.decode(frame.payload) if frame.streamflags & STREAM_FLAG_END_STREAM: del self._incomingstreams[frame.streamid] @@ -1573,7 +1769,8 @@ } try: - self._incomingstreams[frame.streamid].setdecoder(decoded[0], + self._incomingstreams[frame.streamid].setdecoder(self._ui, + decoded[0], decoded[1:]) except Exception as e: return 'error', {