diff --git a/mercurial/worker.py b/mercurial/worker.py --- a/mercurial/worker.py +++ b/mercurial/worker.py @@ -83,7 +83,8 @@ benefit = linear - (_STARTUP_COST * workers + linear / workers) return benefit >= 0.15 -def worker(ui, costperarg, func, staticargs, args, threadsafe=True): +def worker(ui, costperarg, func, staticargs, args, hasretval=False, + threadsafe=True): '''run a function, possibly in parallel in multiple worker processes. @@ -91,23 +92,27 @@ costperarg - cost of a single task - func - function to run + func - function to run. It is expected to return a progress iterator. staticargs - arguments to pass to every invocation of the function args - arguments to split into chunks, to pass to individual workers + hasretval - when True, func and the current function return an progress + iterator then a list (encoded as an iterator that yield many (False, ..) + then a (True, list)). The resulting list is in the natural order. + threadsafe - whether work items are thread safe and can be executed using a thread-based worker. Should be disabled for CPU heavy tasks that don't release the GIL. ''' enabled = ui.configbool('worker', 'enabled') if enabled and worthwhile(ui, costperarg, len(args), threadsafe=threadsafe): - return _platformworker(ui, func, staticargs, args) + return _platformworker(ui, func, staticargs, args, hasretval) return func(*staticargs + (args,)) -def _posixworker(ui, func, staticargs, args): +def _posixworker(ui, func, staticargs, args, hasretval): workers = _numworkers(ui) oldhandler = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -157,7 +162,8 @@ ui.flush() parentpid = os.getpid() pipes = [] - for pargs in partition(args, workers): + retvals = [] + for i, pargs in enumerate(partition(args, workers)): # Every worker gets its own pipe to send results on, so we don't have to # implement atomic writes larger than PIPE_BUF. Each forked process has # its own pipe's descriptors in the local variables, and the parent @@ -165,6 +171,7 @@ # care what order they're in). rfd, wfd = os.pipe() pipes.append((rfd, wfd)) + retvals.append(None) # make sure we use os._exit in all worker code paths. otherwise the # worker may do some clean-ups which could cause surprises like # deadlock. see sshpeer.cleanup for example. @@ -185,7 +192,7 @@ os.close(w) os.close(rfd) for result in func(*(staticargs + (pargs,))): - os.write(wfd, util.pickle.dumps(result)) + os.write(wfd, util.pickle.dumps((i, result))) return 0 ret = scmutil.callcatch(ui, workerfunc) @@ -219,7 +226,11 @@ while openpipes > 0: for key, events in selector.select(): try: - yield util.pickle.load(key.fileobj) + i, res = util.pickle.load(key.fileobj) + if hasretval and res[0]: + retvals[i] = res[1] + else: + yield res except EOFError: selector.unregister(key.fileobj) key.fileobj.close() @@ -237,6 +248,8 @@ if status < 0: os.kill(os.getpid(), -status) sys.exit(status) + if hasretval: + yield True, sum(retvals, []) def _posixexitstatus(code): '''convert a posix exit status into the same form returned by @@ -248,7 +261,7 @@ elif os.WIFSIGNALED(code): return -os.WTERMSIG(code) -def _windowsworker(ui, func, staticargs, args): +def _windowsworker(ui, func, staticargs, args, hasretval): class Worker(threading.Thread): def __init__(self, taskqueue, resultqueue, func, staticargs, *args, **kwargs): @@ -268,9 +281,9 @@ try: while not self._taskqueue.empty(): try: - args = self._taskqueue.get_nowait() + i, args = self._taskqueue.get_nowait() for res in self._func(*self._staticargs + (args,)): - self._resultqueue.put(res) + self._resultqueue.put((i, res)) # threading doesn't provide a native way to # interrupt execution. handle it manually at every # iteration. @@ -305,9 +318,11 @@ workers = _numworkers(ui) resultqueue = pycompat.queue.Queue() taskqueue = pycompat.queue.Queue() + retvals = [] # partition work to more pieces than workers to minimize the chance # of uneven distribution of large tasks between the workers - for pargs in partition(args, workers * 20): + for pargs in enumerate(partition(args, workers * 20)): + retvals.append(None) taskqueue.put(pargs) for _i in range(workers): t = Worker(taskqueue, resultqueue, func, staticargs) @@ -316,7 +331,11 @@ try: while len(threads) > 0: while not resultqueue.empty(): - yield resultqueue.get() + (i, res) = resultqueue.get() + if hasretval and res[0]: + retvals[i] = res[1] + else: + yield res threads[0].join(0.05) finishedthreads = [_t for _t in threads if not _t.is_alive()] for t in finishedthreads: @@ -327,7 +346,13 @@ trykillworkers() raise while not resultqueue.empty(): - yield resultqueue.get() + (i, res) = resultqueue.get() + if hasretval and res[0]: + retvals[i] = res[1] + else: + yield res + if hasretval: + yield True, sum(retvals, []) if pycompat.iswindows: _platformworker = _windowsworker