diff --git a/hgext/largefiles/uisetup.py b/hgext/largefiles/uisetup.py --- a/hgext/largefiles/uisetup.py +++ b/hgext/largefiles/uisetup.py @@ -175,6 +175,7 @@ # ... and wrap some existing ones wireproto.commands['heads'].func = proto.heads + # TODO also wrap wireproto.commandsv2 once heads is implemented there. extensions.wrapfunction(webcommands, 'decodepath', overrides.decodepath) diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py --- a/mercurial/wireproto.py +++ b/mercurial/wireproto.py @@ -502,7 +502,11 @@ def dispatch(repo, proto, command): repo = getdispatchrepo(repo, proto, command) - func, spec = commands[command] + + transportversion = wireprototypes.TRANSPORTS[proto.name]['version'] + commandtable = commandsv2 if transportversion == 2 else commands + func, spec = commandtable[command] + args = proto.getargs(spec) return func(repo, proto, *args) @@ -679,8 +683,12 @@ POLICY_V1_ONLY = 'v1-only' POLICY_V2_ONLY = 'v2-only' +# For version 1 transports. commands = commanddict() +# For version 2 transports. +commandsv2 = commanddict() + def wireprotocommand(name, args='', transportpolicy=POLICY_ALL, permission='push'): """Decorator to declare a wire protocol command. @@ -702,12 +710,15 @@ """ if transportpolicy == POLICY_ALL: transports = set(wireprototypes.TRANSPORTS) + transportversions = {1, 2} elif transportpolicy == POLICY_V1_ONLY: transports = {k for k, v in wireprototypes.TRANSPORTS.items() if v['version'] == 1} + transportversions = {1} elif transportpolicy == POLICY_V2_ONLY: transports = {k for k, v in wireprototypes.TRANSPORTS.items() if v['version'] == 2} + transportversions = {2} else: raise error.ProgrammingError('invalid transport policy value: %s' % transportpolicy) @@ -724,8 +735,21 @@ permission) def register(func): - commands[name] = commandentry(func, args=args, transports=transports, - permission=permission) + if 1 in transportversions: + if name in commands: + raise error.ProgrammingError('%s command already registered ' + 'for version 1' % name) + commands[name] = commandentry(func, args=args, + transports=transports, + permission=permission) + if 2 in transportversions: + if name in commandsv2: + raise error.ProgrammingError('%s command already registered ' + 'for version 2' % name) + commandsv2[name] = commandentry(func, args=args, + transports=transports, + permission=permission) + return func return register diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -335,7 +335,7 @@ # extension. extracommands = {'multirequest'} - if command not in wireproto.commands and command not in extracommands: + if command not in wireproto.commandsv2 and command not in extracommands: res.status = b'404 Not Found' res.headers[b'Content-Type'] = b'text/plain' res.setbodybytes(_('unknown wire protocol command: %s\n') % command) @@ -346,7 +346,7 @@ proto = httpv2protocolhandler(req, ui) - if (not wireproto.commands.commandavailable(command, proto) + if (not wireproto.commandsv2.commandavailable(command, proto) and command not in extracommands): res.status = b'404 Not Found' res.headers[b'Content-Type'] = b'text/plain' @@ -502,7 +502,7 @@ proto = httpv2protocolhandler(req, ui, args=command['args']) if reqcommand == b'multirequest': - if not wireproto.commands.commandavailable(command['command'], proto): + if not wireproto.commandsv2.commandavailable(command['command'], proto): # TODO proper error mechanism res.status = b'200 OK' res.headers[b'Content-Type'] = b'text/plain' @@ -512,7 +512,7 @@ # TODO don't use assert here, since it may be elided by -O. assert authedperm in (b'ro', b'rw') - wirecommand = wireproto.commands[command['command']] + wirecommand = wireproto.commandsv2[command['command']] assert wirecommand.permission in ('push', 'pull') if authedperm == b'ro' and wirecommand.permission != 'pull': diff --git a/tests/test-wireproto.py b/tests/test-wireproto.py --- a/tests/test-wireproto.py +++ b/tests/test-wireproto.py @@ -13,6 +13,8 @@ class proto(object): def __init__(self, args): self.args = args + self.name = 'dummyproto' + def getargs(self, spec): args = self.args args.setdefault(b'*', {}) @@ -22,6 +24,11 @@ def checkperm(self, perm): pass +wireprototypes.TRANSPORTS['dummyproto'] = { + 'transport': 'dummy', + 'version': 1, +} + class clientpeer(wireproto.wirepeer): def __init__(self, serverrepo, ui): self.serverrepo = serverrepo