extensions: extract partial application into a bind() function
authorEric Sumner <ericsumner@fb.com>
Wed, 15 Apr 2015 12:18:05 -0400
changeset 24734 fb6cb1b82f4f
parent 24733 c00e4338fa4b
child 24735 07200e3332a1
extensions: extract partial application into a bind() function We use partial function application for wrapping existing Mercurial functions, and it's implemented separately each time. This patch extracts the partial application into a new bind() function that can be used on its own by extensions when appropriate. In particular, the evolve extension needs to wrap functions in the various bundle2 processing dictionaries, which the pre-existing methods don't support.
mercurial/extensions.py
--- a/mercurial/extensions.py	Tue Apr 14 11:44:04 2015 -0400
+++ b/mercurial/extensions.py	Wed Apr 15 12:18:05 2015 -0400
@@ -152,6 +152,18 @@
     else:
         _aftercallbacks.setdefault(extension, []).append(callback)
 
+def bind(func, *args):
+    '''Partial function application
+
+      Returns a new function that is the partial application of args and kwargs
+      to func.  For example,
+
+          f(1, 2, bar=3) === bind(f, 1)(2, bar=3)'''
+    assert callable(func)
+    def closure(*a, **kw):
+        return func(*(args + a), **kw)
+    return closure
+
 def wrapcommand(table, command, wrapper, synopsis=None, docstring=None):
     '''Wrap the command named `command' in table
 
@@ -189,9 +201,7 @@
             break
 
     origfn = entry[0]
-    def wrap(*args, **kwargs):
-        return util.checksignature(wrapper)(
-            util.checksignature(origfn), *args, **kwargs)
+    wrap = bind(util.checksignature(wrapper), util.checksignature(origfn))
 
     wrap.__module__ = getattr(origfn, '__module__')
 
@@ -241,12 +251,10 @@
     subclass trick.
     '''
     assert callable(wrapper)
-    def wrap(*args, **kwargs):
-        return wrapper(origfn, *args, **kwargs)
 
     origfn = getattr(container, funcname)
     assert callable(origfn)
-    setattr(container, funcname, wrap)
+    setattr(container, funcname, bind(wrapper, origfn))
     return origfn
 
 def _disabledpaths(strip_init=False):