diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -354,19 +354,12 @@ fout.write(b'\n') fout.flush() -class sshserver(baseprotocolhandler): - def __init__(self, ui, repo): +class sshv1protocolhandler(baseprotocolhandler): + """Handler for requests services via version 1 of SSH protocol.""" + def __init__(self, ui, fin, fout): self._ui = ui - self._repo = repo - self._fin = ui.fin - self._fout = ui.fout - - hook.redirect(True) - ui.fout = repo.ui.fout = ui.ferr - - # Prevent insertion/deletion of CRs - util.setbinary(self._fin) - util.setbinary(self._fout) + self._fin = fin + self._fout = fout @property def name(self): @@ -403,6 +396,26 @@ def redirect(self): pass + def _client(self): + client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] + return 'remote:ssh:' + client + +class sshserver(object): + def __init__(self, ui, repo): + self._ui = ui + self._repo = repo + self._fin = ui.fin + self._fout = ui.fout + + hook.redirect(True) + ui.fout = repo.ui.fout = ui.ferr + + # Prevent insertion/deletion of CRs + util.setbinary(self._fin) + util.setbinary(self._fout) + + self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout) + def serve_forever(self): while self.serve_one(): pass @@ -410,8 +423,8 @@ def serve_one(self): cmd = self._fin.readline()[:-1] - if cmd and wireproto.commands.commandavailable(cmd, self): - rsp = wireproto.dispatch(self._repo, self, cmd) + if cmd and wireproto.commands.commandavailable(cmd, self._proto): + rsp = wireproto.dispatch(self._repo, self._proto, cmd) if isinstance(rsp, bytes): _sshv1respondbytes(self._fout, rsp) @@ -432,7 +445,3 @@ elif cmd: _sshv1respondbytes(self._fout, b'') return cmd != '' - - def _client(self): - client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] - return 'remote:ssh:' + client diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py --- a/tests/sshprotoext.py +++ b/tests/sshprotoext.py @@ -48,7 +48,7 @@ wireprotoserver._sshv1respondbytes(self._fout, b'') l = self._fin.readline() assert l == b'between\n' - rsp = wireproto.dispatch(self._repo, self, b'between') + rsp = wireproto.dispatch(self._repo, self._proto, b'between') wireprotoserver._sshv1respondbytes(self._fout, rsp) super(prehelloserver, self).serve_forever() @@ -73,7 +73,7 @@ # Send the upgrade response. self._fout.write(b'upgraded %s %s\n' % (token, name)) - servercaps = wireproto.capabilities(self._repo, self) + servercaps = wireproto.capabilities(self._repo, self._proto) rsp = b'capabilities: %s' % servercaps self._fout.write(b'%d\n' % len(rsp)) self._fout.write(rsp) diff --git a/tests/test-sshserver.py b/tests/test-sshserver.py --- a/tests/test-sshserver.py +++ b/tests/test-sshserver.py @@ -24,7 +24,7 @@ def assertparse(self, cmd, input, expected): server = mockserver(input) _func, spec = wireproto.commands[cmd] - self.assertEqual(server.getargs(spec), expected) + self.assertEqual(server._proto.getargs(spec), expected) def mockserver(inbytes): ui = mockui(inbytes)