comparison mercurial/worker.py @ 42455:5ca136bbd3f6

worker: support parallelization of functions with return values Currently worker supports running functions that return a progress iterator. Generalize it to handle function that return a progress iterator then a return value. It's unused in this commit, but will be used in the next one. Differential Revision: https://phab.mercurial-scm.org/D6515
author Valentin Gatien-Baron <vgatien-baron@janestreet.com>
date Wed, 12 Jun 2019 13:10:52 -0400
parents e10adebf8176
children d29db0a0c4eb
comparison
equal deleted inserted replaced
42454:0eb8c61c306b 42455:5ca136bbd3f6
81 linear = costperop * nops 81 linear = costperop * nops
82 workers = _numworkers(ui) 82 workers = _numworkers(ui)
83 benefit = linear - (_STARTUP_COST * workers + linear / workers) 83 benefit = linear - (_STARTUP_COST * workers + linear / workers)
84 return benefit >= 0.15 84 return benefit >= 0.15
85 85
86 def worker(ui, costperarg, func, staticargs, args, threadsafe=True): 86 def worker(ui, costperarg, func, staticargs, args, hasretval=False,
87 threadsafe=True):
87 '''run a function, possibly in parallel in multiple worker 88 '''run a function, possibly in parallel in multiple worker
88 processes. 89 processes.
89 90
90 returns a progress iterator 91 returns a progress iterator
91 92
92 costperarg - cost of a single task 93 costperarg - cost of a single task
93 94
94 func - function to run 95 func - function to run. It is expected to return a progress iterator.
95 96
96 staticargs - arguments to pass to every invocation of the function 97 staticargs - arguments to pass to every invocation of the function
97 98
98 args - arguments to split into chunks, to pass to individual 99 args - arguments to split into chunks, to pass to individual
99 workers 100 workers
101
102 hasretval - when True, func and the current function return an progress
103 iterator then a list (encoded as an iterator that yield many (False, ..)
104 then a (True, list)). The resulting list is in the natural order.
100 105
101 threadsafe - whether work items are thread safe and can be executed using 106 threadsafe - whether work items are thread safe and can be executed using
102 a thread-based worker. Should be disabled for CPU heavy tasks that don't 107 a thread-based worker. Should be disabled for CPU heavy tasks that don't
103 release the GIL. 108 release the GIL.
104 ''' 109 '''
105 enabled = ui.configbool('worker', 'enabled') 110 enabled = ui.configbool('worker', 'enabled')
106 if enabled and worthwhile(ui, costperarg, len(args), threadsafe=threadsafe): 111 if enabled and worthwhile(ui, costperarg, len(args), threadsafe=threadsafe):
107 return _platformworker(ui, func, staticargs, args) 112 return _platformworker(ui, func, staticargs, args, hasretval)
108 return func(*staticargs + (args,)) 113 return func(*staticargs + (args,))
109 114
110 def _posixworker(ui, func, staticargs, args): 115 def _posixworker(ui, func, staticargs, args, hasretval):
111 workers = _numworkers(ui) 116 workers = _numworkers(ui)
112 oldhandler = signal.getsignal(signal.SIGINT) 117 oldhandler = signal.getsignal(signal.SIGINT)
113 signal.signal(signal.SIGINT, signal.SIG_IGN) 118 signal.signal(signal.SIGINT, signal.SIG_IGN)
114 pids, problem = set(), [0] 119 pids, problem = set(), [0]
115 def killworkers(): 120 def killworkers():
155 killworkers() 160 killworkers()
156 oldchldhandler = signal.signal(signal.SIGCHLD, sigchldhandler) 161 oldchldhandler = signal.signal(signal.SIGCHLD, sigchldhandler)
157 ui.flush() 162 ui.flush()
158 parentpid = os.getpid() 163 parentpid = os.getpid()
159 pipes = [] 164 pipes = []
160 for pargs in partition(args, workers): 165 retvals = []
166 for i, pargs in enumerate(partition(args, workers)):
161 # Every worker gets its own pipe to send results on, so we don't have to 167 # Every worker gets its own pipe to send results on, so we don't have to
162 # implement atomic writes larger than PIPE_BUF. Each forked process has 168 # implement atomic writes larger than PIPE_BUF. Each forked process has
163 # its own pipe's descriptors in the local variables, and the parent 169 # its own pipe's descriptors in the local variables, and the parent
164 # process has the full list of pipe descriptors (and it doesn't really 170 # process has the full list of pipe descriptors (and it doesn't really
165 # care what order they're in). 171 # care what order they're in).
166 rfd, wfd = os.pipe() 172 rfd, wfd = os.pipe()
167 pipes.append((rfd, wfd)) 173 pipes.append((rfd, wfd))
174 retvals.append(None)
168 # make sure we use os._exit in all worker code paths. otherwise the 175 # make sure we use os._exit in all worker code paths. otherwise the
169 # worker may do some clean-ups which could cause surprises like 176 # worker may do some clean-ups which could cause surprises like
170 # deadlock. see sshpeer.cleanup for example. 177 # deadlock. see sshpeer.cleanup for example.
171 # override error handling *before* fork. this is necessary because 178 # override error handling *before* fork. this is necessary because
172 # exception (signal) may arrive after fork, before "pid =" assignment 179 # exception (signal) may arrive after fork, before "pid =" assignment
183 for r, w in pipes[:-1]: 190 for r, w in pipes[:-1]:
184 os.close(r) 191 os.close(r)
185 os.close(w) 192 os.close(w)
186 os.close(rfd) 193 os.close(rfd)
187 for result in func(*(staticargs + (pargs,))): 194 for result in func(*(staticargs + (pargs,))):
188 os.write(wfd, util.pickle.dumps(result)) 195 os.write(wfd, util.pickle.dumps((i, result)))
189 return 0 196 return 0
190 197
191 ret = scmutil.callcatch(ui, workerfunc) 198 ret = scmutil.callcatch(ui, workerfunc)
192 except: # parent re-raises, child never returns 199 except: # parent re-raises, child never returns
193 if os.getpid() == parentpid: 200 if os.getpid() == parentpid:
217 try: 224 try:
218 openpipes = len(pipes) 225 openpipes = len(pipes)
219 while openpipes > 0: 226 while openpipes > 0:
220 for key, events in selector.select(): 227 for key, events in selector.select():
221 try: 228 try:
222 yield util.pickle.load(key.fileobj) 229 i, res = util.pickle.load(key.fileobj)
230 if hasretval and res[0]:
231 retvals[i] = res[1]
232 else:
233 yield res
223 except EOFError: 234 except EOFError:
224 selector.unregister(key.fileobj) 235 selector.unregister(key.fileobj)
225 key.fileobj.close() 236 key.fileobj.close()
226 openpipes -= 1 237 openpipes -= 1
227 except IOError as e: 238 except IOError as e:
235 status = cleanup() 246 status = cleanup()
236 if status: 247 if status:
237 if status < 0: 248 if status < 0:
238 os.kill(os.getpid(), -status) 249 os.kill(os.getpid(), -status)
239 sys.exit(status) 250 sys.exit(status)
251 if hasretval:
252 yield True, sum(retvals, [])
240 253
241 def _posixexitstatus(code): 254 def _posixexitstatus(code):
242 '''convert a posix exit status into the same form returned by 255 '''convert a posix exit status into the same form returned by
243 os.spawnv 256 os.spawnv
244 257
246 if os.WIFEXITED(code): 259 if os.WIFEXITED(code):
247 return os.WEXITSTATUS(code) 260 return os.WEXITSTATUS(code)
248 elif os.WIFSIGNALED(code): 261 elif os.WIFSIGNALED(code):
249 return -os.WTERMSIG(code) 262 return -os.WTERMSIG(code)
250 263
251 def _windowsworker(ui, func, staticargs, args): 264 def _windowsworker(ui, func, staticargs, args, hasretval):
252 class Worker(threading.Thread): 265 class Worker(threading.Thread):
253 def __init__(self, taskqueue, resultqueue, func, staticargs, *args, 266 def __init__(self, taskqueue, resultqueue, func, staticargs, *args,
254 **kwargs): 267 **kwargs):
255 threading.Thread.__init__(self, *args, **kwargs) 268 threading.Thread.__init__(self, *args, **kwargs)
256 self._taskqueue = taskqueue 269 self._taskqueue = taskqueue
266 279
267 def run(self): 280 def run(self):
268 try: 281 try:
269 while not self._taskqueue.empty(): 282 while not self._taskqueue.empty():
270 try: 283 try:
271 args = self._taskqueue.get_nowait() 284 i, args = self._taskqueue.get_nowait()
272 for res in self._func(*self._staticargs + (args,)): 285 for res in self._func(*self._staticargs + (args,)):
273 self._resultqueue.put(res) 286 self._resultqueue.put((i, res))
274 # threading doesn't provide a native way to 287 # threading doesn't provide a native way to
275 # interrupt execution. handle it manually at every 288 # interrupt execution. handle it manually at every
276 # iteration. 289 # iteration.
277 if self._interrupted: 290 if self._interrupted:
278 return 291 return
303 return 316 return
304 317
305 workers = _numworkers(ui) 318 workers = _numworkers(ui)
306 resultqueue = pycompat.queue.Queue() 319 resultqueue = pycompat.queue.Queue()
307 taskqueue = pycompat.queue.Queue() 320 taskqueue = pycompat.queue.Queue()
321 retvals = []
308 # partition work to more pieces than workers to minimize the chance 322 # partition work to more pieces than workers to minimize the chance
309 # of uneven distribution of large tasks between the workers 323 # of uneven distribution of large tasks between the workers
310 for pargs in partition(args, workers * 20): 324 for pargs in enumerate(partition(args, workers * 20)):
325 retvals.append(None)
311 taskqueue.put(pargs) 326 taskqueue.put(pargs)
312 for _i in range(workers): 327 for _i in range(workers):
313 t = Worker(taskqueue, resultqueue, func, staticargs) 328 t = Worker(taskqueue, resultqueue, func, staticargs)
314 threads.append(t) 329 threads.append(t)
315 t.start() 330 t.start()
316 try: 331 try:
317 while len(threads) > 0: 332 while len(threads) > 0:
318 while not resultqueue.empty(): 333 while not resultqueue.empty():
319 yield resultqueue.get() 334 (i, res) = resultqueue.get()
335 if hasretval and res[0]:
336 retvals[i] = res[1]
337 else:
338 yield res
320 threads[0].join(0.05) 339 threads[0].join(0.05)
321 finishedthreads = [_t for _t in threads if not _t.is_alive()] 340 finishedthreads = [_t for _t in threads if not _t.is_alive()]
322 for t in finishedthreads: 341 for t in finishedthreads:
323 if t.exception is not None: 342 if t.exception is not None:
324 raise t.exception 343 raise t.exception
325 threads.remove(t) 344 threads.remove(t)
326 except (Exception, KeyboardInterrupt): # re-raises 345 except (Exception, KeyboardInterrupt): # re-raises
327 trykillworkers() 346 trykillworkers()
328 raise 347 raise
329 while not resultqueue.empty(): 348 while not resultqueue.empty():
330 yield resultqueue.get() 349 (i, res) = resultqueue.get()
350 if hasretval and res[0]:
351 retvals[i] = res[1]
352 else:
353 yield res
354 if hasretval:
355 yield True, sum(retvals, [])
331 356
332 if pycompat.iswindows: 357 if pycompat.iswindows:
333 _platformworker = _windowsworker 358 _platformworker = _windowsworker
334 else: 359 else:
335 _platformworker = _posixworker 360 _platformworker = _posixworker