mercurial/templater.py
changeset 10852 0d50586a9d31
parent 10850 a63391e26284
child 10853 b6f6d9fd53d6
--- a/mercurial/templater.py	Mon Apr 05 15:25:08 2010 -0500
+++ b/mercurial/templater.py	Mon Apr 05 15:25:08 2010 -0500
@@ -17,13 +17,15 @@
     if isinstance(thing, str):
         yield thing
     elif not hasattr(thing, '__iter__'):
-        yield str(thing)
-    elif thing is not None:
+        if i is not None:
+            yield str(thing)
+    else:
         for i in thing:
             if isinstance(i, str):
                 yield i
             elif not hasattr(i, '__iter__'):
-                yield str(i)
+                if i is not None:
+                    yield str(i)
             elif i is not None:
                 for j in _flatten(i):
                     yield j
@@ -64,77 +66,85 @@
         self._defaults = defaults
         self._cache = {}
 
-    def process(self, t, mapping):
-        '''Perform expansion. t is name of map element to expand. mapping contains
-        added elements for use during expansion. Is a generator.'''
+    def _load(self, t):
+        '''load, parse, and cache a template'''
         if t not in self._cache:
             self._cache[t] = self._parse(self._loader(t))
-        return _flatten(self._process(self._cache[t], mapping))
+        return self._cache[t]
+
+    def process(self, t, mapping):
+        '''Perform expansion. t is name of map element to expand.
+        mapping contains added elements for use during expansion. Is a
+        generator.'''
+
+        return _flatten(self._process(self._load(t), mapping))
+
+    def _get(self, mapping, key):
+        v = mapping.get(key)
+        if v is None:
+            v = self._defaults.get(key, '')
+        if hasattr(v, '__call__'):
+            v = v(**mapping)
+        return v
+
+    def _raw(self, mapping, x):
+        return x
+
+    def _filter(self, mapping, parts):
+        filters, val = parts
+        x = self._get(mapping, val)
+        for f in filters:
+            x = f(x)
+        return x
+
+    def _format(self, mapping, args):
+        key, parsed = args
+        v = self._get(mapping, key)
+        if not hasattr(v, '__iter__'):
+            raise SyntaxError(_("error expanding '%s%%%s'")
+                              % (key, format))
+        lm = mapping.copy()
+        for i in v:
+            if isinstance(i, dict):
+                lm.update(i)
+                yield self._process(parsed, lm)
+            else:
+                # v is not an iterable of dicts, this happen when 'key'
+                # has been fully expanded already and format is useless.
+                # If so, return the expanded value.
+                yield i
 
     def _parse(self, tmpl):
         '''preparse a template'''
 
-        defget = self._defaults.get
-        def getter(mapping, key):
-            v = mapping.get(key)
-            if v is None:
-                v = defget(key, '')
-            if hasattr(v, '__call__'):
-                v = v(**mapping)
-            return v
-
-        def raw(mapping, x):
-            return x
-        def filt(mapping, parts):
-            filters, val = parts
-            x = getter(mapping, val)
-            for f in filters:
-                x = f(x)
-            return x
-        def formatter(mapping, args):
-            key, format = args
-            v = getter(mapping, key)
-            if not hasattr(v, '__iter__'):
-                raise SyntaxError(_("error expanding '%s%%%s'")
-                                  % (key, format))
-            lm = mapping.copy()
-            for i in v:
-                if isinstance(i, dict):
-                    lm.update(i)
-                    yield self.process(format, lm)
-                else:
-                    # v is not an iterable of dicts, this happen when 'key'
-                    # has been fully expanded already and format is useless.
-                    # If so, return the expanded value.
-                    yield i
-
         parsed = []
         pos, stop = 0, len(tmpl)
         while pos < stop:
             n = tmpl.find('{', pos)
             if n < 0:
-                parsed.append((raw, tmpl[pos:stop]))
+                parsed.append((self._raw, tmpl[pos:stop]))
                 break
             if n > 0 and tmpl[n - 1] == '\\':
                 # escaped
-                parsed.append((raw, tmpl[pos:n + 1]))
+                parsed.append((self._raw, tmpl[pos:n + 1]))
                 pos = n + 1
                 continue
             if n > pos:
-                parsed.append((raw, tmpl[pos:n]))
+                parsed.append((self._raw, tmpl[pos:n]))
 
             pos = n
             n = tmpl.find('}', pos)
             if n < 0:
                 # no closing
-                parsed.append((raw, tmpl[pos:stop]))
+                parsed.append((self._raw, tmpl[pos:stop]))
                 break
 
             expr = tmpl[pos + 1:n]
             pos = n + 1
 
             if '%' in expr:
-                parsed.append((formatter, expr.split('%')))
+                key, t = expr.split('%')
+                parsed.append((self._format, (key, self._load(t))))
             elif '|' in expr:
                 parts = expr.split('|')
                 val = parts[0]
@@ -142,9 +152,9 @@
                     filters = [self._filters[f] for f in parts[1:]]
                 except KeyError, i:
                     raise SyntaxError(_("unknown filter '%s'") % i[0])
-                parsed.append((filt, (filters, val)))
+                parsed.append((self._filter, (filters, val)))
             else:
-                parsed.append((getter, expr))
+                parsed.append((self._get, expr))
 
         return parsed