diff --git a/mercurial/wireprotoframing.py b/mercurial/wireprotoframing.py --- a/mercurial/wireprotoframing.py +++ b/mercurial/wireprotoframing.py @@ -218,7 +218,7 @@ return frame(h.requestid, h.typeid, h.flags, payload) -def createcommandframes(requestid, cmd, args, datafh=None): +def createcommandframes(stream, requestid, cmd, args, datafh=None): """Create frames necessary to transmit a request to run a command. This is a generator of bytearrays. Each item represents a frame @@ -233,8 +233,8 @@ if not flags: flags |= FLAG_COMMAND_NAME_EOS - yield makeframe(requestid=requestid, typeid=FRAME_TYPE_COMMAND_NAME, - flags=flags, payload=cmd) + yield stream.makeframe(requestid=requestid, typeid=FRAME_TYPE_COMMAND_NAME, + flags=flags, payload=cmd) for i, k in enumerate(sorted(args)): v = args[k] @@ -250,10 +250,10 @@ payload[offset:offset + len(v)] = v flags = FLAG_COMMAND_ARGUMENT_EOA if last else 0 - yield makeframe(requestid=requestid, - typeid=FRAME_TYPE_COMMAND_ARGUMENT, - flags=flags, - payload=payload) + yield stream.makeframe(requestid=requestid, + typeid=FRAME_TYPE_COMMAND_ARGUMENT, + flags=flags, + payload=payload) if datafh: while True: @@ -267,15 +267,15 @@ assert datafh.read(1) == b'' done = True - yield makeframe(requestid=requestid, - typeid=FRAME_TYPE_COMMAND_DATA, - flags=flags, - payload=data) + yield stream.makeframe(requestid=requestid, + typeid=FRAME_TYPE_COMMAND_DATA, + flags=flags, + payload=data) if done: break -def createbytesresponseframesfrombytes(requestid, data, +def createbytesresponseframesfrombytes(stream, requestid, data, maxframesize=DEFAULT_MAX_FRAME_SIZE): """Create a raw frame to send a bytes response from static bytes input. @@ -284,10 +284,10 @@ # Simple case of a single frame. if len(data) <= maxframesize: - yield makeframe(requestid=requestid, - typeid=FRAME_TYPE_BYTES_RESPONSE, - flags=FLAG_BYTES_RESPONSE_EOS, - payload=data) + yield stream.makeframe(requestid=requestid, + typeid=FRAME_TYPE_BYTES_RESPONSE, + flags=FLAG_BYTES_RESPONSE_EOS, + payload=data) return offset = 0 @@ -301,15 +301,15 @@ else: flags = FLAG_BYTES_RESPONSE_CONTINUATION - yield makeframe(requestid=requestid, - typeid=FRAME_TYPE_BYTES_RESPONSE, - flags=flags, - payload=chunk) + yield stream.makeframe(requestid=requestid, + typeid=FRAME_TYPE_BYTES_RESPONSE, + flags=flags, + payload=chunk) if done: break -def createerrorframe(requestid, msg, protocol=False, application=False): +def createerrorframe(stream, requestid, msg, protocol=False, application=False): # TODO properly handle frame size limits. assert len(msg) <= DEFAULT_MAX_FRAME_SIZE @@ -319,12 +319,12 @@ if application: flags |= FLAG_ERROR_RESPONSE_APPLICATION - yield makeframe(requestid=requestid, - typeid=FRAME_TYPE_ERROR_RESPONSE, - flags=flags, - payload=msg) + yield stream.makeframe(requestid=requestid, + typeid=FRAME_TYPE_ERROR_RESPONSE, + flags=flags, + payload=msg) -def createtextoutputframe(requestid, atoms): +def createtextoutputframe(stream, requestid, atoms): """Create a text output frame to render text to people. ``atoms`` is a 3-tuple of (formatting string, args, labels). @@ -390,10 +390,20 @@ if bytesleft < 0: raise ValueError('cannot encode data in a single frame') - yield makeframe(requestid=requestid, - typeid=FRAME_TYPE_TEXT_OUTPUT, - flags=0, - payload=b''.join(atomchunks)) + yield stream.makeframe(requestid=requestid, + typeid=FRAME_TYPE_TEXT_OUTPUT, + flags=0, + payload=b''.join(atomchunks)) + +class stream(object): + """Represents a logical unidirectional series of frames.""" + + def makeframe(self, requestid, typeid, flags, payload): + """Create a frame to be sent out over this stream. + + Only returns the frame instance. Does not actually send it. + """ + return makeframe(requestid, typeid, flags, payload) class serverreactor(object): """Holds state of a server handling frame-based protocol requests. @@ -498,13 +508,14 @@ return meth(frame) - def onbytesresponseready(self, requestid, data): + def onbytesresponseready(self, stream, requestid, data): """Signal that a bytes response is ready to be sent to the client. The raw bytes response is passed as an argument. """ def sendframes(): - for frame in createbytesresponseframesfrombytes(requestid, data): + for frame in createbytesresponseframesfrombytes(stream, requestid, + data): yield frame self._activecommands.remove(requestid) @@ -540,9 +551,10 @@ 'framegen': makegen(), } - def onapplicationerror(self, requestid, msg): + def onapplicationerror(self, stream, requestid, msg): return 'sendframes', { - 'framegen': createerrorframe(requestid, msg, application=True), + 'framegen': createerrorframe(stream, requestid, msg, + application=True), } def _makeerrorresult(self, msg): diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -546,9 +546,11 @@ res.status = b'200 OK' res.headers[b'Content-Type'] = FRAMINGTYPE + stream = wireprotoframing.stream() if isinstance(rsp, wireprototypes.bytesresponse): - action, meta = reactor.onbytesresponseready(command['requestid'], + action, meta = reactor.onbytesresponseready(stream, + command['requestid'], rsp.data) else: action, meta = reactor.onapplicationerror( diff --git a/tests/test-wireproto-serverreactor.py b/tests/test-wireproto-serverreactor.py --- a/tests/test-wireproto-serverreactor.py +++ b/tests/test-wireproto-serverreactor.py @@ -27,16 +27,19 @@ header.flags, payload)) -def sendcommandframes(reactor, rid, cmd, args, datafh=None): +def sendcommandframes(reactor, stream, rid, cmd, args, datafh=None): """Generate frames to run a command and send them to a reactor.""" return sendframes(reactor, - framing.createcommandframes(rid, cmd, args, datafh)) + framing.createcommandframes(stream, rid, cmd, args, + datafh)) class FrameTests(unittest.TestCase): def testdataexactframesize(self): data = util.bytesio(b'x' * framing.DEFAULT_MAX_FRAME_SIZE) - frames = list(framing.createcommandframes(1, b'command', {}, data)) + stream = framing.stream() + frames = list(framing.createcommandframes(stream, 1, b'command', + {}, data)) self.assertEqual(frames, [ ffs(b'1 command-name have-data command'), ffs(b'1 command-data continuation %s' % data.getvalue()), @@ -45,7 +48,10 @@ def testdatamultipleframes(self): data = util.bytesio(b'x' * (framing.DEFAULT_MAX_FRAME_SIZE + 1)) - frames = list(framing.createcommandframes(1, b'command', {}, data)) + + stream = framing.stream() + frames = list(framing.createcommandframes(stream, 1, b'command', {}, + data)) self.assertEqual(frames, [ ffs(b'1 command-name have-data command'), ffs(b'1 command-data continuation %s' % ( @@ -56,7 +62,8 @@ def testargsanddata(self): data = util.bytesio(b'x' * 100) - frames = list(framing.createcommandframes(1, b'command', { + stream = framing.stream() + frames = list(framing.createcommandframes(stream, 1, b'command', { b'key1': b'key1value', b'key2': b'key2value', b'key3': b'key3value', @@ -75,51 +82,54 @@ with self.assertRaisesRegexp(ValueError, 'cannot use more than 255 formatting'): args = [b'x' for i in range(256)] - list(framing.createtextoutputframe(1, [(b'bleh', args, [])])) + list(framing.createtextoutputframe(None, 1, + [(b'bleh', args, [])])) def testtextoutputexcessivelabels(self): """At most 255 labels are allowed.""" with self.assertRaisesRegexp(ValueError, 'cannot use more than 255 labels'): labels = [b'l' for i in range(256)] - list(framing.createtextoutputframe(1, [(b'bleh', [], labels)])) + list(framing.createtextoutputframe(None, 1, + [(b'bleh', [], labels)])) def testtextoutputformattingstringtype(self): """Formatting string must be bytes.""" with self.assertRaisesRegexp(ValueError, 'must use bytes formatting '): - list(framing.createtextoutputframe(1, [ + list(framing.createtextoutputframe(None, 1, [ (b'foo'.decode('ascii'), [], [])])) def testtextoutputargumentbytes(self): with self.assertRaisesRegexp(ValueError, 'must use bytes for argument'): - list(framing.createtextoutputframe(1, [ + list(framing.createtextoutputframe(None, 1, [ (b'foo', [b'foo'.decode('ascii')], [])])) def testtextoutputlabelbytes(self): with self.assertRaisesRegexp(ValueError, 'must use bytes for labels'): - list(framing.createtextoutputframe(1, [ + list(framing.createtextoutputframe(None, 1, [ (b'foo', [], [b'foo'.decode('ascii')])])) def testtextoutputtoolongformatstring(self): with self.assertRaisesRegexp(ValueError, 'formatting string cannot be longer than'): - list(framing.createtextoutputframe(1, [ + list(framing.createtextoutputframe(None, 1, [ (b'x' * 65536, [], [])])) def testtextoutputtoolongargumentstring(self): with self.assertRaisesRegexp(ValueError, 'argument string cannot be longer than'): - list(framing.createtextoutputframe(1, [ + list(framing.createtextoutputframe(None, 1, [ (b'bleh', [b'x' * 65536], [])])) def testtextoutputtoolonglabelstring(self): with self.assertRaisesRegexp(ValueError, 'label string cannot be longer than'): - list(framing.createtextoutputframe(1, [ + list(framing.createtextoutputframe(None, 1, [ (b'bleh', [], [b'x' * 65536])])) def testtextoutput1simpleatom(self): - val = list(framing.createtextoutputframe(1, [ + stream = framing.stream() + val = list(framing.createtextoutputframe(stream, 1, [ (b'foo', [], [])])) self.assertEqual(val, [ @@ -127,7 +137,8 @@ ]) def testtextoutput2simpleatoms(self): - val = list(framing.createtextoutputframe(1, [ + stream = framing.stream() + val = list(framing.createtextoutputframe(stream, 1, [ (b'foo', [], []), (b'bar', [], []), ])) @@ -137,7 +148,8 @@ ]) def testtextoutput1arg(self): - val = list(framing.createtextoutputframe(1, [ + stream = framing.stream() + val = list(framing.createtextoutputframe(stream, 1, [ (b'foo %s', [b'val1'], []), ])) @@ -146,7 +158,8 @@ ]) def testtextoutput2arg(self): - val = list(framing.createtextoutputframe(1, [ + stream = framing.stream() + val = list(framing.createtextoutputframe(stream, 1, [ (b'foo %s %s', [b'val', b'value'], []), ])) @@ -156,7 +169,8 @@ ]) def testtextoutput1label(self): - val = list(framing.createtextoutputframe(1, [ + stream = framing.stream() + val = list(framing.createtextoutputframe(stream, 1, [ (b'foo', [], [b'label']), ])) @@ -165,7 +179,8 @@ ]) def testargandlabel(self): - val = list(framing.createtextoutputframe(1, [ + stream = framing.stream() + val = list(framing.createtextoutputframe(stream, 1, [ (b'foo %s', [b'arg'], [b'label']), ])) @@ -193,7 +208,8 @@ def test1framecommand(self): """Receiving a command in a single frame yields request to run it.""" reactor = makereactor() - results = list(sendcommandframes(reactor, 1, b'mycommand', {})) + stream = framing.stream() + results = list(sendcommandframes(reactor, stream, 1, b'mycommand', {})) self.assertEqual(len(results), 1) self.assertaction(results[0], 'runcommand') self.assertEqual(results[0][1], { @@ -208,7 +224,8 @@ def test1argument(self): reactor = makereactor() - results = list(sendcommandframes(reactor, 41, b'mycommand', + stream = framing.stream() + results = list(sendcommandframes(reactor, stream, 41, b'mycommand', {b'foo': b'bar'})) self.assertEqual(len(results), 2) self.assertaction(results[0], 'wantframe') @@ -222,7 +239,8 @@ def testmultiarguments(self): reactor = makereactor() - results = list(sendcommandframes(reactor, 1, b'mycommand', + stream = framing.stream() + results = list(sendcommandframes(reactor, stream, 1, b'mycommand', {b'foo': b'bar', b'biz': b'baz'})) self.assertEqual(len(results), 3) self.assertaction(results[0], 'wantframe') @@ -237,7 +255,8 @@ def testsimplecommanddata(self): reactor = makereactor() - results = list(sendcommandframes(reactor, 1, b'mycommand', {}, + stream = framing.stream() + results = list(sendcommandframes(reactor, stream, 1, b'mycommand', {}, util.bytesio(b'data!'))) self.assertEqual(len(results), 2) self.assertaction(results[0], 'wantframe') @@ -488,9 +507,11 @@ def testsimpleresponse(self): """Bytes response to command sends result frames.""" reactor = makereactor() - list(sendcommandframes(reactor, 1, b'mycommand', {})) + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'mycommand', {})) - result = reactor.onbytesresponseready(1, b'response') + outstream = framing.stream() + result = reactor.onbytesresponseready(outstream, 1, b'response') self.assertaction(result, 'sendframes') self.assertframesequal(result[1]['framegen'], [ b'1 bytes-response eos response', @@ -502,9 +523,11 @@ second = b'y' * 100 reactor = makereactor() - list(sendcommandframes(reactor, 1, b'mycommand', {})) + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'mycommand', {})) - result = reactor.onbytesresponseready(1, first + second) + outstream = framing.stream() + result = reactor.onbytesresponseready(outstream, 1, first + second) self.assertaction(result, 'sendframes') self.assertframesequal(result[1]['framegen'], [ b'1 bytes-response continuation %s' % first, @@ -513,9 +536,11 @@ def testapplicationerror(self): reactor = makereactor() - list(sendcommandframes(reactor, 1, b'mycommand', {})) + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'mycommand', {})) - result = reactor.onapplicationerror(1, b'some message') + outstream = framing.stream() + result = reactor.onapplicationerror(outstream, 1, b'some message') self.assertaction(result, 'sendframes') self.assertframesequal(result[1]['framegen'], [ b'1 error-response application some message', @@ -524,11 +549,14 @@ def test1commanddeferresponse(self): """Responses when in deferred output mode are delayed until EOF.""" reactor = makereactor(deferoutput=True) - results = list(sendcommandframes(reactor, 1, b'mycommand', {})) + instream = framing.stream() + results = list(sendcommandframes(reactor, instream, 1, b'mycommand', + {})) self.assertEqual(len(results), 1) self.assertaction(results[0], 'runcommand') - result = reactor.onbytesresponseready(1, b'response') + outstream = framing.stream() + result = reactor.onbytesresponseready(outstream, 1, b'response') self.assertaction(result, 'noop') result = reactor.oninputeof() self.assertaction(result, 'sendframes') @@ -538,12 +566,14 @@ def testmultiplecommanddeferresponse(self): reactor = makereactor(deferoutput=True) - list(sendcommandframes(reactor, 1, b'command1', {})) - list(sendcommandframes(reactor, 3, b'command2', {})) + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'command1', {})) + list(sendcommandframes(reactor, instream, 3, b'command2', {})) - result = reactor.onbytesresponseready(1, b'response1') + outstream = framing.stream() + result = reactor.onbytesresponseready(outstream, 1, b'response1') self.assertaction(result, 'noop') - result = reactor.onbytesresponseready(3, b'response2') + result = reactor.onbytesresponseready(outstream, 3, b'response2') self.assertaction(result, 'noop') result = reactor.oninputeof() self.assertaction(result, 'sendframes') @@ -554,14 +584,16 @@ def testrequestidtracking(self): reactor = makereactor(deferoutput=True) - list(sendcommandframes(reactor, 1, b'command1', {})) - list(sendcommandframes(reactor, 3, b'command2', {})) - list(sendcommandframes(reactor, 5, b'command3', {})) + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'command1', {})) + list(sendcommandframes(reactor, instream, 3, b'command2', {})) + list(sendcommandframes(reactor, instream, 5, b'command3', {})) # Register results for commands out of order. - reactor.onbytesresponseready(3, b'response3') - reactor.onbytesresponseready(1, b'response1') - reactor.onbytesresponseready(5, b'response5') + outstream = framing.stream() + reactor.onbytesresponseready(outstream, 3, b'response3') + reactor.onbytesresponseready(outstream, 1, b'response1') + reactor.onbytesresponseready(outstream, 5, b'response5') result = reactor.oninputeof() self.assertaction(result, 'sendframes') @@ -574,8 +606,9 @@ def testduplicaterequestonactivecommand(self): """Receiving a request ID that matches a request that isn't finished.""" reactor = makereactor() - list(sendcommandframes(reactor, 1, b'command1', {})) - results = list(sendcommandframes(reactor, 1, b'command1', {})) + stream = framing.stream() + list(sendcommandframes(reactor, stream, 1, b'command1', {})) + results = list(sendcommandframes(reactor, stream, 1, b'command1', {})) self.assertaction(results[0], 'error') self.assertEqual(results[0][1], { @@ -585,13 +618,15 @@ def testduplicaterequestonactivecommandnosend(self): """Same as above but we've registered a response but haven't sent it.""" reactor = makereactor() - list(sendcommandframes(reactor, 1, b'command1', {})) - reactor.onbytesresponseready(1, b'response') + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'command1', {})) + outstream = framing.stream() + reactor.onbytesresponseready(outstream, 1, b'response') # We've registered the response but haven't sent it. From the # perspective of the reactor, the command is still active. - results = list(sendcommandframes(reactor, 1, b'command1', {})) + results = list(sendcommandframes(reactor, instream, 1, b'command1', {})) self.assertaction(results[0], 'error') self.assertEqual(results[0][1], { 'message': b'request with ID 1 is already active', @@ -600,7 +635,8 @@ def testduplicaterequestargumentframe(self): """Variant on above except we sent an argument frame instead of name.""" reactor = makereactor() - list(sendcommandframes(reactor, 1, b'command', {})) + stream = framing.stream() + list(sendcommandframes(reactor, stream, 1, b'command', {})) results = list(sendframes(reactor, [ ffs(b'3 command-name have-args command'), ffs(b'1 command-argument 0 ignored'), @@ -614,11 +650,13 @@ def testduplicaterequestaftersend(self): """We can use a duplicate request ID after we've sent the response.""" reactor = makereactor() - list(sendcommandframes(reactor, 1, b'command1', {})) - res = reactor.onbytesresponseready(1, b'response') + instream = framing.stream() + list(sendcommandframes(reactor, instream, 1, b'command1', {})) + outstream = framing.stream() + res = reactor.onbytesresponseready(outstream, 1, b'response') list(res[1]['framegen']) - results = list(sendcommandframes(reactor, 1, b'command1', {})) + results = list(sendcommandframes(reactor, instream, 1, b'command1', {})) self.assertaction(results[0], 'runcommand') if __name__ == '__main__':