diff --git a/mercurial/commands.py b/mercurial/commands.py --- a/mercurial/commands.py +++ b/mercurial/commands.py @@ -53,12 +53,12 @@ rewriteutil, scmutil, server, - sshserver, streamclone, tags as tagsmod, templatekw, ui as uimod, util, + wireprotoserver, ) release = lockmod.release @@ -4756,7 +4756,7 @@ if repo is None: raise error.RepoError(_("there is no Mercurial repository here" " (.hg not found)")) - s = sshserver.sshserver(ui, repo) + s = wireprotoserver.sshserver(ui, repo) s.serve_forever() service = server.createservice(ui, repo, opts) diff --git a/mercurial/sshserver.py b/mercurial/sshserver.py deleted file mode 100644 --- a/mercurial/sshserver.py +++ /dev/null @@ -1,131 +0,0 @@ -# sshserver.py - ssh protocol server support for mercurial -# -# Copyright 2005-2007 Matt Mackall -# Copyright 2006 Vadim Gelfer -# -# This software may be used and distributed according to the terms of the -# GNU General Public License version 2 or any later version. - -from __future__ import absolute_import - -import sys - -from .i18n import _ -from . import ( - encoding, - error, - hook, - util, - wireproto, -) - -class sshserver(wireproto.abstractserverproto): - def __init__(self, ui, repo): - self.ui = ui - self.repo = repo - self.lock = None - self.fin = ui.fin - self.fout = ui.fout - self.name = 'ssh' - - hook.redirect(True) - ui.fout = repo.ui.fout = ui.ferr - - # Prevent insertion/deletion of CRs - util.setbinary(self.fin) - util.setbinary(self.fout) - - def getargs(self, args): - data = {} - keys = args.split() - for n in xrange(len(keys)): - argline = self.fin.readline()[:-1] - arg, l = argline.split() - if arg not in keys: - raise error.Abort(_("unexpected parameter %r") % arg) - if arg == '*': - star = {} - for k in xrange(int(l)): - argline = self.fin.readline()[:-1] - arg, l = argline.split() - val = self.fin.read(int(l)) - star[arg] = val - data['*'] = star - else: - val = self.fin.read(int(l)) - data[arg] = val - return [data[k] for k in keys] - - def getarg(self, name): - return self.getargs(name)[0] - - def getfile(self, fpout): - self.sendresponse('') - count = int(self.fin.readline()) - while count: - fpout.write(self.fin.read(count)) - count = int(self.fin.readline()) - - def redirect(self): - pass - - def sendresponse(self, v): - self.fout.write("%d\n" % len(v)) - self.fout.write(v) - self.fout.flush() - - def sendstream(self, source): - write = self.fout.write - for chunk in source.gen: - write(chunk) - self.fout.flush() - - def sendpushresponse(self, rsp): - self.sendresponse('') - self.sendresponse(str(rsp.res)) - - def sendpusherror(self, rsp): - self.sendresponse(rsp.res) - - def sendooberror(self, rsp): - self.ui.ferr.write('%s\n-\n' % rsp.message) - self.ui.ferr.flush() - self.fout.write('\n') - self.fout.flush() - - def serve_forever(self): - try: - while self.serve_one(): - pass - finally: - if self.lock is not None: - self.lock.release() - sys.exit(0) - - handlers = { - str: sendresponse, - wireproto.streamres: sendstream, - wireproto.streamres_legacy: sendstream, - wireproto.pushres: sendpushresponse, - wireproto.pusherr: sendpusherror, - wireproto.ooberror: sendooberror, - } - - def serve_one(self): - cmd = self.fin.readline()[:-1] - if cmd and cmd in wireproto.commands: - rsp = wireproto.dispatch(self.repo, self, cmd) - self.handlers[rsp.__class__](self, rsp) - elif cmd: - impl = getattr(self, 'do_' + cmd, None) - if impl: - r = impl() - if r is not None: - self.sendresponse(r) - else: - self.sendresponse("") - return cmd != '' - - def _client(self): - client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] - return 'remote:ssh:' + client diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -8,9 +8,13 @@ import cgi import struct +import sys +from .i18n import _ from . import ( + encoding, error, + hook, pycompat, util, wireproto, @@ -197,3 +201,114 @@ req.respond(HTTP_OK, HGERRTYPE, body=rsp) return [] raise error.ProgrammingError('hgweb.protocol internal failure', rsp) + +class sshserver(wireproto.abstractserverproto): + def __init__(self, ui, repo): + self.ui = ui + self.repo = repo + self.lock = None + self.fin = ui.fin + self.fout = ui.fout + self.name = 'ssh' + + hook.redirect(True) + ui.fout = repo.ui.fout = ui.ferr + + # Prevent insertion/deletion of CRs + util.setbinary(self.fin) + util.setbinary(self.fout) + + def getargs(self, args): + data = {} + keys = args.split() + for n in xrange(len(keys)): + argline = self.fin.readline()[:-1] + arg, l = argline.split() + if arg not in keys: + raise error.Abort(_("unexpected parameter %r") % arg) + if arg == '*': + star = {} + for k in xrange(int(l)): + argline = self.fin.readline()[:-1] + arg, l = argline.split() + val = self.fin.read(int(l)) + star[arg] = val + data['*'] = star + else: + val = self.fin.read(int(l)) + data[arg] = val + return [data[k] for k in keys] + + def getarg(self, name): + return self.getargs(name)[0] + + def getfile(self, fpout): + self.sendresponse('') + count = int(self.fin.readline()) + while count: + fpout.write(self.fin.read(count)) + count = int(self.fin.readline()) + + def redirect(self): + pass + + def sendresponse(self, v): + self.fout.write("%d\n" % len(v)) + self.fout.write(v) + self.fout.flush() + + def sendstream(self, source): + write = self.fout.write + for chunk in source.gen: + write(chunk) + self.fout.flush() + + def sendpushresponse(self, rsp): + self.sendresponse('') + self.sendresponse(str(rsp.res)) + + def sendpusherror(self, rsp): + self.sendresponse(rsp.res) + + def sendooberror(self, rsp): + self.ui.ferr.write('%s\n-\n' % rsp.message) + self.ui.ferr.flush() + self.fout.write('\n') + self.fout.flush() + + def serve_forever(self): + try: + while self.serve_one(): + pass + finally: + if self.lock is not None: + self.lock.release() + sys.exit(0) + + handlers = { + str: sendresponse, + wireproto.streamres: sendstream, + wireproto.streamres_legacy: sendstream, + wireproto.pushres: sendpushresponse, + wireproto.pusherr: sendpusherror, + wireproto.ooberror: sendooberror, + } + + def serve_one(self): + cmd = self.fin.readline()[:-1] + if cmd and cmd in wireproto.commands: + rsp = wireproto.dispatch(self.repo, self, cmd) + self.handlers[rsp.__class__](self, rsp) + elif cmd: + impl = getattr(self, 'do_' + cmd, None) + if impl: + r = impl() + if r is not None: + self.sendresponse(r) + else: + self.sendresponse("") + return cmd != '' + + def _client(self): + client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0] + return 'remote:ssh:' + client diff --git a/tests/test-sshserver.py b/tests/test-sshserver.py --- a/tests/test-sshserver.py +++ b/tests/test-sshserver.py @@ -6,9 +6,9 @@ import silenttestrunner from mercurial import ( - sshserver, util, wireproto, + wireprotoserver, ) class SSHServerGetArgsTests(unittest.TestCase): @@ -29,7 +29,7 @@ def mockserver(inbytes): ui = mockui(inbytes) repo = mockrepo(ui) - return sshserver.sshserver(ui, repo) + return wireprotoserver.sshserver(ui, repo) class mockrepo(object): def __init__(self, ui):