diff --git a/tests/testlib/badserverext.py b/tests/testlib/badserverext.py --- a/tests/testlib/badserverext.py +++ b/tests/testlib/badserverext.py @@ -70,6 +70,11 @@ self._all_close_after_recv_bytes = close_after_recv_bytes self._all_close_after_send_bytes = close_after_send_bytes + self.target_recv_bytes = None + self.remaining_recv_bytes = None + self.target_send_bytes = None + self.remaining_send_bytes = None + def start_next_request(self): """move to the next set of close condition""" if self._all_close_after_recv_bytes: @@ -93,6 +98,54 @@ return True return False + def forward_write(self, obj, method, data, *args, **kwargs): + """call an underlying write function until condition are met + + When the condition are met the socket is closed + """ + remaining = self.remaining_send_bytes + + orig = object.__getattribute__(obj, '_orig') + bmethod = method.encode('ascii') + func = getattr(orig, method) + # No byte limit on this operation. Call original function. + if not remaining: + result = func(data, *args, **kwargs) + obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data)) + return result + + remaining = max(0, remaining) + + if remaining > 0: + if remaining < len(data): + newdata = data[0:remaining] + else: + newdata = data + + remaining -= len(newdata) + + obj._writelog( + b'%s(%d from %d) -> (%d) %s' + % ( + bmethod, + len(newdata), + len(data), + remaining, + newdata, + ) + ) + + result = func(newdata, *args, **kwargs) + + self.remaining_send_bytes = remaining + + if remaining <= 0: + obj._writelog(b'write limit reached; closing socket') + object.__getattribute__(obj, '_cond_close')() + raise Exception('connection closed after sending N bytes') + + return result + # We can't adjust __class__ on a socket instance. So we define a proxy type. class socketproxy(object): @@ -131,37 +184,11 @@ return fileobjectproxy(f, logfp, cond) def sendall(self, data, flags=0): - remaining = object.__getattribute__(self, '_cond').remaining_send_bytes - - # No read limit. Call original function. - if not remaining: - result = object.__getattribute__(self, '_orig').sendall(data, flags) - self._writelog(b'sendall(%d) -> %s' % (len(data), data)) - return result - - if len(data) > remaining: - newdata = data[0:remaining] - else: - newdata = data - - remaining -= len(newdata) + cond = object.__getattribute__(self, '_cond') + return cond.forward_write(self, 'sendall', data, flags) - result = object.__getattribute__(self, '_orig').sendall(newdata, flags) - - self._writelog( - b'sendall(%d from %d) -> (%d) %s' - % (len(newdata), len(data), remaining, newdata) - ) - - object.__getattribute__(self, '_cond').remaining_send_bytes = remaining - - if remaining <= 0: - self._writelog(b'write limit reached; closing socket') - object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR) - - raise Exception('connection closed after sending N bytes') - - return result + def _cond_close(self): + object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR) # We can't adjust __class__ on socket._fileobject, so define a proxy. @@ -174,7 +201,14 @@ object.__setattr__(self, '_cond', condition_tracked) def __getattribute__(self, name): - if name in ('_close', 'read', 'readline', 'write', '_writelog'): + if name in ( + '_close', + 'read', + 'readline', + 'write', + '_writelog', + '_cond_close', + ): return object.__getattribute__(self, name) return getattr(object.__getattribute__(self, '_orig'), name) @@ -280,37 +314,11 @@ return result def write(self, data): - remaining = object.__getattribute__(self, '_cond').remaining_send_bytes - - # No byte limit on this operation. Call original function. - if not remaining: - result = object.__getattribute__(self, '_orig').write(data) - self._writelog(b'write(%d) -> %s' % (len(data), data)) - return result - - if len(data) > remaining: - newdata = data[0:remaining] - else: - newdata = data - - remaining -= len(newdata) + cond = object.__getattribute__(self, '_cond') + return cond.forward_write(self, 'write', data) - result = object.__getattribute__(self, '_orig').write(newdata) - - self._writelog( - b'write(%d from %d) -> (%d) %s' - % (len(newdata), len(data), remaining, newdata) - ) - - object.__getattribute__(self, '_cond').remaining_send_bytes = remaining - - if remaining <= 0: - self._writelog(b'write limit reached; closing socket') - self._close() - - raise Exception('connection closed after sending N bytes') - - return result + def _cond_close(self): + self._close() def process_config(value):