diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py --- a/mercurial/wireproto.py +++ b/mercurial/wireproto.py @@ -592,9 +592,10 @@ class commandentry(object): """Represents a declared wire protocol command.""" - def __init__(self, func, args=''): + def __init__(self, func, args='', transports=None): self.func = func self.args = args + self.transports = transports or set() def _merge(self, func, args): """Merge this instance with an incoming 2-tuple. @@ -604,7 +605,7 @@ data not captured by the 2-tuple and a new instance containing the union of the two objects is returned. """ - return commandentry(func, args=args) + return commandentry(func, args=args, transports=set(self.transports)) # Old code treats instances as 2-tuples. So expose that interface. def __iter__(self): @@ -640,7 +641,9 @@ if k in self: v = self[k]._merge(v[0], v[1]) else: - v = commandentry(v[0], args=v[1]) + # Use default values from @wireprotocommand. + v = commandentry(v[0], args=v[1], + transports=set(wireprototypes.TRANSPORTS)) else: raise ValueError('command entries must be commandentry instances ' 'or 2-tuples') @@ -649,22 +652,52 @@ def commandavailable(self, command, proto): """Determine if a command is available for the requested protocol.""" - # For now, commands are available for all protocols. So do a simple - # membership test. - return command in self + assert proto.name in wireprototypes.TRANSPORTS + + entry = self.get(command) + + if not entry: + return False + + if proto.name not in entry.transports: + return False + + return True + +# Constants specifying which transports a wire protocol command should be +# available on. For use with @wireprotocommand. +POLICY_ALL = 'all' +POLICY_V1_ONLY = 'v1-only' +POLICY_V2_ONLY = 'v2-only' commands = commanddict() -def wireprotocommand(name, args=''): +def wireprotocommand(name, args='', transportpolicy=POLICY_ALL): """Decorator to declare a wire protocol command. ``name`` is the name of the wire protocol command being provided. ``args`` is a space-delimited list of named arguments that the command accepts. ``*`` is a special value that says to accept all arguments. + + ``transportpolicy`` is a POLICY_* constant denoting which transports + this wire protocol command should be exposed to. By default, commands + are exposed to all wire protocol transports. """ + if transportpolicy == POLICY_ALL: + transports = set(wireprototypes.TRANSPORTS) + elif transportpolicy == POLICY_V1_ONLY: + transports = {k for k, v in wireprototypes.TRANSPORTS.items() + if v['version'] == 1} + elif transportpolicy == POLICY_V2_ONLY: + transports = {k for k, v in wireprototypes.TRANSPORTS.items() + if v['version'] == 2} + else: + raise error.Abort(_('invalid transport policy value: %s') % + transportpolicy) + def register(func): - commands[name] = commandentry(func, args=args) + commands[name] = commandentry(func, args=args, transports=transports) return func return register diff --git a/mercurial/wireprototypes.py b/mercurial/wireprototypes.py --- a/mercurial/wireprototypes.py +++ b/mercurial/wireprototypes.py @@ -13,6 +13,22 @@ # to reflect BC breakages. SSHV2 = 'exp-ssh-v2-0001' +# All available wire protocol transports. +TRANSPORTS = { + SSHV1: { + 'transport': 'ssh', + 'version': 1, + }, + SSHV2: { + 'transport': 'ssh', + 'version': 2, + }, + 'http-v1': { + 'transport': 'http', + 'version': 1, + } +} + class bytesresponse(object): """A wire protocol response consisting of raw bytes.""" def __init__(self, data):