diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -336,19 +336,11 @@ return '' -class sshserver(baseprotocolhandler): - def __init__(self, ui, repo): +class sshprotocolhandler(baseprotocolhandler): + def __init__(self, ui, fin, fout): self._ui = ui - self._repo = repo - self._fin = ui.fin - self._fout = ui.fout - - hook.redirect(True) - ui.fout = repo.ui.fout = ui.ferr - - # Prevent insertion/deletion of CRs - util.setbinary(self._fin) - util.setbinary(self._fout) + self._fin = fin + self._fout = fout @property def name(self): @@ -409,11 +401,6 @@ self._fout.write('\n') self._fout.flush() - def serve_forever(self): - while self.serve_one(): - pass - sys.exit(0) - _handlers = { str: _sendresponse, wireproto.streamres: _sendstream, @@ -423,15 +410,38 @@ wireproto.ooberror: _sendooberror, } - def serve_one(self): - cmd = self._fin.readline()[:-1] - if cmd and wireproto.commands.commandavailable(cmd, self): - rsp = wireproto.dispatch(self._repo, self, cmd) - self._handlers[rsp.__class__](self, rsp) - elif cmd: - self._sendresponse("") - return cmd != '' - def _client(self): client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] return 'remote:ssh:' + client + +class sshserver(object): + def __init__(self, ui, repo): + self._ui = ui + self._repo = repo + self._fin = ui.fin + self._fout = ui.fout + + hook.redirect(True) + ui.fout = repo.ui.fout = ui.ferr + + # Prevent insertion/deletion of CRs + util.setbinary(self._fin) + util.setbinary(self._fout) + + self._proto = sshprotocolhandler(self._ui, self._fin, self._fout) + + def serve_forever(self): + while self.serve_one(): + pass + sys.exit(0) + + def serve_one(self): + # TODO improve boundary between transport layer and protocol handler. + cmd = self._fin.readline()[:-1] + if cmd and wireproto.commands.commandavailable(cmd, self._proto): + rsp = wireproto.dispatch(self._repo, self._proto, cmd) + self._proto._handlers[rsp.__class__](self._proto, rsp) + elif cmd: + self._proto._sendresponse("") + + return cmd != '' diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py --- a/tests/sshprotoext.py +++ b/tests/sshprotoext.py @@ -45,11 +45,11 @@ l = self._fin.readline() assert l == b'hello\n' # Respond to unknown commands with an empty reply. - self._sendresponse(b'') + self._proto._sendresponse(b'') l = self._fin.readline() assert l == b'between\n' - rsp = wireproto.dispatch(self._repo, self, b'between') - self._handlers[rsp.__class__](self, rsp) + rsp = wireproto.dispatch(self._repo, self._proto, b'between') + self._proto._handlers[rsp.__class__](self._proto, rsp) super(prehelloserver, self).serve_forever() @@ -73,7 +73,7 @@ # Send the upgrade response. self._fout.write(b'upgraded %s %s\n' % (token, name)) - servercaps = wireproto.capabilities(self._repo, self) + servercaps = wireproto.capabilities(self._repo, self._proto) rsp = b'capabilities: %s' % servercaps self._fout.write(b'%d\n' % len(rsp)) self._fout.write(rsp) diff --git a/tests/test-sshserver.py b/tests/test-sshserver.py --- a/tests/test-sshserver.py +++ b/tests/test-sshserver.py @@ -24,7 +24,7 @@ def assertparse(self, cmd, input, expected): server = mockserver(input) _func, spec = wireproto.commands[cmd] - self.assertEqual(server.getargs(spec), expected) + self.assertEqual(server._proto.getargs(spec), expected) def mockserver(inbytes): ui = mockui(inbytes)