diff --git a/mercurial/wireprotoframing.py b/mercurial/wireprotoframing.py --- a/mercurial/wireprotoframing.py +++ b/mercurial/wireprotoframing.py @@ -648,6 +648,140 @@ flags=FLAG_COMMAND_RESPONSE_CONTINUATION, payload=payload) +# TODO consider defining encoders/decoders using the util.compressionengine +# mechanism. + +class identityencoder(object): + """Encoder for the "identity" stream encoding profile.""" + def __init__(self, ui): + pass + + def encode(self, data): + return data + + def flush(self): + return b'' + + def finish(self): + return b'' + +class identitydecoder(object): + """Decoder for the "identity" stream encoding profile.""" + + def __init__(self, ui, extraobjs): + if extraobjs: + raise error.Abort(_('identity decoder received unexpected ' + 'additional values')) + + def decode(self, data): + return data + +class zlibencoder(object): + def __init__(self, ui): + import zlib + self._zlib = zlib + self._compressor = zlib.compressobj() + + def encode(self, data): + return self._compressor.compress(data) + + def flush(self): + # Z_SYNC_FLUSH doesn't reset compression context, which is + # what we want. + return self._compressor.flush(self._zlib.Z_SYNC_FLUSH) + + def finish(self): + res = self._compressor.flush(self._zlib.Z_FINISH) + self._compressor = None + return res + +class zlibdecoder(object): + def __init__(self, ui, extraobjs): + import zlib + + if extraobjs: + raise error.Abort(_('zlib decoder received unexpected ' + 'additional values')) + + self._decompressor = zlib.decompressobj() + + def decode(self, data): + # Python 2's zlib module doesn't use the buffer protocol and can't + # handle all bytes-like types. + if not pycompat.ispy3 and isinstance(data, bytearray): + data = bytes(data) + + return self._decompressor.decompress(data) + +class zstdbaseencoder(object): + def __init__(self, level): + from . import zstd + + self._zstd = zstd + cctx = zstd.ZstdCompressor(level=level) + self._compressor = cctx.compressobj() + + def encode(self, data): + return self._compressor.compress(data) + + def flush(self): + # COMPRESSOBJ_FLUSH_BLOCK flushes all data previously fed into the + # compressor and allows a decompressor to access all encoded data + # up to this point. + return self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_BLOCK) + + def finish(self): + res = self._compressor.flush(self._zstd.COMPRESSOBJ_FLUSH_FINISH) + self._compressor = None + return res + +class zstd8mbencoder(zstdbaseencoder): + def __init__(self, ui): + super(zstd8mbencoder, self).__init__(3) + +class zstdbasedecoder(object): + def __init__(self, maxwindowsize): + from . import zstd + dctx = zstd.ZstdDecompressor(max_window_size=maxwindowsize) + self._decompressor = dctx.decompressobj() + + def decode(self, data): + return self._decompressor.decompress(data) + +class zstd8mbdecoder(zstdbasedecoder): + def __init__(self, ui, extraobjs): + if extraobjs: + raise error.Abort(_('zstd8mb decoder received unexpected ' + 'additional values')) + + super(zstd8mbdecoder, self).__init__(maxwindowsize=8 * 1048576) + +# We lazily populate this to avoid excessive module imports when importing +# this module. +STREAM_ENCODERS = {} +STREAM_ENCODERS_ORDER = [] + +def populatestreamencoders(): + if STREAM_ENCODERS: + return + + try: + from . import zstd + zstd.__version__ + except ImportError: + zstd = None + + # zstandard is fastest and is preferred. + if zstd: + STREAM_ENCODERS[b'zstd-8mb'] = (zstd8mbencoder, zstd8mbdecoder) + STREAM_ENCODERS_ORDER.append(b'zstd-8mb') + + STREAM_ENCODERS[b'zlib'] = (zlibencoder, zlibdecoder) + STREAM_ENCODERS_ORDER.append(b'zlib') + + STREAM_ENCODERS[b'identity'] = (identityencoder, identitydecoder) + STREAM_ENCODERS_ORDER.append(b'identity') + class stream(object): """Represents a logical unidirectional series of frames.""" @@ -671,16 +805,70 @@ class inputstream(stream): """Represents a stream used for receiving data.""" - def setdecoder(self, name, extraobjs): + def __init__(self, streamid, active=False): + super(inputstream, self).__init__(streamid, active=active) + self._decoder = None + + def setdecoder(self, ui, name, extraobjs): """Set the decoder for this stream. Receives the stream profile name and any additional CBOR objects decoded from the stream encoding settings frame payloads. """ + if name not in STREAM_ENCODERS: + raise error.Abort(_('unknown stream decoder: %s') % name) + + self._decoder = STREAM_ENCODERS[name][1](ui, extraobjs) + + def decode(self, data): + # Default is identity decoder. We don't bother instantiating one + # because it is trivial. + if not self._decoder: + return data + + return self._decoder.decode(data) + + def flush(self): + if not self._decoder: + return b'' + + return self._decoder.flush() class outputstream(stream): """Represents a stream used for sending data.""" + def __init__(self, streamid, active=False): + super(outputstream, self).__init__(streamid, active=active) + self._encoder = None + + def setencoder(self, ui, name): + """Set the encoder for this stream. + + Receives the stream profile name. + """ + if name not in STREAM_ENCODERS: + raise error.Abort(_('unknown stream encoder: %s') % name) + + self._encoder = STREAM_ENCODERS[name][0](ui) + + def encode(self, data): + if not self._encoder: + return data + + return self._encoder.encode(data) + + def flush(self): + if not self._encoder: + return b'' + + return self._encoder.flush() + + def finish(self): + if not self._encoder: + return b'' + + self._encoder.finish() + def ensureserverstream(stream): if stream.streamid % 2: raise error.ProgrammingError('server should only write to even ' @@ -786,6 +974,8 @@ # Sender protocol settings are optional. Set implied default values. self._sendersettings = dict(DEFAULT_PROTOCOL_SETTINGS) + populatestreamencoders() + def onframerecv(self, frame): """Process a frame that has been received off the wire. @@ -1384,6 +1574,8 @@ self._incomingstreams = {} self._streamsettingsdecoders = {} + populatestreamencoders() + def callcommand(self, name, args, datafh=None, redirect=None): """Request that a command be executed. @@ -1494,9 +1686,13 @@ self._incomingstreams[frame.streamid] = inputstream( frame.streamid) + stream = self._incomingstreams[frame.streamid] + + # If the payload is encoded, ask the stream to decode it. We + # merely substitute the decoded result into the frame payload as + # if it had been transferred all along. if frame.streamflags & STREAM_FLAG_ENCODING_APPLIED: - raise error.ProgrammingError('support for decoding stream ' - 'payloads not yet implemneted') + frame.payload = stream.decode(frame.payload) if frame.streamflags & STREAM_FLAG_END_STREAM: del self._incomingstreams[frame.streamid] @@ -1573,7 +1769,8 @@ } try: - self._incomingstreams[frame.streamid].setdecoder(decoded[0], + self._incomingstreams[frame.streamid].setdecoder(self._ui, + decoded[0], decoded[1:]) except Exception as e: return 'error', { diff --git a/tests/test-wireproto-clientreactor.py b/tests/test-wireproto-clientreactor.py --- a/tests/test-wireproto-clientreactor.py +++ b/tests/test-wireproto-clientreactor.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import unittest +import zlib from mercurial import ( error, @@ -11,6 +12,12 @@ cborutil, ) +try: + from mercurial import zstd + zstd.__version__ +except ImportError: + zstd = None + ffs = framing.makeframefromhumanstring globalui = uimod.ui() @@ -261,8 +268,11 @@ action, meta = sendframe(reactor, ffs(b'1 2 stream-begin stream-settings eos %s' % data)) - self.assertEqual(action, b'noop') - self.assertEqual(meta, {}) + self.assertEqual(action, b'error') + self.assertEqual(meta, { + b'message': b'error setting stream decoder: identity decoder ' + b'received unexpected additional values', + }) def testmultipleframes(self): reactor = framing.clientreactor(globalui, buffersends=False) @@ -286,6 +296,309 @@ self.assertEqual(action, b'noop') self.assertEqual(meta, {}) + def testinvalidencoder(self): + reactor = framing.clientreactor(globalui, buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings eos cbor:b"badvalue"')) + + self.assertEqual(action, b'error') + self.assertEqual(meta, { + b'message': b'error setting stream decoder: unknown stream ' + b'decoder: badvalue', + }) + + def testzlibencoding(self): + reactor = framing.clientreactor(globalui, buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zlib"' % + request.requestid)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + result = { + b'status': b'ok', + } + encoded = b''.join(cborutil.streamencode(result)) + + compressed = zlib.compress(encoded) + self.assertEqual(zlib.decompress(compressed), encoded) + + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response eos %s' % + (request.requestid, compressed))) + + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], encoded) + + def testzlibencodingsinglebyteframes(self): + reactor = framing.clientreactor(globalui, buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zlib"' % + request.requestid)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + result = { + b'status': b'ok', + } + encoded = b''.join(cborutil.streamencode(result)) + + compressed = zlib.compress(encoded) + self.assertEqual(zlib.decompress(compressed), encoded) + + chunks = [] + + for i in range(len(compressed)): + char = compressed[i:i + 1] + if char == b'\\': + char = b'\\\\' + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response continuation %s' % + (request.requestid, char))) + + self.assertEqual(action, b'responsedata') + chunks.append(meta[b'data']) + self.assertTrue(meta[b'expectmore']) + self.assertFalse(meta[b'eos']) + + # zlib will have the full data decoded at this point, even though + # we haven't flushed. + self.assertEqual(b''.join(chunks), encoded) + + # End the stream for good measure. + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-end command-response eos ' % request.requestid)) + + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], b'') + self.assertFalse(meta[b'expectmore']) + self.assertTrue(meta[b'eos']) + + def testzlibmultipleresponses(self): + # We feed in zlib compressed data on the same stream but belonging to + # 2 different requests. This tests our flushing behavior. + reactor = framing.clientreactor(globalui, buffersends=False, + hasmultiplesend=True) + + request1, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + request2, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + outstream = framing.outputstream(2) + outstream.setencoder(globalui, b'zlib') + + response1 = b''.join(cborutil.streamencode({ + b'status': b'ok', + b'extra': b'response1' * 10, + })) + + response2 = b''.join(cborutil.streamencode({ + b'status': b'error', + b'extra': b'response2' * 10, + })) + + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zlib"' % + request1.requestid)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + # Feeding partial data in won't get anything useful out. + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response continuation %s' % ( + request1.requestid, outstream.encode(response1)))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], b'') + + # But flushing data at both ends will get our original data. + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response eos %s' % ( + request1.requestid, outstream.flush()))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], response1) + + # We should be able to reuse the compressor/decompressor for the + # 2nd response. + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response continuation %s' % ( + request2.requestid, outstream.encode(response2)))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], b'') + + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response eos %s' % ( + request2.requestid, outstream.flush()))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], response2) + + @unittest.skipUnless(zstd, 'zstd not available') + def testzstd8mbencoding(self): + reactor = framing.clientreactor(globalui, buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zstd-8mb"' % + request.requestid)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + result = { + b'status': b'ok', + } + encoded = b''.join(cborutil.streamencode(result)) + + encoder = framing.zstd8mbencoder(globalui) + compressed = encoder.encode(encoded) + encoder.finish() + self.assertEqual(zstd.ZstdDecompressor().decompress( + compressed, max_output_size=len(encoded)), encoded) + + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response eos %s' % + (request.requestid, compressed))) + + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], encoded) + + @unittest.skipUnless(zstd, 'zstd not available') + def testzstd8mbencodingsinglebyteframes(self): + reactor = framing.clientreactor(globalui, buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zstd-8mb"' % + request.requestid)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + result = { + b'status': b'ok', + } + encoded = b''.join(cborutil.streamencode(result)) + + compressed = zstd.ZstdCompressor().compress(encoded) + self.assertEqual(zstd.ZstdDecompressor().decompress(compressed), + encoded) + + chunks = [] + + for i in range(len(compressed)): + char = compressed[i:i + 1] + if char == b'\\': + char = b'\\\\' + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response continuation %s' % + (request.requestid, char))) + + self.assertEqual(action, b'responsedata') + chunks.append(meta[b'data']) + self.assertTrue(meta[b'expectmore']) + self.assertFalse(meta[b'eos']) + + # zstd decompressor will flush at frame boundaries. + self.assertEqual(b''.join(chunks), encoded) + + # End the stream for good measure. + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-end command-response eos ' % request.requestid)) + + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], b'') + self.assertFalse(meta[b'expectmore']) + self.assertTrue(meta[b'eos']) + + @unittest.skipUnless(zstd, 'zstd not available') + def testzstd8mbmultipleresponses(self): + # We feed in zstd compressed data on the same stream but belonging to + # 2 different requests. This tests our flushing behavior. + reactor = framing.clientreactor(globalui, buffersends=False, + hasmultiplesend=True) + + request1, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + request2, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + outstream = framing.outputstream(2) + outstream.setencoder(globalui, b'zstd-8mb') + + response1 = b''.join(cborutil.streamencode({ + b'status': b'ok', + b'extra': b'response1' * 10, + })) + + response2 = b''.join(cborutil.streamencode({ + b'status': b'error', + b'extra': b'response2' * 10, + })) + + action, meta = sendframe(reactor, + ffs(b'%d 2 stream-begin stream-settings eos cbor:b"zstd-8mb"' % + request1.requestid)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + # Feeding partial data in won't get anything useful out. + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response continuation %s' % ( + request1.requestid, outstream.encode(response1)))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], b'') + + # But flushing data at both ends will get our original data. + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response eos %s' % ( + request1.requestid, outstream.flush()))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], response1) + + # We should be able to reuse the compressor/decompressor for the + # 2nd response. + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response continuation %s' % ( + request2.requestid, outstream.encode(response2)))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], b'') + + action, meta = sendframe(reactor, + ffs(b'%d 2 encoded command-response eos %s' % ( + request2.requestid, outstream.flush()))) + self.assertEqual(action, b'responsedata') + self.assertEqual(meta[b'data'], response2) + if __name__ == '__main__': import silenttestrunner silenttestrunner.main(__name__)