worker: support parallelization of functions with return values
authorValentin Gatien-Baron <vgatien-baron@janestreet.com>
Wed, 12 Jun 2019 13:10:52 -0400
changeset 42455 5ca136bbd3f6
parent 42454 0eb8c61c306b
child 42456 87a34c767384
worker: support parallelization of functions with return values Currently worker supports running functions that return a progress iterator. Generalize it to handle function that return a progress iterator then a return value. It's unused in this commit, but will be used in the next one. Differential Revision: https://phab.mercurial-scm.org/D6515
mercurial/worker.py
--- a/mercurial/worker.py	Sun May 19 16:06:06 2019 -0400
+++ b/mercurial/worker.py	Wed Jun 12 13:10:52 2019 -0400
@@ -83,7 +83,8 @@
     benefit = linear - (_STARTUP_COST * workers + linear / workers)
     return benefit >= 0.15
 
-def worker(ui, costperarg, func, staticargs, args, threadsafe=True):
+def worker(ui, costperarg, func, staticargs, args, hasretval=False,
+           threadsafe=True):
     '''run a function, possibly in parallel in multiple worker
     processes.
 
@@ -91,23 +92,27 @@
 
     costperarg - cost of a single task
 
-    func - function to run
+    func - function to run. It is expected to return a progress iterator.
 
     staticargs - arguments to pass to every invocation of the function
 
     args - arguments to split into chunks, to pass to individual
     workers
 
+    hasretval - when True, func and the current function return an progress
+    iterator then a list (encoded as an iterator that yield many (False, ..)
+    then a (True, list)). The resulting list is in the natural order.
+
     threadsafe - whether work items are thread safe and can be executed using
     a thread-based worker. Should be disabled for CPU heavy tasks that don't
     release the GIL.
     '''
     enabled = ui.configbool('worker', 'enabled')
     if enabled and worthwhile(ui, costperarg, len(args), threadsafe=threadsafe):
-        return _platformworker(ui, func, staticargs, args)
+        return _platformworker(ui, func, staticargs, args, hasretval)
     return func(*staticargs + (args,))
 
-def _posixworker(ui, func, staticargs, args):
+def _posixworker(ui, func, staticargs, args, hasretval):
     workers = _numworkers(ui)
     oldhandler = signal.getsignal(signal.SIGINT)
     signal.signal(signal.SIGINT, signal.SIG_IGN)
@@ -157,7 +162,8 @@
     ui.flush()
     parentpid = os.getpid()
     pipes = []
-    for pargs in partition(args, workers):
+    retvals = []
+    for i, pargs in enumerate(partition(args, workers)):
         # Every worker gets its own pipe to send results on, so we don't have to
         # implement atomic writes larger than PIPE_BUF. Each forked process has
         # its own pipe's descriptors in the local variables, and the parent
@@ -165,6 +171,7 @@
         # care what order they're in).
         rfd, wfd = os.pipe()
         pipes.append((rfd, wfd))
+        retvals.append(None)
         # make sure we use os._exit in all worker code paths. otherwise the
         # worker may do some clean-ups which could cause surprises like
         # deadlock. see sshpeer.cleanup for example.
@@ -185,7 +192,7 @@
                         os.close(w)
                     os.close(rfd)
                     for result in func(*(staticargs + (pargs,))):
-                        os.write(wfd, util.pickle.dumps(result))
+                        os.write(wfd, util.pickle.dumps((i, result)))
                     return 0
 
                 ret = scmutil.callcatch(ui, workerfunc)
@@ -219,7 +226,11 @@
         while openpipes > 0:
             for key, events in selector.select():
                 try:
-                    yield util.pickle.load(key.fileobj)
+                    i, res = util.pickle.load(key.fileobj)
+                    if hasretval and res[0]:
+                        retvals[i] = res[1]
+                    else:
+                        yield res
                 except EOFError:
                     selector.unregister(key.fileobj)
                     key.fileobj.close()
@@ -237,6 +248,8 @@
         if status < 0:
             os.kill(os.getpid(), -status)
         sys.exit(status)
+    if hasretval:
+        yield True, sum(retvals, [])
 
 def _posixexitstatus(code):
     '''convert a posix exit status into the same form returned by
@@ -248,7 +261,7 @@
     elif os.WIFSIGNALED(code):
         return -os.WTERMSIG(code)
 
-def _windowsworker(ui, func, staticargs, args):
+def _windowsworker(ui, func, staticargs, args, hasretval):
     class Worker(threading.Thread):
         def __init__(self, taskqueue, resultqueue, func, staticargs, *args,
                      **kwargs):
@@ -268,9 +281,9 @@
             try:
                 while not self._taskqueue.empty():
                     try:
-                        args = self._taskqueue.get_nowait()
+                        i, args = self._taskqueue.get_nowait()
                         for res in self._func(*self._staticargs + (args,)):
-                            self._resultqueue.put(res)
+                            self._resultqueue.put((i, res))
                             # threading doesn't provide a native way to
                             # interrupt execution. handle it manually at every
                             # iteration.
@@ -305,9 +318,11 @@
     workers = _numworkers(ui)
     resultqueue = pycompat.queue.Queue()
     taskqueue = pycompat.queue.Queue()
+    retvals = []
     # partition work to more pieces than workers to minimize the chance
     # of uneven distribution of large tasks between the workers
-    for pargs in partition(args, workers * 20):
+    for pargs in enumerate(partition(args, workers * 20)):
+        retvals.append(None)
         taskqueue.put(pargs)
     for _i in range(workers):
         t = Worker(taskqueue, resultqueue, func, staticargs)
@@ -316,7 +331,11 @@
     try:
         while len(threads) > 0:
             while not resultqueue.empty():
-                yield resultqueue.get()
+                (i, res) = resultqueue.get()
+                if hasretval and res[0]:
+                    retvals[i] = res[1]
+                else:
+                    yield res
             threads[0].join(0.05)
             finishedthreads = [_t for _t in threads if not _t.is_alive()]
             for t in finishedthreads:
@@ -327,7 +346,13 @@
         trykillworkers()
         raise
     while not resultqueue.empty():
-        yield resultqueue.get()
+        (i, res) = resultqueue.get()
+        if hasretval and res[0]:
+            retvals[i] = res[1]
+        else:
+            yield res
+    if hasretval:
+        yield True, sum(retvals, [])
 
 if pycompat.iswindows:
     _platformworker = _windowsworker