diff --git a/tests/testlib/badserverext.py b/tests/testlib/badserverext.py --- a/tests/testlib/badserverext.py +++ b/tests/testlib/badserverext.py @@ -64,29 +64,47 @@ default=False, ) + +class ConditionTracker(object): + def __init__(self, close_after_recv_bytes, close_after_send_bytes): + self._all_close_after_recv_bytes = close_after_recv_bytes + self._all_close_after_send_bytes = close_after_send_bytes + + def start_next_request(self): + """move to the next set of close condition""" + if self._all_close_after_recv_bytes: + self.target_recv_bytes = self._all_close_after_recv_bytes.pop(0) + self.remaining_recv_bytes = self.target_recv_bytes + else: + self.target_recv_bytes = None + self.remaining_recv_bytes = None + if self._all_close_after_send_bytes: + self.target_send_bytes = self._all_close_after_send_bytes.pop(0) + self.remaining_send_bytes = self.target_send_bytes + else: + self.target_send_bytes = None + self.remaining_send_bytes = None + + def might_close(self): + """True, if any processing will be needed""" + if self.remaining_recv_bytes is not None: + return True + if self.remaining_send_bytes is not None: + return True + return False + + # We can't adjust __class__ on a socket instance. So we define a proxy type. class socketproxy(object): - __slots__ = ( - '_orig', - '_logfp', - '_close_after_recv_bytes', - '_close_after_send_bytes', - ) + __slots__ = ('_orig', '_logfp', '_cond') - def __init__( - self, obj, logfp, close_after_recv_bytes=0, close_after_send_bytes=0 - ): + def __init__(self, obj, logfp, condition_tracked): object.__setattr__(self, '_orig', obj) object.__setattr__(self, '_logfp', logfp) - object.__setattr__( - self, '_close_after_recv_bytes', close_after_recv_bytes - ) - object.__setattr__( - self, '_close_after_send_bytes', close_after_send_bytes - ) + object.__setattr__(self, '_cond', condition_tracked) def __getattribute__(self, name): - if name in ('makefile', 'sendall', '_writelog'): + if name in ('makefile', 'sendall', '_writelog', '_cond_close'): return object.__getattribute__(self, name) return getattr(object.__getattribute__(self, '_orig'), name) @@ -108,22 +126,12 @@ f = object.__getattribute__(self, '_orig').makefile(mode, bufsize) logfp = object.__getattribute__(self, '_logfp') - close_after_recv_bytes = object.__getattribute__( - self, '_close_after_recv_bytes' - ) - close_after_send_bytes = object.__getattribute__( - self, '_close_after_send_bytes' - ) + cond = object.__getattribute__(self, '_cond') - return fileobjectproxy( - f, - logfp, - close_after_recv_bytes=close_after_recv_bytes, - close_after_send_bytes=close_after_send_bytes, - ) + return fileobjectproxy(f, logfp, cond) def sendall(self, data, flags=0): - remaining = object.__getattribute__(self, '_close_after_send_bytes') + remaining = object.__getattribute__(self, '_cond').remaining_send_bytes # No read limit. Call original function. if not remaining: @@ -145,7 +153,7 @@ % (len(newdata), len(data), remaining, newdata) ) - object.__setattr__(self, '_close_after_send_bytes', remaining) + object.__getattribute__(self, '_cond').remaining_send_bytes = remaining if remaining <= 0: self._writelog(b'write limit reached; closing socket') @@ -158,24 +166,12 @@ # We can't adjust __class__ on socket._fileobject, so define a proxy. class fileobjectproxy(object): - __slots__ = ( - '_orig', - '_logfp', - '_close_after_recv_bytes', - '_close_after_send_bytes', - ) + __slots__ = ('_orig', '_logfp', '_cond') - def __init__( - self, obj, logfp, close_after_recv_bytes=0, close_after_send_bytes=0 - ): + def __init__(self, obj, logfp, condition_tracked): object.__setattr__(self, '_orig', obj) object.__setattr__(self, '_logfp', logfp) - object.__setattr__( - self, '_close_after_recv_bytes', close_after_recv_bytes - ) - object.__setattr__( - self, '_close_after_send_bytes', close_after_send_bytes - ) + object.__setattr__(self, '_cond', condition_tracked) def __getattribute__(self, name): if name in ('_close', 'read', 'readline', 'write', '_writelog'): @@ -210,7 +206,7 @@ self._sock.shutdown(socket.SHUT_RDWR) def read(self, size=-1): - remaining = object.__getattribute__(self, '_close_after_recv_bytes') + remaining = object.__getattribute__(self, '_cond').remaining_recv_bytes # No read limit. Call original function. if not remaining: @@ -235,7 +231,7 @@ % (size, origsize, len(result), result) ) - object.__setattr__(self, '_close_after_recv_bytes', remaining) + object.__getattribute__(self, '_cond').remaining_recv_bytes = remaining if remaining <= 0: self._writelog(b'read limit reached; closing socket') @@ -247,7 +243,7 @@ return result def readline(self, size=-1): - remaining = object.__getattribute__(self, '_close_after_recv_bytes') + remaining = object.__getattribute__(self, '_cond').remaining_recv_bytes # No read limit. Call original function. if not remaining: @@ -272,7 +268,7 @@ % (size, origsize, len(result), result) ) - object.__setattr__(self, '_close_after_recv_bytes', remaining) + object.__getattribute__(self, '_cond').remaining_recv_bytes = remaining if remaining <= 0: self._writelog(b'read limit reached; closing socket') @@ -284,7 +280,7 @@ return result def write(self, data): - remaining = object.__getattribute__(self, '_close_after_send_bytes') + remaining = object.__getattribute__(self, '_cond').remaining_send_bytes # No byte limit on this operation. Call original function. if not remaining: @@ -306,7 +302,7 @@ result = object.__getattribute__(self, '_orig').write(newdata) - object.__setattr__(self, '_close_after_send_bytes', remaining) + object.__getattribute__(self, '_cond').remaining_send_bytes = remaining if remaining <= 0: self._writelog(b'write limit reached; closing socket') @@ -317,6 +313,12 @@ return result +def process_config(value): + parts = value.split(b',') + integers = [int(v) for v in parts if v] + return [v if v else None for v in integers] + + def extsetup(ui): # Change the base HTTP server class so various events can be performed. # See SocketServer.BaseServer for how the specially named methods work. @@ -325,12 +327,15 @@ self._ui = ui super(badserver, self).__init__(ui, *args, **kwargs) - recvbytes = self._ui.config(b'badserver', b'close-after-recv-bytes') - recvbytes = recvbytes.split(b',') - self.close_after_recv_bytes = [int(v) for v in recvbytes if v] - sendbytes = self._ui.config(b'badserver', b'close-after-send-bytes') - sendbytes = sendbytes.split(b',') - self.close_after_send_bytes = [int(v) for v in sendbytes if v] + all_recv_bytes = self._ui.config( + b'badserver', b'close-after-recv-bytes' + ) + all_recv_bytes = process_config(all_recv_bytes) + all_send_bytes = self._ui.config( + b'badserver', b'close-after-send-bytes' + ) + all_send_bytes = process_config(all_send_bytes) + self._cond = ConditionTracker(all_recv_bytes, all_send_bytes) # Need to inherit object so super() works. class badrequesthandler(self.RequestHandlerClass, object): @@ -370,21 +375,11 @@ # is a hgweb.server._httprequesthandler. def process_request(self, socket, address): # Wrap socket in a proxy if we need to count bytes. - if self.close_after_recv_bytes: - close_after_recv_bytes = self.close_after_recv_bytes.pop(0) - else: - close_after_recv_bytes = 0 - if self.close_after_send_bytes: - close_after_send_bytes = self.close_after_send_bytes.pop(0) - else: - close_after_send_bytes = 0 + self._cond.start_next_request() - if close_after_recv_bytes or close_after_send_bytes: + if self._cond.might_close(): socket = socketproxy( - socket, - self.errorlog, - close_after_recv_bytes=close_after_recv_bytes, - close_after_send_bytes=close_after_send_bytes, + socket, self.errorlog, condition_tracked=self._cond ) return super(badserver, self).process_request(socket, address)