diff --git a/mercurial/wireprotoframing.py b/mercurial/wireprotoframing.py --- a/mercurial/wireprotoframing.py +++ b/mercurial/wireprotoframing.py @@ -668,6 +668,13 @@ return makeframe(requestid, self.streamid, streamflags, typeid, flags, payload) + def setdecoder(self, name, extraobjs): + """Set the decoder for this stream. + + Receives the stream profile name and any additional CBOR objects + decoded from the stream encoding settings frame payloads. + """ + def ensureserverstream(stream): if stream.streamid % 2: raise error.ProgrammingError('server should only write to even ' @@ -1367,6 +1374,7 @@ self._pendingrequests = collections.deque() self._activerequests = {} self._incomingstreams = {} + self._streamsettingsdecoders = {} def callcommand(self, name, args, datafh=None, redirect=None): """Request that a command be executed. @@ -1484,6 +1492,9 @@ if frame.streamflags & STREAM_FLAG_END_STREAM: del self._incomingstreams[frame.streamid] + if frame.typeid == FRAME_TYPE_STREAM_SETTINGS: + return self._onstreamsettingsframe(frame) + if frame.requestid not in self._activerequests: return 'error', { 'message': (_('received frame for inactive request ID: %d') % @@ -1505,6 +1516,64 @@ return meth(request, frame) + def _onstreamsettingsframe(self, frame): + assert frame.typeid == FRAME_TYPE_STREAM_SETTINGS + + more = frame.flags & FLAG_STREAM_ENCODING_SETTINGS_CONTINUATION + eos = frame.flags & FLAG_STREAM_ENCODING_SETTINGS_EOS + + if more and eos: + return 'error', { + 'message': (_('stream encoding settings frame cannot have both ' + 'continuation and end of stream flags set')), + } + + if not more and not eos: + return 'error', { + 'message': _('stream encoding settings frame must have ' + 'continuation or end of stream flag set'), + } + + if frame.streamid not in self._streamsettingsdecoders: + decoder = cborutil.bufferingdecoder() + self._streamsettingsdecoders[frame.streamid] = decoder + + decoder = self._streamsettingsdecoders[frame.streamid] + + try: + decoder.decode(frame.payload) + except Exception as e: + return 'error', { + 'message': (_('error decoding CBOR from stream encoding ' + 'settings frame: %s') % + stringutil.forcebytestr(e)), + } + + if more: + return 'noop', {} + + assert eos + + decoded = decoder.getavailable() + del self._streamsettingsdecoders[frame.streamid] + + if not decoded: + return 'error', { + 'message': _('stream encoding settings frame did not contain ' + 'CBOR data'), + } + + try: + self._incomingstreams[frame.streamid].setdecoder(decoded[0], + decoded[1:]) + except Exception as e: + return 'error', { + 'message': (_('error setting stream decoder: %s') % + stringutil.forcebytestr(e)), + } + + return 'noop', {} + def _oncommandresponseframe(self, request, frame): if frame.flags & FLAG_COMMAND_RESPONSE_EOS: request.state = 'received' diff --git a/tests/test-wireproto-clientreactor.py b/tests/test-wireproto-clientreactor.py --- a/tests/test-wireproto-clientreactor.py +++ b/tests/test-wireproto-clientreactor.py @@ -6,6 +6,9 @@ error, wireprotoframing as framing, ) +from mercurial.utils import ( + cborutil, +) ffs = framing.makeframefromhumanstring @@ -162,6 +165,120 @@ b"b'redirect': {b'targets': [b'a', b'b'], " b"b'hashes': [b'sha256']}}")) +class StreamSettingsTests(unittest.TestCase): + def testnoflags(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings 0 ')) + + self.assertEqual(action, b'error') + self.assertEqual(meta, { + b'message': b'stream encoding settings frame must have ' + b'continuation or end of stream flag set', + }) + + def testconflictflags(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings continuation|eos ')) + + self.assertEqual(action, b'error') + self.assertEqual(meta, { + b'message': b'stream encoding settings frame cannot have both ' + b'continuation and end of stream flags set', + }) + + def testemptypayload(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings eos ')) + + self.assertEqual(action, b'error') + self.assertEqual(meta, { + b'message': b'stream encoding settings frame did not contain ' + b'CBOR data' + }) + + def testbadcbor(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings eos badvalue')) + + self.assertEqual(action, b'error') + + def testsingleobject(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings eos cbor:b"identity"')) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + def testmultipleobjects(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + data = b''.join([ + b''.join(cborutil.streamencode(b'identity')), + b''.join(cborutil.streamencode({b'foo', b'bar'})), + ]) + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings eos %s' % data)) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + def testmultipleframes(self): + reactor = framing.clientreactor(buffersends=False) + + request, action, meta = reactor.callcommand(b'foo', {}) + for f in meta[b'framegen']: + pass + + data = b''.join(cborutil.streamencode(b'identity')) + + action, meta = sendframe(reactor, + ffs(b'1 2 stream-begin stream-settings continuation %s' % + data[0:3])) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + + action, meta = sendframe(reactor, + ffs(b'1 2 0 stream-settings eos %s' % data[3:])) + + self.assertEqual(action, b'noop') + self.assertEqual(meta, {}) + if __name__ == '__main__': import silenttestrunner silenttestrunner.main(__name__)