Skip to content

Commit 15e71ff

Browse files
Improve type inference for f-strings containing literals (#571)
1 parent 57e6e0d commit 15e71ff

File tree

3 files changed

+100
-6
lines changed

3 files changed

+100
-6
lines changed

docs/changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## Unreleased
44

5-
- Add experimental `@has_extra_keys` decorator for `TypedDictt` types (#568)
5+
- Improve type inference for f-strings containing literals (#571)
6+
- Add experimental `@has_extra_keys` decorator for `TypedDict` types (#568)
67
- Fix crash on recursive type aliases. Recursive type aliases now fall back to `Any` (#565)
78
- Support `in` on objects with only `__getitem__` (#564)
89
- Add support for `except*` (PEP 654) (#562)

pyanalyze/name_check_visitor.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,11 +2579,77 @@ def _visit_comprehension_inner(
25792579
# Literals and displays
25802580

25812581
def visit_JoinedStr(self, node: ast.JoinedStr) -> Value:
2582-
# JoinedStr is the node type for f-strings.
2583-
# Not too much to check here. Perhaps we can add checks that format specifiers
2584-
# are valid.
2585-
self._generic_visit_list(node.values)
2586-
return TypedValue(str)
2582+
elements = self._generic_visit_list(node.values)
2583+
limit = self.options.get_value_for(UnionSimplificationLimit)
2584+
possible_values: List[List[str]] = [[]]
2585+
for elt in elements:
2586+
subvals = list(flatten_values(elt))
2587+
# Bail out if the list of possible values gets too long.
2588+
if len(possible_values) * len(subvals) > limit:
2589+
return TypedValue(str)
2590+
to_add = []
2591+
for subval in subvals:
2592+
if not isinstance(subval, KnownValue):
2593+
return TypedValue(str)
2594+
if not isinstance(subval.val, str):
2595+
return TypedValue(str)
2596+
to_add.append(subval.val)
2597+
possible_values = [
2598+
lst + [new_elt] for lst in possible_values for new_elt in to_add
2599+
]
2600+
return unite_values(*[KnownValue("".join(lst)) for lst in possible_values])
2601+
2602+
def visit_FormattedValue(self, node: ast.FormattedValue) -> Value:
2603+
val = self.visit(node.value)
2604+
format_spec_val = (
2605+
self.visit(node.format_spec) if node.format_spec else KnownValue("")
2606+
)
2607+
if isinstance(format_spec_val, KnownValue) and isinstance(
2608+
format_spec_val.val, str
2609+
):
2610+
format_spec = format_spec_val.val
2611+
else:
2612+
# TODO: statically check whether the format specifier is valid.
2613+
return TypedValue(str)
2614+
possible_vals = []
2615+
for subval in flatten_values(val):
2616+
possible_vals.append(
2617+
self._visit_single_formatted_value(subval, node, format_spec)
2618+
)
2619+
return unite_and_simplify(
2620+
*possible_vals, limit=self.options.get_value_for(UnionSimplificationLimit)
2621+
)
2622+
2623+
def _visit_single_formatted_value(
2624+
self, val: Value, node: ast.FormattedValue, format_spec: str
2625+
) -> Value:
2626+
if not isinstance(val, KnownValue):
2627+
return TypedValue(str)
2628+
output = val.val
2629+
if node.conversion != -1:
2630+
unsupported_conversion = False
2631+
try:
2632+
if node.conversion == ord("a"):
2633+
output = ascii(output)
2634+
elif node.conversion == ord("s"):
2635+
output = str(output)
2636+
elif node.conversion == ord("r"):
2637+
output = repr(output)
2638+
else:
2639+
unsupported_conversion = True
2640+
except Exception:
2641+
# str/repr/ascii failed
2642+
return TypedValue(str)
2643+
if unsupported_conversion:
2644+
raise NotImplementedError(
2645+
f"Unsupported converion specifier {node.conversion}"
2646+
)
2647+
try:
2648+
output = format(output, format_spec)
2649+
except Exception:
2650+
# format failed
2651+
return TypedValue(str)
2652+
return KnownValue(output)
25872653

25882654
def visit_Constant(self, node: ast.Constant) -> Value:
25892655
# replaces Num, Str, etc. in 3.8+

pyanalyze/test_format_strings.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,30 @@ def capybara(x):
478478
def test_complicated_expression(self):
479479
def capybara(x):
480480
"foo %s" % len(x)
481+
482+
483+
class TestFStringLiteral(TestNameCheckVisitorBase):
484+
@assert_passes()
485+
def test_basic(self):
486+
from typing_extensions import Literal, assert_type
487+
488+
def capybara(a: Literal["a"], ab: Literal["a", "b"]):
489+
assert_type(f"a{a}", Literal["aa"])
490+
assert_type(f"a{ab}", Literal["aa", "ab"])
491+
assert_type(f"a{ab}b{ab}", Literal["aaba", "aabb", "abba", "abbb"])
492+
# Make sure we don't infer a 2**16-size union
493+
assert_type(
494+
f"{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}{ab}", str
495+
)
496+
497+
@assert_passes()
498+
def test_conversions(self):
499+
from typing_extensions import Literal, assert_type
500+
501+
def capybara(a: Literal["á"], ab: Literal["á", "ê"]):
502+
assert_type(f"a{a!r}", Literal["a'á'"])
503+
assert_type(f"a{ab!r}", Literal["a'á'", "a'ê'"])
504+
assert_type(f"a{a!s}", Literal["aá"])
505+
assert_type(f"a{ab!s}", Literal["aá", "aê"])
506+
assert_type(f"a{a!a}", Literal["a'\\xe1'"])
507+
assert_type(f"a{ab!a}", Literal["a'\\xe1'", "a'\\xea'"])

0 commit comments

Comments
 (0)