diff --git a/mercurial/sshpeer.py b/mercurial/sshpeer.py --- a/mercurial/sshpeer.py +++ b/mercurial/sshpeer.py @@ -349,6 +349,12 @@ self._pipee = stderr self._caps = caps + # Commands that have a "framed" response where the first line of the + # response contains the length of that response. + _FRAMED_COMMANDS = { + 'batch', + } + # Begin of _basepeer interface. @util.propertycache @@ -391,26 +397,7 @@ __del__ = _cleanup - def _submitbatch(self, req): - rsp = self._callstream("batch", cmds=wireproto.encodebatchcmds(req)) - available = self._getamount() - # TODO this response parsing is probably suboptimal for large - # batches with large responses. - toread = min(available, 1024) - work = rsp.read(toread) - available -= toread - chunk = work - while chunk: - while ';' in work: - one, work = work.split(';', 1) - yield wireproto.unescapearg(one) - toread = min(available, 1024) - chunk = rsp.read(toread) - available -= toread - work += chunk - yield wireproto.unescapearg(work) - - def _sendrequest(self, cmd, args): + def _sendrequest(self, cmd, args, framed=False): if (self.ui.debugflag and self.ui.configbool('devel', 'debug.peer-request')): dbg = self.ui.debug @@ -444,20 +431,27 @@ self._pipeo.write(v) self._pipeo.flush() + # We know exactly how many bytes are in the response. So return a proxy + # around the raw output stream that allows reading exactly this many + # bytes. Callers then can read() without fear of overrunning the + # response. + if framed: + amount = self._getamount() + return util.cappedreader(self._pipei, amount) + return self._pipei def _callstream(self, cmd, **args): args = pycompat.byteskwargs(args) - return self._sendrequest(cmd, args) + return self._sendrequest(cmd, args, framed=cmd in self._FRAMED_COMMANDS) def _callcompressable(self, cmd, **args): args = pycompat.byteskwargs(args) - return self._sendrequest(cmd, args) + return self._sendrequest(cmd, args, framed=cmd in self._FRAMED_COMMANDS) def _call(self, cmd, **args): args = pycompat.byteskwargs(args) - self._sendrequest(cmd, args) - return self._readframed() + return self._sendrequest(cmd, args, framed=True).read() def _callpush(self, cmd, fp, **args): r = self._call(cmd, **args)