diff --git a/tests/testlib/badserverext.py b/tests/testlib/badserverext.py --- a/tests/testlib/badserverext.py +++ b/tests/testlib/badserverext.py @@ -108,22 +108,25 @@ orig = object.__getattribute__(obj, '_orig') bmethod = method.encode('ascii') func = getattr(orig, method) - # No byte limit on this operation. Call original function. + + if remaining: + remaining = max(0, remaining) + if not remaining: - result = func(data, *args, **kwargs) - obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data)) - return result - - remaining = max(0, remaining) - - if remaining > 0: + newdata = data + else: if remaining < len(data): newdata = data[0:remaining] else: newdata = data + remaining -= len(newdata) + self.remaining_send_bytes = remaining - remaining -= len(newdata) + result = func(newdata, *args, **kwargs) + if remaining is None: + obj._writelog(b'%s(%d) -> %s' % (bmethod, len(data), data)) + else: obj._writelog( b'%s(%d from %d) -> (%d) %s' % ( @@ -135,11 +138,7 @@ ) ) - result = func(newdata, *args, **kwargs) - - self.remaining_send_bytes = remaining - - if remaining <= 0: + if remaining is not None and remaining <= 0: obj._writelog(b'write limit reached; closing socket') object.__getattribute__(obj, '_cond_close')() raise Exception('connection closed after sending N bytes')