diff --git a/mercurial/hgweb/hgweb_mod.py b/mercurial/hgweb/hgweb_mod.py --- a/mercurial/hgweb/hgweb_mod.py +++ b/mercurial/hgweb/hgweb_mod.py @@ -350,18 +350,15 @@ # Route it to a wire protocol handler if it looks like a wire protocol # request. - protohandler = wireprotoserver.parsehttprequest(rctx.repo, req, query) + protohandler = wireprotoserver.parsehttprequest(rctx, req, query, + self.check_perm) if protohandler: try: if query: raise ErrorResponse(HTTP_NOT_FOUND) - # TODO fold this into parsehttprequest - checkperm = lambda op: self.check_perm(rctx, req, op) - protohandler['proto'].checkperm = checkperm - - return protohandler['dispatch'](checkperm) + return protohandler['dispatch']() except ErrorResponse as inst: return protohandler['handleerror'](inst) diff --git a/mercurial/wireproto.py b/mercurial/wireproto.py --- a/mercurial/wireproto.py +++ b/mercurial/wireproto.py @@ -731,13 +731,10 @@ vals[unescapearg(n)] = unescapearg(v) func, spec = commands[op] - # If the protocol supports permissions checking, perform that - # checking on each batched command. - # TODO formalize permission checking as part of protocol interface. - if util.safehasattr(proto, 'checkperm'): - perm = commands[op].permission - assert perm in ('push', 'pull') - proto.checkperm(perm) + # Validate that client has permissions to perform this command. + perm = commands[op].permission + assert perm in ('push', 'pull') + proto.checkperm(perm) if spec: keys = spec.split() diff --git a/mercurial/wireprotoserver.py b/mercurial/wireprotoserver.py --- a/mercurial/wireprotoserver.py +++ b/mercurial/wireprotoserver.py @@ -54,9 +54,10 @@ return ''.join(chunks) class httpv1protocolhandler(wireprototypes.baseprotocolhandler): - def __init__(self, req, ui): + def __init__(self, req, ui, checkperm): self._req = req self._ui = ui + self._checkperm = checkperm @property def name(self): @@ -139,6 +140,9 @@ return caps + def checkperm(self, perm): + return self._checkperm(perm) + # This method exists mostly so that extensions like remotefilelog can # disable a kludgey legacy method only over http. As of early 2018, # there are no other known users, so with any luck we can discard this @@ -146,7 +150,7 @@ def iscmd(cmd): return cmd in wireproto.commands -def parsehttprequest(repo, req, query): +def parsehttprequest(rctx, req, query, checkperm): """Parse the HTTP request for a wire protocol request. If the current request appears to be a wire protocol request, this @@ -156,6 +160,8 @@ ``req`` is a ``wsgirequest`` instance. """ + repo = rctx.repo + # HTTP version 1 wire protocol requests are denoted by a "cmd" query # string parameter. If it isn't present, this isn't a wire protocol # request. @@ -174,13 +180,13 @@ if not iscmd(cmd): return None - proto = httpv1protocolhandler(req, repo.ui) + proto = httpv1protocolhandler(req, repo.ui, + lambda perm: checkperm(rctx, req, perm)) return { 'cmd': cmd, 'proto': proto, - 'dispatch': lambda checkperm: _callhttp(repo, req, proto, cmd, - checkperm), + 'dispatch': lambda: _callhttp(repo, req, proto, cmd), 'handleerror': lambda ex: _handlehttperror(ex, req, cmd), } @@ -224,7 +230,7 @@ opts = {'level': ui.configint('server', 'zliblevel')} return HGTYPE, util.compengines['zlib'], opts -def _callhttp(repo, req, proto, cmd, checkperm): +def _callhttp(repo, req, proto, cmd): def genversion2(gen, engine, engineopts): # application/mercurial-0.2 always sends a payload header # identifying the compression engine. @@ -242,7 +248,7 @@ 'over HTTP')) return [] - checkperm(wireproto.commands[cmd].permission) + proto.checkperm(wireproto.commands[cmd].permission) rsp = wireproto.dispatch(repo, proto, cmd) @@ -392,6 +398,9 @@ def addcapabilities(self, repo, caps): return caps + def checkperm(self, perm): + pass + class sshv2protocolhandler(sshv1protocolhandler): """Protocol handler for version 2 of the SSH protocol.""" diff --git a/mercurial/wireprototypes.py b/mercurial/wireprototypes.py --- a/mercurial/wireprototypes.py +++ b/mercurial/wireprototypes.py @@ -146,3 +146,12 @@ Returns a list of capabilities. The passed in argument can be returned. """ + + @abc.abstractmethod + def checkperm(self, perm): + """Validate that the client has permissions to perform a request. + + The argument is the permission required to proceed. If the client + doesn't have that permission, the exception should raise or abort + in a protocol specific manner. + """ diff --git a/tests/test-wireproto.py b/tests/test-wireproto.py --- a/tests/test-wireproto.py +++ b/tests/test-wireproto.py @@ -18,6 +18,9 @@ names = spec.split() return [args[n] for n in names] + def checkperm(self, perm): + pass + class clientpeer(wireproto.wirepeer): def __init__(self, serverrepo): self.serverrepo = serverrepo