diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -409,6 +409,56 @@ client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] return 'remote:ssh:' + client +def _runsshserver(ui, repo, fin, fout): + state = 'protov1-serving' + proto = sshv1protocolhandler(ui, fin, fout) + + while True: + if state == 'protov1-serving': + # Commands are issued on new lines. + request = fin.readline()[:-1] + + # Empty lines signal to terminate the connection. + if not request: + state = 'shutdown' + continue + + available = wireproto.commands.commandavailable(request, proto) + + # This command isn't available. Send an empty response and go + # back to waiting for a new command. + if not available: + _sshv1respondbytes(fout, b'') + continue + + rsp = wireproto.dispatch(repo, proto, request) + + if isinstance(rsp, bytes): + _sshv1respondbytes(fout, rsp) + elif isinstance(rsp, wireprototypes.bytesresponse): + _sshv1respondbytes(fout, rsp.data) + elif isinstance(rsp, wireprototypes.streamres): + _sshv1respondstream(fout, rsp) + elif isinstance(rsp, wireprototypes.streamreslegacy): + _sshv1respondstream(fout, rsp) + elif isinstance(rsp, wireprototypes.pushres): + _sshv1respondbytes(fout, b'') + _sshv1respondbytes(fout, bytes(rsp.res)) + elif isinstance(rsp, wireprototypes.pusherr): + _sshv1respondbytes(fout, rsp.res) + elif isinstance(rsp, wireprototypes.ooberror): + _sshv1respondooberror(fout, ui.ferr, rsp.message) + else: + raise error.ProgrammingError('unhandled response type from ' + 'wire protocol command: %s' % rsp) + + elif state == 'shutdown': + break + + else: + raise error.ProgrammingError('unhandled ssh server state: %s' % + state) + class sshserver(object): def __init__(self, ui, repo): self._ui = ui @@ -423,36 +473,8 @@ util.setbinary(self._fin) util.setbinary(self._fout) - self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout) - def serve_forever(self): - while self.serve_one(): - pass + _runsshserver(self._ui, self._repo, self._fin, self._fout) sys.exit(0) - def serve_one(self): - cmd = self._fin.readline()[:-1] - if cmd and wireproto.commands.commandavailable(cmd, self._proto): - rsp = wireproto.dispatch(self._repo, self._proto, cmd) - if isinstance(rsp, bytes): - _sshv1respondbytes(self._fout, rsp) - elif isinstance(rsp, wireprototypes.bytesresponse): - _sshv1respondbytes(self._fout, rsp.data) - elif isinstance(rsp, wireprototypes.streamres): - _sshv1respondstream(self._fout, rsp) - elif isinstance(rsp, wireprototypes.streamreslegacy): - _sshv1respondstream(self._fout, rsp) - elif isinstance(rsp, wireprototypes.pushres): - _sshv1respondbytes(self._fout, b'') - _sshv1respondbytes(self._fout, bytes(rsp.res)) - elif isinstance(rsp, wireprototypes.pusherr): - _sshv1respondbytes(self._fout, rsp.res) - elif isinstance(rsp, wireprototypes.ooberror): - _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message) - else: - raise error.ProgrammingError('unhandled response type from ' - 'wire protocol command: %s' % rsp) - elif cmd: - _sshv1respondbytes(self._fout, b'') - return cmd != '' diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py --- a/tests/sshprotoext.py +++ b/tests/sshprotoext.py @@ -48,7 +48,9 @@ wireprotoserver._sshv1respondbytes(self._fout, b'') l = self._fin.readline() assert l == b'between\n' - rsp = wireproto.dispatch(self._repo, self._proto, b'between') + proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin, + self._fout) + rsp = wireproto.dispatch(self._repo, proto, b'between') wireprotoserver._sshv1respondbytes(self._fout, rsp.data) super(prehelloserver, self).serve_forever() @@ -72,8 +74,10 @@ self._fin.read(81) # Send the upgrade response. + proto = wireprotoserver.sshv1protocolhandler(self._ui, self._fin, + self._fout) self._fout.write(b'upgraded %s %s\n' % (token, name)) - servercaps = wireproto.capabilities(self._repo, self._proto) + servercaps = wireproto.capabilities(self._repo, proto) rsp = b'capabilities: %s' % servercaps.data self._fout.write(b'%d\n' % len(rsp)) self._fout.write(rsp) diff --git a/tests/test-sshserver.py b/tests/test-sshserver.py --- a/tests/test-sshserver.py +++ b/tests/test-sshserver.py @@ -23,8 +23,11 @@ def assertparse(self, cmd, input, expected): server = mockserver(input) + proto = wireprotoserver.sshv1protocolhandler(server._ui, + server._fin, + server._fout) _func, spec = wireproto.commands[cmd] - self.assertEqual(server._proto.getargs(spec), expected) + self.assertEqual(proto.getargs(spec), expected) def mockserver(inbytes): ui = mockui(inbytes)