diff --git a/tests/testlib/badserverext.py b/tests/testlib/badserverext.py --- a/tests/testlib/badserverext.py +++ b/tests/testlib/badserverext.py @@ -31,10 +31,16 @@ If defined, close the client socket after sending this many bytes. (The value is a list, multiple values can use used to close a series of requests request) + +close-after-send-patterns + If defined, close the client socket after the configured regexp is seen. + (The value is a list, multiple values can use used to close a series of requests + request) """ from __future__ import absolute_import +import re import socket from mercurial import ( @@ -64,20 +70,33 @@ ) configitem( b'badserver', + b'close-after-send-patterns', + default=b'', +) +configitem( + b'badserver', b'close-before-accept', default=False, ) class ConditionTracker(object): - def __init__(self, close_after_recv_bytes, close_after_send_bytes): + def __init__( + self, + close_after_recv_bytes, + close_after_send_bytes, + close_after_send_patterns, + ): self._all_close_after_recv_bytes = close_after_recv_bytes self._all_close_after_send_bytes = close_after_send_bytes + self._all_close_after_send_patterns = close_after_send_patterns self.target_recv_bytes = None self.remaining_recv_bytes = None self.target_send_bytes = None self.remaining_send_bytes = None + self.send_pattern = None + self.send_data = b'' def start_next_request(self): """move to the next set of close condition""" @@ -87,6 +106,7 @@ else: self.target_recv_bytes = None self.remaining_recv_bytes = None + if self._all_close_after_send_bytes: self.target_send_bytes = self._all_close_after_send_bytes.pop(0) self.remaining_send_bytes = self.target_send_bytes @@ -94,12 +114,20 @@ self.target_send_bytes = None self.remaining_send_bytes = None + self.send_data = b'' + if self._all_close_after_send_patterns: + self.send_pattern = self._all_close_after_send_patterns.pop(0) + else: + self.send_pattern = None + def might_close(self): """True, if any processing will be needed""" if self.remaining_recv_bytes is not None: return True if self.remaining_send_bytes is not None: return True + if self.send_pattern is not None: + return True return False def forward_write(self, obj, method, data, *args, **kwargs): @@ -108,11 +136,19 @@ When the condition are met the socket is closed """ remaining = self.remaining_send_bytes + pattern = self.send_pattern orig = object.__getattribute__(obj, '_orig') bmethod = method.encode('ascii') func = getattr(orig, method) + if pattern: + self.send_data += data + pieces = pattern.split(self.send_data, maxsplit=1) + if len(pieces) > 1: + dropped = len(pieces[-1]) + remaining = len(data) - dropped + if remaining: remaining = max(0, remaining) @@ -131,16 +167,9 @@ if remaining is None: obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data)) else: - obj._writelog( - b'%s(%d from %d) -> (%d) %s' - % ( - bmethod, - len(newdata), - len(data), - remaining, - newdata, - ) - ) + msg = b'%s(%d from %d) -> (%d) %s' + msg %= (bmethod, len(newdata), len(data), remaining, newdata) + obj._writelog(msg) if remaining is not None and remaining <= 0: obj._writelog(b'write limit reached; closing socket') @@ -305,12 +334,23 @@ self._close() -def process_config(value): +def process_bytes_config(value): parts = value.split(b',') integers = [int(v) for v in parts if v] return [v if v else None for v in integers] +def process_pattern_config(value): + patterns = [] + for p in value.split(b','): + if not p: + p = None + else: + p = re.compile(p, re.DOTALL | re.MULTILINE) + patterns.append(p) + return patterns + + def extsetup(ui): # Change the base HTTP server class so various events can be performed. # See SocketServer.BaseServer for how the specially named methods work. @@ -322,12 +362,20 @@ all_recv_bytes = self._ui.config( b'badserver', b'close-after-recv-bytes' ) - all_recv_bytes = process_config(all_recv_bytes) + all_recv_bytes = process_bytes_config(all_recv_bytes) all_send_bytes = self._ui.config( b'badserver', b'close-after-send-bytes' ) - all_send_bytes = process_config(all_send_bytes) - self._cond = ConditionTracker(all_recv_bytes, all_send_bytes) + all_send_bytes = process_bytes_config(all_send_bytes) + all_send_patterns = self._ui.config( + b'badserver', b'close-after-send-patterns' + ) + all_send_patterns = process_pattern_config(all_send_patterns) + self._cond = ConditionTracker( + all_recv_bytes, + all_send_bytes, + all_send_patterns, + ) # Need to inherit object so super() works. class badrequesthandler(self.RequestHandlerClass, object):