contrib/import-checker.py
branchstable
changeset 27909 3203dfe341f9
parent 27621 39845b064041
child 28330 f3fb24e36d61
--- a/contrib/import-checker.py	Wed Jan 06 11:01:55 2016 -0800
+++ b/contrib/import-checker.py	Sun Jan 17 21:40:21 2016 -0600
@@ -1,4 +1,7 @@
+#!/usr/bin/env python
+
 import ast
+import collections
 import os
 import sys
 
@@ -11,6 +14,8 @@
 # Whitelist of modules that symbols can be directly imported from.
 allowsymbolimports = (
     '__future__',
+    'mercurial.hgweb.common',
+    'mercurial.hgweb.request',
     'mercurial.i18n',
     'mercurial.node',
 )
@@ -35,6 +40,17 @@
 
     return False
 
+def walklocal(root):
+    """Recursively yield all descendant nodes but not in a different scope"""
+    todo = collections.deque(ast.iter_child_nodes(root))
+    yield root, False
+    while todo:
+        node = todo.popleft()
+        newscope = isinstance(node, ast.FunctionDef)
+        if not newscope:
+            todo.extend(ast.iter_child_nodes(node))
+        yield node, newscope
+
 def dotted_name_of_path(path, trimpure=False):
     """Given a relative path to a source file, return its dotted module name.
 
@@ -45,7 +61,7 @@
     >>> dotted_name_of_path('zlibmodule.so')
     'zlib'
     """
-    parts = path.split('/')
+    parts = path.replace(os.sep, '/').split('/')
     parts[-1] = parts[-1].split('.', 1)[0] # remove .py and .so and .ARCH.so
     if parts[-1].endswith('module'):
         parts[-1] = parts[-1][:-6]
@@ -163,9 +179,6 @@
     # consider them stdlib.
     for m in ['msvcrt', '_winreg']:
         yield m
-    # These get missed too
-    for m in 'ctypes', 'email', 'multiprocessing':
-        yield m
     yield 'builtins' # python3 only
     for m in 'fcntl', 'grp', 'pwd', 'termios':  # Unix only
         yield m
@@ -198,11 +211,12 @@
                     or top == libpath and d in ('hgext', 'mercurial')):
                     del dirs[i]
             for name in files:
-                if name == '__init__.py':
-                    continue
                 if not name.endswith(('.py', '.so', '.pyc', '.pyo', '.pyd')):
                     continue
-                full_path = os.path.join(top, name)
+                if name.startswith('__init__.py'):
+                    full_path = top
+                else:
+                    full_path = os.path.join(top, name)
                 rel_path = full_path[len(libpath) + 1:]
                 mod = dotted_name_of_path(rel_path)
                 yield mod
@@ -237,7 +251,7 @@
     >>> sorted(imported_modules(
     ...        'import foo1; from bar import bar1',
     ...        modulename, localmods))
-    ['foo.bar.__init__', 'foo.bar.bar1', 'foo.foo1']
+    ['foo.bar.bar1', 'foo.foo1']
     >>> sorted(imported_modules(
     ...        'from bar.bar1 import name1, name2, name3',
     ...        modulename, localmods))
@@ -284,21 +298,28 @@
                 continue
 
             absname, dottedpath, hassubmod = found
-            yield dottedpath
             if not hassubmod:
+                # "dottedpath" is not a package; must be imported
+                yield dottedpath
                 # examination of "node.names" should be redundant
                 # e.g.: from mercurial.node import nullid, nullrev
                 continue
 
+            modnotfound = False
             prefix = absname + '.'
             for n in node.names:
                 found = fromlocal(prefix + n.name)
                 if not found:
                     # this should be a function or a property of "node.module"
+                    modnotfound = True
                     continue
                 yield found[1]
+            if modnotfound:
+                # "dottedpath" is a package, but imported because of non-module
+                # lookup
+                yield dottedpath
 
-def verify_import_convention(module, source):
+def verify_import_convention(module, source, localmods):
     """Verify imports match our established coding convention.
 
     We have 2 conventions: legacy and modern. The modern convention is in
@@ -311,11 +332,11 @@
     absolute = usingabsolute(root)
 
     if absolute:
-        return verify_modern_convention(module, root)
+        return verify_modern_convention(module, root, localmods)
     else:
         return verify_stdlib_on_own_line(root)
 
-def verify_modern_convention(module, root):
+def verify_modern_convention(module, root, localmods, root_col_offset=0):
     """Verify a file conforms to the modern import convention rules.
 
     The rules of the modern convention are:
@@ -342,6 +363,7 @@
       and readability problems. See `requirealias`.
     """
     topmodule = module.split('.')[0]
+    fromlocal = fromlocalfunc(module, localmods)
 
     # Whether a local/non-stdlib import has been performed.
     seenlocal = False
@@ -352,29 +374,36 @@
     # Relative import levels encountered so far.
     seenlevels = set()
 
-    for node in ast.walk(root):
-        if isinstance(node, ast.Import):
+    for node, newscope in walklocal(root):
+        def msg(fmt, *args):
+            return (fmt % args, node.lineno)
+        if newscope:
+            # Check for local imports in function
+            for r in verify_modern_convention(module, node, localmods,
+                                              node.col_offset + 4):
+                yield r
+        elif isinstance(node, ast.Import):
             # Disallow "import foo, bar" and require separate imports
             # for each module.
             if len(node.names) > 1:
-                yield 'multiple imported names: %s' % ', '.join(
-                    n.name for n in node.names)
+                yield msg('multiple imported names: %s',
+                          ', '.join(n.name for n in node.names))
 
             name = node.names[0].name
             asname = node.names[0].asname
 
             # Ignore sorting rules on imports inside blocks.
-            if node.col_offset == 0:
+            if node.col_offset == root_col_offset:
                 if lastname and name < lastname:
-                    yield 'imports not lexically sorted: %s < %s' % (
-                           name, lastname)
+                    yield msg('imports not lexically sorted: %s < %s',
+                              name, lastname)
 
                 lastname = name
 
             # stdlib imports should be before local imports.
             stdlib = name in stdlib_modules
-            if stdlib and seenlocal and node.col_offset == 0:
-                yield 'stdlib import follows local import: %s' % name
+            if stdlib and seenlocal and node.col_offset == root_col_offset:
+                yield msg('stdlib import follows local import: %s', name)
 
             if not stdlib:
                 seenlocal = True
@@ -382,11 +411,11 @@
             # Import of sibling modules should use relative imports.
             topname = name.split('.')[0]
             if topname == topmodule:
-                yield 'import should be relative: %s' % name
+                yield msg('import should be relative: %s', name)
 
             if name in requirealias and asname != requirealias[name]:
-                yield '%s module must be "as" aliased to %s' % (
-                       name, requirealias[name])
+                yield msg('%s module must be "as" aliased to %s',
+                          name, requirealias[name])
 
         elif isinstance(node, ast.ImportFrom):
             # Resolve the full imported module name.
@@ -400,39 +429,49 @@
 
                 topname = fullname.split('.')[0]
                 if topname == topmodule:
-                    yield 'import should be relative: %s' % fullname
+                    yield msg('import should be relative: %s', fullname)
 
             # __future__ is special since it needs to come first and use
             # symbol import.
             if fullname != '__future__':
                 if not fullname or fullname in stdlib_modules:
-                    yield 'relative import of stdlib module'
+                    yield msg('relative import of stdlib module')
                 else:
                     seenlocal = True
 
             # Direct symbol import is only allowed from certain modules and
             # must occur before non-symbol imports.
-            if node.module and node.col_offset == 0:
-                if fullname not in allowsymbolimports:
-                    yield 'direct symbol import from %s' % fullname
+            if node.module and node.col_offset == root_col_offset:
+                found = fromlocal(node.module, node.level)
+                if found and found[2]:  # node.module is a package
+                    prefix = found[0] + '.'
+                    symbols = [n.name for n in node.names
+                               if not fromlocal(prefix + n.name)]
+                else:
+                    symbols = [n.name for n in node.names]
 
-                if seennonsymbolrelative:
-                    yield ('symbol import follows non-symbol import: %s' %
-                           fullname)
+                if symbols and fullname not in allowsymbolimports:
+                    yield msg('direct symbol import %s from %s',
+                              ', '.join(symbols), fullname)
+
+                if symbols and seennonsymbolrelative:
+                    yield msg('symbol import follows non-symbol import: %s',
+                              fullname)
 
             if not node.module:
                 assert node.level
                 seennonsymbolrelative = True
 
                 # Only allow 1 group per level.
-                if node.level in seenlevels and node.col_offset == 0:
-                    yield 'multiple "from %s import" statements' % (
-                           '.' * node.level)
+                if (node.level in seenlevels
+                    and node.col_offset == root_col_offset):
+                    yield msg('multiple "from %s import" statements',
+                              '.' * node.level)
 
                 # Higher-level groups come before lower-level groups.
                 if any(node.level > l for l in seenlevels):
-                    yield 'higher-level import should come first: %s' % (
-                           fullname)
+                    yield msg('higher-level import should come first: %s',
+                              fullname)
 
                 seenlevels.add(node.level)
 
@@ -442,14 +481,14 @@
 
             for n in node.names:
                 if lastentryname and n.name < lastentryname:
-                    yield 'imports from %s not lexically sorted: %s < %s' % (
-                           fullname, n.name, lastentryname)
+                    yield msg('imports from %s not lexically sorted: %s < %s',
+                              fullname, n.name, lastentryname)
 
                 lastentryname = n.name
 
                 if n.name in requirealias and n.asname != requirealias[n.name]:
-                    yield '%s from %s must be "as" aliased to %s' % (
-                          n.name, fullname, requirealias[n.name])
+                    yield msg('%s from %s must be "as" aliased to %s',
+                              n.name, fullname, requirealias[n.name])
 
 def verify_stdlib_on_own_line(root):
     """Given some python source, verify that stdlib imports are done
@@ -460,7 +499,7 @@
     http://bugs.python.org/issue19510.
 
     >>> list(verify_stdlib_on_own_line(ast.parse('import sys, foo')))
-    ['mixed imports\\n   stdlib:    sys\\n   relative:  foo']
+    [('mixed imports\\n   stdlib:    sys\\n   relative:  foo', 1)]
     >>> list(verify_stdlib_on_own_line(ast.parse('import sys, os')))
     []
     >>> list(verify_stdlib_on_own_line(ast.parse('import foo, bar')))
@@ -474,7 +513,7 @@
             if from_stdlib[True] and from_stdlib[False]:
                 yield ('mixed imports\n   stdlib:    %s\n   relative:  %s' %
                        (', '.join(sorted(from_stdlib[True])),
-                        ', '.join(sorted(from_stdlib[False]))))
+                        ', '.join(sorted(from_stdlib[False]))), node.lineno)
 
 class CircularImport(Exception):
     pass
@@ -546,9 +585,9 @@
         src = f.read()
         used_imports[modname] = sorted(
             imported_modules(src, modname, localmods, ignore_nested=True))
-        for error in verify_import_convention(modname, src):
+        for error, lineno in verify_import_convention(modname, src, localmods):
             any_errors = True
-            print source_path, error
+            print '%s:%d: %s' % (source_path, lineno, error)
         f.close()
     cycles = find_cycles(used_imports)
     if cycles: