diff --git a/mercurial/sshpeer.py b/mercurial/sshpeer.py --- a/mercurial/sshpeer.py +++ b/mercurial/sshpeer.py @@ -115,35 +115,15 @@ return self._main.flush() class sshpeer(wireproto.wirepeer): - def __init__(self, ui, path, create=False): + def __init__(self, ui, path, create=False, sshstate=None): self._url = path self._ui = ui self._pipeo = self._pipei = self._pipee = None u = util.url(path, parsequery=False, parsefragment=False) - - self._user = u.user - self._host = u.host - self._port = u.port self._path = u.path or '.' - sshcmd = self.ui.config("ui", "ssh") - remotecmd = self.ui.config("ui", "remotecmd") - sshaddenv = dict(self.ui.configitems("sshenv")) - sshenv = util.shellenviron(sshaddenv) - - args = util.sshargs(sshcmd, self._host, self._user, self._port) - - if create: - cmd = '%s %s %s' % (sshcmd, args, - util.shellquote("%s init %s" % - (_serverquote(remotecmd), _serverquote(self._path)))) - ui.debug('running %s\n' % cmd) - res = ui.system(cmd, blockedtag='sshpeer', environ=sshenv) - if res != 0: - self._abort(error.RepoError(_("could not create remote repo"))) - - self._validaterepo(sshcmd, args, remotecmd, sshenv) + self._validaterepo(*sshstate) # Begin of _basepeer interface. @@ -377,4 +357,23 @@ if u.passwd is not None: raise error.RepoError(_('password in URL not supported')) - return sshpeer(ui, path, create=create) + sshcmd = ui.config('ui', 'ssh') + remotecmd = ui.config('ui', 'remotecmd') + sshaddenv = dict(ui.configitems('sshenv')) + sshenv = util.shellenviron(sshaddenv) + remotepath = u.path or '.' + + args = util.sshargs(sshcmd, u.host, u.user, u.port) + + if create: + cmd = '%s %s %s' % (sshcmd, args, + util.shellquote('%s init %s' % + (_serverquote(remotecmd), _serverquote(remotepath)))) + ui.debug('running %s\n' % cmd) + res = ui.system(cmd, blockedtag='sshpeer', environ=sshenv) + if res != 0: + raise error.RepoError(_('could not create remote repo')) + + sshstate = (sshcmd, args, remotecmd, sshenv) + + return sshpeer(ui, path, create=create, sshstate=sshstate) diff --git a/tests/test-check-interfaces.py b/tests/test-check-interfaces.py --- a/tests/test-check-interfaces.py +++ b/tests/test-check-interfaces.py @@ -69,7 +69,7 @@ checkobject(badpeer()) checkobject(httppeer.httppeer(ui, 'http://localhost')) checkobject(localrepo.localpeer(dummyrepo())) - checkobject(testingsshpeer(ui, 'ssh://localhost/foo')) + checkobject(testingsshpeer(ui, 'ssh://localhost/foo', False, ())) checkobject(bundlerepo.bundlepeer(dummyrepo())) checkobject(statichttprepo.statichttppeer(dummyrepo())) checkobject(unionrepo.unionpeer(dummyrepo()))