diff --git a/hgext/largefiles/proto.py b/hgext/largefiles/proto.py --- a/hgext/largefiles/proto.py +++ b/hgext/largefiles/proto.py @@ -14,6 +14,7 @@ httppeer, util, wireproto, + wireprototypes, ) from . import ( @@ -85,8 +86,8 @@ server side.''' filename = lfutil.findfile(repo, sha) if not filename: - return '2\n' - return '0\n' + return wireprototypes.bytesresponse('2\n') + return wireprototypes.bytesresponse('0\n') def wirereposetup(ui, repo): class lfileswirerepository(repo.__class__): diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py --- a/mercurial/wireproto.py +++ b/mercurial/wireproto.py @@ -37,6 +37,7 @@ urlerr = util.urlerr urlreq = util.urlreq +bytesresponse = wireprototypes.bytesresponse ooberror = wireprototypes.ooberror pushres = wireprototypes.pushres pusherr = wireprototypes.pusherr @@ -696,8 +697,15 @@ result = func(repo, proto) if isinstance(result, ooberror): return result + + # For now, all batchable commands must return bytesresponse or + # raw bytes (for backwards compatibility). + assert isinstance(result, (bytesresponse, bytes)) + if isinstance(result, bytesresponse): + result = result.data res.append(escapearg(result)) - return ';'.join(res) + + return bytesresponse(';'.join(res)) @wireprotocommand('between', 'pairs') def between(repo, proto, pairs): @@ -705,7 +713,8 @@ r = [] for b in repo.between(pairs): r.append(encodelist(b) + "\n") - return "".join(r) + + return bytesresponse(''.join(r)) @wireprotocommand('branchmap') def branchmap(repo, proto): @@ -715,7 +724,8 @@ branchname = urlreq.quote(encoding.fromlocal(branch)) branchnodes = encodelist(nodes) heads.append('%s %s' % (branchname, branchnodes)) - return '\n'.join(heads) + + return bytesresponse('\n'.join(heads)) @wireprotocommand('branches', 'nodes') def branches(repo, proto, nodes): @@ -723,7 +733,8 @@ r = [] for b in repo.branches(nodes): r.append(encodelist(b) + "\n") - return "".join(r) + + return bytesresponse(''.join(r)) @wireprotocommand('clonebundles', '') def clonebundles(repo, proto): @@ -735,7 +746,7 @@ depending on the request. e.g. you could advertise URLs for the closest data center given the client's IP address. """ - return repo.vfs.tryread('clonebundles.manifest') + return bytesresponse(repo.vfs.tryread('clonebundles.manifest')) wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey', 'known', 'getbundle', 'unbundlehash', 'batch'] @@ -789,7 +800,7 @@ # `_capabilities` instead. @wireprotocommand('capabilities') def capabilities(repo, proto): - return ' '.join(_capabilities(repo, proto)) + return bytesresponse(' '.join(_capabilities(repo, proto))) @wireprotocommand('changegroup', 'roots') def changegroup(repo, proto, roots): @@ -814,7 +825,8 @@ def debugwireargs(repo, proto, one, two, others): # only accept optional args from the known set opts = options('debugwireargs', ['three', 'four'], others) - return repo.debugwireargs(one, two, **pycompat.strkwargs(opts)) + return bytesresponse(repo.debugwireargs(one, two, + **pycompat.strkwargs(opts))) @wireprotocommand('getbundle', '*') def getbundle(repo, proto, others): @@ -885,7 +897,7 @@ @wireprotocommand('heads') def heads(repo, proto): h = repo.heads() - return encodelist(h) + "\n" + return bytesresponse(encodelist(h) + '\n') @wireprotocommand('hello') def hello(repo, proto): @@ -896,12 +908,13 @@ capabilities: space separated list of tokens ''' - return "capabilities: %s\n" % (capabilities(repo, proto)) + caps = capabilities(repo, proto).data + return bytesresponse('capabilities: %s\n' % caps) @wireprotocommand('listkeys', 'namespace') def listkeys(repo, proto, namespace): d = repo.listkeys(encoding.tolocal(namespace)).items() - return pushkeymod.encodekeys(d) + return bytesresponse(pushkeymod.encodekeys(d)) @wireprotocommand('lookup', 'key') def lookup(repo, proto, key): @@ -913,11 +926,12 @@ except Exception as inst: r = str(inst) success = 0 - return "%d %s\n" % (success, r) + return bytesresponse('%d %s\n' % (success, r)) @wireprotocommand('known', 'nodes *') def known(repo, proto, nodes, others): - return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes))) + v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes))) + return bytesresponse(v) @wireprotocommand('pushkey', 'namespace key old new') def pushkey(repo, proto, namespace, key, old, new): @@ -938,7 +952,7 @@ encoding.tolocal(old), new) or False output = output.getvalue() if output else '' - return '%s\n%s' % (int(r), output) + return bytesresponse('%s\n%s' % (int(r), output)) @wireprotocommand('stream_out') def stream(repo, proto): diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -274,6 +274,9 @@ if isinstance(rsp, bytes): req.respond(HTTP_OK, HGTYPE, body=rsp) return [] + elif isinstance(rsp, wireprototypes.bytesresponse): + req.respond(HTTP_OK, HGTYPE, body=rsp.data) + return [] elif isinstance(rsp, wireprototypes.streamreslegacy): gen = rsp.gen req.respond(HTTP_OK, HGTYPE) @@ -435,6 +438,8 @@ if isinstance(rsp, bytes): _sshv1respondbytes(self._fout, rsp) + elif isinstance(rsp, wireprototypes.bytesresponse): + _sshv1respondbytes(self._fout, rsp.data) elif isinstance(rsp, wireprototypes.streamres): _sshv1respondstream(self._fout, rsp) elif isinstance(rsp, wireprototypes.streamreslegacy): diff --git a/mercurial/wireprototypes.py b/mercurial/wireprototypes.py --- a/mercurial/wireprototypes.py +++ b/mercurial/wireprototypes.py @@ -5,6 +5,11 @@ from __future__ import absolute_import +class bytesresponse(object): + """A wire protocol response consisting of raw bytes.""" + def __init__(self, data): + self.data = data + class ooberror(object): """wireproto reply: failure of a batch of operation diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py --- a/tests/sshprotoext.py +++ b/tests/sshprotoext.py @@ -49,7 +49,7 @@ l = self._fin.readline() assert l == b'between\n' rsp = wireproto.dispatch(self._repo, self._proto, b'between') - wireprotoserver._sshv1respondbytes(self._fout, rsp) + wireprotoserver._sshv1respondbytes(self._fout, rsp.data) super(prehelloserver, self).serve_forever() @@ -74,7 +74,7 @@ # Send the upgrade response. self._fout.write(b'upgraded %s %s\n' % (token, name)) servercaps = wireproto.capabilities(self._repo, self._proto) - rsp = b'capabilities: %s' % servercaps + rsp = b'capabilities: %s' % servercaps.data self._fout.write(b'%d\n' % len(rsp)) self._fout.write(rsp) self._fout.write(b'\n') diff --git a/tests/test-wireproto.py b/tests/test-wireproto.py --- a/tests/test-wireproto.py +++ b/tests/test-wireproto.py @@ -1,8 +1,10 @@ from __future__ import absolute_import, print_function from mercurial import ( + error, util, wireproto, + wireprototypes, ) stringio = util.stringio @@ -42,7 +44,13 @@ return ['batch'] def _call(self, cmd, **args): - return wireproto.dispatch(self.serverrepo, proto(args), cmd) + res = wireproto.dispatch(self.serverrepo, proto(args), cmd) + if isinstance(res, wireprototypes.bytesresponse): + return res.data + elif isinstance(res, bytes): + return res + else: + raise error.Abort('dummy client does not support response type') def _callstream(self, cmd, **args): return stringio(self._call(cmd, **args))