Blob Blame History Raw
From db75314805c2b5df57883f49b7d93c62a079189e Mon Sep 17 00:00:00 2001
From: Jared Garst <cultofjared@gmail.com>
Date: Sun, 2 Oct 2016 07:22:51 -0700
Subject: [PATCH] add format string support

format strings require support for two new nodes:
FormattedValue(expr valu, int? conversion, expr? format_spec)
JoinedStr(expr* values)
---
 astroid/as_string.py              | 15 ++++++++++++++-
 astroid/node_classes.py           | 22 ++++++++++++++++++++++
 astroid/nodes.py                  |  2 ++
 astroid/rebuilder.py              | 12 ++++++++++++
 astroid/tests/unittest_python3.py | 11 +++++++++--
 tox.ini                           |  4 ++--
 6 files changed, 61 insertions(+), 5 deletions(-)

diff --git a/astroid/as_string.py b/astroid/as_string.py
index 3d669c4..8eaebc1 100644
--- a/astroid/as_string.py
+++ b/astroid/as_string.py
@@ -481,6 +481,19 @@ def visit_asyncwith(self, node):
     def visit_asyncfor(self, node):
         return 'async %s' % self.visit_for(node)
 
+    def visit_joinedstr(self, node):
+        # Special treatment for constants,
+        # as we want to join literals not reprs
+        string = ''.join(
+            value.value if type(value).__name__ == 'Const'
+            else value.accept(self)
+            for value in node.values
+        )
+        return "f'%s'" % string
+
+    def visit_formattedvalue(self, node):
+        return '{%s}' % node.value.accept(self)
+
 
 def _import_string(names):
     """return a list of (name, asname) formatted as a string"""
@@ -490,7 +503,7 @@ def _import_string(names):
             _names.append('%s as %s' % (name, asname))
         else:
             _names.append(name)
-    return  ', '.join(_names)
+    return ', '.join(_names)
 
 
 if sys.version_info >= (3, 0):
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index 42fdf7c..9873158 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -1863,6 +1863,28 @@ class DictUnpack(NodeNG):
     """Represents the unpacking of dicts into dicts using PEP 448."""
 
 
+class FormattedValue(bases.NodeNG):
+    """Represents a PEP 498 format string."""
+    _astroid_fields = ('value', 'format_spec')
+    value = None
+    conversion = None
+    format_spec = None
+
+    def postinit(self, value, conversion=None, format_spec=None):
+        self.value = value
+        self.conversion = conversion
+        self.format_spec = format_spec
+
+
+class JoinedStr(bases.NodeNG):
+    """Represents a list of string expressions to be joined."""
+    _astroid_fields = ('values',)
+    value = None
+
+    def postinit(self, values=None):
+        self.values = values
+
+
 # constants ##############################################################
 
 CONST_CLS = {
diff --git a/astroid/nodes.py b/astroid/nodes.py
index 1c279cc..3397294 100644
--- a/astroid/nodes.py
+++ b/astroid/nodes.py
@@ -37,6 +37,7 @@
     TryExcept, TryFinally, Tuple, UnaryOp, While, With, Yield, YieldFrom,
     const_factory,
     AsyncFor, Await, AsyncWith,
+    FormattedValue, JoinedStr,
     # Backwards-compatibility aliases
     Backquote, Discard, AssName, AssAttr, Getattr, CallFunc, From,
     # Node not present in the builtin ast module.
@@ -75,4 +76,5 @@
     UnaryOp,
     While, With,
     Yield, YieldFrom,
+    FormattedValue, JoinedStr,
     )
diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py
index d80033f..91774cf 100644
--- a/astroid/rebuilder.py
+++ b/astroid/rebuilder.py
@@ -984,6 +984,24 @@ def visit_await(self, node, parent):
         return self._visit_with(new.AsyncWith, node, parent,
                                 assign_ctx=assign_ctx)
 
+    def visit_joinedstr(self, node, parent, assign_ctx=None):
+        newnode = new.JoinedStr()
+        newnode.lineno = node.lineno
+        newnode.col_offset = node.col_offset
+        newnode.parent = parent
+        newnode.postinit([self.visit(child, newnode)
+                          for child in node.values])
+        return newnode
+
+    def visit_formattedvalue(self, node, parent, assign_ctx=None):
+        newnode = new.FormattedValue()
+        newnode.lineno = node.lineno
+        newnode.col_offset = node.col_offset
+        newnode.parent = parent
+        newnode.postinit(self.visit(node.value, newnode),
+                         node.conversion,
+                         _visit_or_none(node, 'format_spec', self, newnode, assign_ctx=assign_ctx))
+        return newnode
 
 if sys.version_info >= (3, 0):
     TreeRebuilder = TreeRebuilder3k
diff --git a/astroid/tests/unittest_python3.py b/astroid/tests/unittest_python3.py
index ad8e57b..e555bb4 100644
--- a/astroid/tests/unittest_python3.py
+++ b/astroid/tests/unittest_python3.py
@@ -87,7 +87,7 @@ def test_metaclass_error(self):
     @require_version('3.0')
     def test_metaclass_imported(self):
         astroid = self.builder.string_build(dedent("""
-        from abc import ABCMeta 
+        from abc import ABCMeta
         class Test(metaclass=ABCMeta): pass"""))
         klass = astroid.body[1]
 
@@ -98,7 +98,7 @@ class Test(metaclass=ABCMeta): pass"""))
     @require_version('3.0')
     def test_as_string(self):
         body = dedent("""
-        from abc import ABCMeta 
+        from abc import ABCMeta
         class Test(metaclass=ABCMeta): pass""")
         astroid = self.builder.string_build(body)
         klass = astroid.body[1]
@@ -239,6 +239,13 @@ def test_unpacking_in_dict_getitem(self):
             self.assertIsInstance(value, nodes.Const)
             self.assertEqual(value.value, expected)
 
+    @require_version('3.6')
+    def test_format_string(self):
+        code = "f'{greetings} {person}'"
+        node = extract_node(code)
+        self.assertEqual(node.as_string(), code)
+
+
 
 if __name__ == '__main__':
     unittest.main()