diff --git a/mercurial/sshpeer.py b/mercurial/sshpeer.py --- a/mercurial/sshpeer.py +++ b/mercurial/sshpeer.py @@ -156,13 +156,69 @@ return proc, stdin, stdout, stderr +def _performhandshake(ui, stdin, stdout, stderr): + def badresponse(): + msg = _('no suitable response from remote hg') + hint = ui.config('ui', 'ssherrorhint') + raise error.RepoError(msg, hint=hint) + + requestlog = ui.configbool('devel', 'debug.peer-request') + + try: + pairsarg = '%s-%s' % ('0' * 40, '0' * 40) + handshake = [ + 'hello\n', + 'between\n', + 'pairs %d\n' % len(pairsarg), + pairsarg, + ] + + if requestlog: + ui.debug('devel-peer-request: hello\n') + ui.debug('sending hello command\n') + if requestlog: + ui.debug('devel-peer-request: between\n') + ui.debug('devel-peer-request: pairs: %d bytes\n' % len(pairsarg)) + ui.debug('sending between command\n') + + stdin.write(''.join(handshake)) + stdin.flush() + except IOError: + badresponse() + + lines = ['', 'dummy'] + max_noise = 500 + while lines[-1] and max_noise: + try: + l = stdout.readline() + _forwardoutput(ui, stderr) + if lines[-1] == '1\n' and l == '\n': + break + if l: + ui.debug('remote: ', l) + lines.append(l) + max_noise -= 1 + except IOError: + badresponse() + else: + badresponse() + + caps = set() + for l in reversed(lines): + if l.startswith('capabilities:'): + caps.update(l[:-1].split(':')[1].split()) + break + + return caps + class sshpeer(wireproto.wirepeer): - def __init__(self, ui, url, proc, stdin, stdout, stderr): + def __init__(self, ui, url, proc, stdin, stdout, stderr, caps): """Create a peer from an existing SSH connection. ``proc`` is a handle on the underlying SSH process. ``stdin``, ``stdout``, and ``stderr`` are handles on the stdio pipes for that process. + ``caps`` is a set of capabilities supported by the remote. """ self._url = url self._ui = ui @@ -172,8 +228,7 @@ self._pipeo = stdin self._pipei = stdout self._pipee = stderr - - self._validaterepo() + self._caps = caps # Begin of _basepeer interface. @@ -205,61 +260,6 @@ # End of _basewirecommands interface. - def _validaterepo(self): - def badresponse(): - msg = _("no suitable response from remote hg") - hint = self.ui.config("ui", "ssherrorhint") - self._abort(error.RepoError(msg, hint=hint)) - - try: - pairsarg = '%s-%s' % ('0' * 40, '0' * 40) - - handshake = [ - 'hello\n', - 'between\n', - 'pairs %d\n' % len(pairsarg), - pairsarg, - ] - - requestlog = self.ui.configbool('devel', 'debug.peer-request') - - if requestlog: - self.ui.debug('devel-peer-request: hello\n') - self.ui.debug('sending hello command\n') - if requestlog: - self.ui.debug('devel-peer-request: between\n') - self.ui.debug('devel-peer-request: pairs: %d bytes\n' % - len(pairsarg)) - self.ui.debug('sending between command\n') - - self._pipeo.write(''.join(handshake)) - self._pipeo.flush() - except IOError: - badresponse() - - lines = ["", "dummy"] - max_noise = 500 - while lines[-1] and max_noise: - try: - l = self._pipei.readline() - _forwardoutput(self.ui, self._pipee) - if lines[-1] == "1\n" and l == "\n": - break - if l: - self.ui.debug("remote: ", l) - lines.append(l) - max_noise -= 1 - except IOError: - badresponse() - else: - badresponse() - - self._caps = set() - for l in reversed(lines): - if l.startswith("capabilities:"): - self._caps.update(l[:-1].split(":")[1].split()) - break - def _readerr(self): _forwardoutput(self.ui, self._pipee) @@ -414,4 +414,10 @@ proc, stdin, stdout, stderr = _makeconnection(ui, sshcmd, args, remotecmd, remotepath, sshenv) - return sshpeer(ui, path, proc, stdin, stdout, stderr) + try: + caps = _performhandshake(ui, stdin, stdout, stderr) + except Exception: + _cleanuppipes(ui, stdout, stdin, stderr) + raise + + return sshpeer(ui, path, proc, stdin, stdout, stderr, caps) diff --git a/tests/sshprotoext.py b/tests/sshprotoext.py --- a/tests/sshprotoext.py +++ b/tests/sshprotoext.py @@ -12,6 +12,7 @@ from mercurial import ( error, + extensions, registrar, sshpeer, wireproto, @@ -52,30 +53,26 @@ super(prehelloserver, self).serve_forever() -class extrahandshakecommandspeer(sshpeer.sshpeer): - """An ssh peer that sends extra commands as part of initial handshake.""" - def _validaterepo(self): - mode = self._ui.config(b'sshpeer', b'handshake-mode') - if mode == b'pre-no-args': - self._callstream(b'no-args') - return super(extrahandshakecommandspeer, self)._validaterepo() - elif mode == b'pre-multiple-no-args': - self._callstream(b'unknown1') - self._callstream(b'unknown2') - self._callstream(b'unknown3') - return super(extrahandshakecommandspeer, self)._validaterepo() - else: - raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' % - mode) - -def registercommands(): - def dummycommand(repo, proto): - raise error.ProgrammingError('this should never be called') - - wireproto.wireprotocommand(b'no-args', b'')(dummycommand) - wireproto.wireprotocommand(b'unknown1', b'')(dummycommand) - wireproto.wireprotocommand(b'unknown2', b'')(dummycommand) - wireproto.wireprotocommand(b'unknown3', b'')(dummycommand) +def performhandshake(orig, ui, stdin, stdout, stderr): + """Wrapped version of sshpeer._performhandshake to send extra commands.""" + mode = ui.config(b'sshpeer', b'handshake-mode') + if mode == b'pre-no-args': + ui.debug(b'sending no-args command\n') + stdin.write(b'no-args\n') + stdin.flush() + return orig(ui, stdin, stdout, stderr) + elif mode == b'pre-multiple-no-args': + ui.debug(b'sending unknown1 command\n') + stdin.write(b'unknown1\n') + ui.debug(b'sending unknown2 command\n') + stdin.write(b'unknown2\n') + ui.debug(b'sending unknown3 command\n') + stdin.write(b'unknown3\n') + stdin.flush() + return orig(ui, stdin, stdout, stderr) + else: + raise error.ProgrammingError(b'unknown HANDSHAKECOMMANDMODE: %s' % + mode) def extsetup(ui): # It's easier for tests to define the server behavior via environment @@ -94,7 +91,6 @@ peermode = ui.config(b'sshpeer', b'mode') if peermode == b'extra-handshake-commands': - sshpeer.sshpeer = extrahandshakecommandspeer - registercommands() + extensions.wrapfunction(sshpeer, '_performhandshake', performhandshake) elif peermode: raise error.ProgrammingError(b'unknown peer mode: %s' % peermode) diff --git a/tests/test-check-interfaces.py b/tests/test-check-interfaces.py --- a/tests/test-check-interfaces.py +++ b/tests/test-check-interfaces.py @@ -51,10 +51,6 @@ pass # Facilitates testing sshpeer without requiring an SSH server. -class testingsshpeer(sshpeer.sshpeer): - def _validaterepo(self, *args, **kwargs): - pass - class badpeer(httppeer.httppeer): def __init__(self): super(badpeer, self).__init__(uimod.ui(), 'http://localhost') @@ -69,8 +65,8 @@ checkobject(badpeer()) checkobject(httppeer.httppeer(ui, 'http://localhost')) checkobject(localrepo.localpeer(dummyrepo())) - checkobject(testingsshpeer(ui, 'ssh://localhost/foo', None, None, None, - None)) + checkobject(sshpeer.sshpeer(ui, 'ssh://localhost/foo', None, None, None, + None, None)) checkobject(bundlerepo.bundlepeer(dummyrepo())) checkobject(statichttprepo.statichttppeer(dummyrepo())) checkobject(unionrepo.unionpeer(dummyrepo())) diff --git a/tests/test-ssh-proto.t b/tests/test-ssh-proto.t --- a/tests/test-ssh-proto.t +++ b/tests/test-ssh-proto.t @@ -146,7 +146,6 @@ $ hg --config sshpeer.mode=extra-handshake-commands --config sshpeer.handshake-mode=pre-no-args --debug debugpeer ssh://user@dummy/server running * "*/tests/dummyssh" 'user@dummy' 'hg -R server serve --stdio' (glob) - devel-peer-request: no-args sending no-args command devel-peer-request: hello sending hello command @@ -182,11 +181,8 @@ $ hg --config sshpeer.mode=extra-handshake-commands --config sshpeer.handshake-mode=pre-multiple-no-args --debug debugpeer ssh://user@dummy/server running * "*/tests/dummyssh" 'user@dummy' 'hg -R server serve --stdio' (glob) - devel-peer-request: unknown1 sending unknown1 command - devel-peer-request: unknown2 sending unknown2 command - devel-peer-request: unknown3 sending unknown3 command devel-peer-request: hello sending hello command