Skip to content

Commit 9446c35

Browse files
Support for typechecking generators (#529)
1 parent 651dd47 commit 9446c35

File tree

8 files changed

+186
-24
lines changed

8 files changed

+186
-24
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Type check `yield`, `yield from`, and `return` nodes in generators (#529)
56
- Type check calls to comparison operators (#527)
67
- Retrieve attributes from stubs even when a runtime
78
equivalent exists (#526)

pyanalyze/annotations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
)
7777
from .value import (
7878
_HashableValue,
79+
TypeVarLike,
7980
annotate_value,
8081
AnnotatedValue,
8182
AnySource,
@@ -610,12 +611,12 @@ def _type_from_runtime(
610611
return AnyValue(AnySource.error)
611612

612613

613-
def make_type_var_value(tv: TypeVar, ctx: Context) -> TypeVarValue:
614+
def make_type_var_value(tv: TypeVarLike, ctx: Context) -> TypeVarValue:
614615
if tv.__bound__ is not None:
615616
bound = _type_from_runtime(tv.__bound__, ctx)
616617
else:
617618
bound = None
618-
if tv.__constraints__:
619+
if isinstance(tv, TypeVar) and tv.__constraints__:
619620
constraints = tuple(
620621
_type_from_runtime(constraint, ctx) for constraint in tv.__constraints__
621622
)

pyanalyze/error_code.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class ErrorCode(enum.Enum):
9898
already_declared = 78
9999
invalid_annotated_assignment = 79
100100
unused_assignment = 80
101+
incompatible_yield = 81
101102

102103

103104
# Allow testing unannotated functions without too much fuss
@@ -217,6 +218,7 @@ class ErrorCode(enum.Enum):
217218
ErrorCode.already_declared: "Name is already declared",
218219
ErrorCode.invalid_annotated_assignment: "Invalid annotated assignment",
219220
ErrorCode.unused_assignment: "Assigned value is never used",
221+
ErrorCode.incompatible_yield: "Incompatible yield type",
220222
}
221223

222224

pyanalyze/functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
from .options import Options, PyObjectSequenceOption
2424
from .signature import ParameterKind, Signature, SigParameter
2525
from .stacked_scopes import Composite
26+
from .typevar import resolve_bounds_map
2627
from .value import (
2728
AnySource,
2829
AnyValue,
2930
CallableValue,
3031
CanAssignContext,
3132
CanAssignError,
3233
GenericValue,
34+
get_tv_map,
3335
KnownValue,
3436
SubclassValue,
3537
TypedValue,
@@ -43,6 +45,15 @@
4345
FunctionNode = Union[FunctionDefNode, ast.Lambda]
4446
IMPLICIT_CLASSMETHODS = ("__init_subclass__", "__new__")
4547

48+
YieldT = TypeVar("YieldT")
49+
SendT = TypeVar("SendT")
50+
ReturnT = TypeVar("ReturnT")
51+
IterableValue = GenericValue(collections.abc.Iterable, [TypeVarValue(YieldT)])
52+
GeneratorValue = GenericValue(
53+
collections.abc.Generator,
54+
[TypeVarValue(YieldT), TypeVarValue(SendT), TypeVarValue(ReturnT)],
55+
)
56+
4657

4758
class AsyncFunctionKind(enum.Enum):
4859
non_async = 0
@@ -78,6 +89,41 @@ class FunctionInfo:
7889
return_annotation: Optional[Value]
7990
potential_function: Optional[object]
8091

92+
def get_generator_yield_type(self, ctx: CanAssignContext) -> Value:
93+
if self.return_annotation is None:
94+
return AnyValue(AnySource.unannotated)
95+
can_assign = IterableValue.can_assign(self.return_annotation, ctx)
96+
if isinstance(can_assign, CanAssignError):
97+
return AnyValue(AnySource.error)
98+
tv_map, _ = resolve_bounds_map(can_assign, ctx)
99+
return tv_map.get(YieldT, AnyValue(AnySource.generic_argument))
100+
101+
def get_generator_send_type(self, ctx: CanAssignContext) -> Value:
102+
if self.return_annotation is None:
103+
return AnyValue(AnySource.unannotated)
104+
tv_map = get_tv_map(GeneratorValue, self.return_annotation, ctx)
105+
if not isinstance(tv_map, CanAssignError):
106+
return tv_map.get(SendT, AnyValue(AnySource.generic_argument))
107+
# If the return annotation is a non-Generator Iterable, assume the send
108+
# type is None.
109+
can_assign = IterableValue.can_assign(self.return_annotation, ctx)
110+
if isinstance(can_assign, CanAssignError):
111+
return AnyValue(AnySource.error)
112+
return KnownValue(None)
113+
114+
def get_generator_return_type(self, ctx: CanAssignContext) -> Value:
115+
if self.return_annotation is None:
116+
return AnyValue(AnySource.unannotated)
117+
tv_map = get_tv_map(GeneratorValue, self.return_annotation, ctx)
118+
if not isinstance(tv_map, CanAssignError):
119+
return tv_map.get(ReturnT, AnyValue(AnySource.generic_argument))
120+
# If the return annotation is a non-Generator Iterable, assume the return
121+
# type is None.
122+
can_assign = IterableValue.can_assign(self.return_annotation, ctx)
123+
if isinstance(can_assign, CanAssignError):
124+
return AnyValue(AnySource.error)
125+
return KnownValue(None)
126+
81127

82128
@dataclass
83129
class FunctionResult:

pyanalyze/name_check_visitor.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@
8282
FunctionInfo,
8383
FunctionNode,
8484
FunctionResult,
85+
GeneratorValue,
8586
IMPLICIT_CLASSMETHODS,
87+
IterableValue,
88+
ReturnT,
89+
SendT,
90+
YieldT,
8691
)
8792
from .options import (
8893
add_arguments,
@@ -1021,6 +1026,7 @@ class NameCheckVisitor(node_visitor.ReplacingNodeVisitor):
10211026
current_class: Optional[type]
10221027
current_enum_members: Optional[Dict[object, str]]
10231028
current_function: Optional[object]
1029+
current_function_info: Optional[FunctionInfo]
10241030
current_function_name: Optional[str]
10251031
error_for_implicit_any: bool
10261032
expected_return_value: Optional[Value]
@@ -1081,6 +1087,7 @@ def __init__(
10811087
# current class (for inferring the type of cls and self arguments)
10821088
self.current_class = None
10831089
self.current_function_name = None
1090+
self.current_function_info = None
10841091

10851092
# async
10861093
self.async_kind = AsyncFunctionKind.non_async
@@ -1712,6 +1719,8 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
17121719
self, "current_function", potential_function
17131720
), qcore.override(
17141721
self, "expected_return_value", expected_return
1722+
), qcore.override(
1723+
self, "current_function_info", info
17151724
):
17161725
result = self._visit_function_body(info)
17171726

@@ -1831,7 +1840,8 @@ def record_call(self, callable: object, arguments: CallArgs) -> None:
18311840
def visit_Lambda(self, node: ast.Lambda) -> Value:
18321841
with self.asynq_checker.set_func_name("<lambda>"):
18331842
info = compute_function_info(node, self)
1834-
result = self._visit_function_body(info)
1843+
with qcore.override(self, "current_function_info", info):
1844+
result = self._visit_function_body(info)
18351845
return compute_value_of_function(info, self, result=result.return_value)
18361846

18371847
def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult:
@@ -3236,15 +3246,53 @@ def unpack_awaitable(self, composite: Composite, node: ast.AST) -> Value:
32363246
def visit_YieldFrom(self, node: ast.YieldFrom) -> Value:
32373247
self.is_generator = True
32383248
value = self.visit(node.value)
3239-
if not TypedValue(collections.abc.Iterable).is_assignable(
3240-
value, self
3241-
) and not AwaitableValue.is_assignable(value, self):
3242-
self._show_error_if_checking(
3243-
node,
3244-
f"Cannot use {value} in yield from",
3245-
error_code=ErrorCode.bad_yield_from,
3246-
)
3247-
return AnyValue(AnySource.inference)
3249+
tv_map = get_tv_map(GeneratorValue, value, self)
3250+
if isinstance(tv_map, CanAssignError):
3251+
can_assign = get_tv_map(AwaitableValue, value, self)
3252+
if not isinstance(can_assign, CanAssignError):
3253+
tv_map = {
3254+
ReturnT: can_assign.get(T, AnyValue(AnySource.generic_argument))
3255+
}
3256+
else:
3257+
can_assign = get_tv_map(IterableValue, value, self)
3258+
if isinstance(can_assign, CanAssignError):
3259+
self._show_error_if_checking(
3260+
node,
3261+
f"Cannot use {value} in yield from",
3262+
error_code=ErrorCode.bad_yield_from,
3263+
detail=can_assign.display(),
3264+
)
3265+
tv_map = {ReturnT: AnyValue(AnySource.error)}
3266+
else:
3267+
tv_map = {
3268+
YieldT: can_assign.get(T, AnyValue(AnySource.generic_argument))
3269+
}
3270+
3271+
if self.current_function_info is not None:
3272+
expected_yield = self.current_function_info.get_generator_yield_type(self)
3273+
yield_type = tv_map.get(YieldT, AnyValue(AnySource.generic_argument))
3274+
can_assign = expected_yield.can_assign(yield_type, self)
3275+
if isinstance(can_assign, CanAssignError):
3276+
self._show_error_if_checking(
3277+
node,
3278+
f"Cannot yield from {value} (expected {expected_yield})",
3279+
error_code=ErrorCode.incompatible_yield,
3280+
detail=can_assign.display(),
3281+
)
3282+
3283+
expected_send = self.current_function_info.get_generator_send_type(self)
3284+
send_type = tv_map.get(SendT, AnyValue(AnySource.generic_argument))
3285+
can_assign = send_type.can_assign(expected_send, self)
3286+
if isinstance(can_assign, CanAssignError):
3287+
self._show_error_if_checking(
3288+
node,
3289+
f"Cannot send {send_type} to a generator (expected"
3290+
f" {expected_send})",
3291+
error_code=ErrorCode.incompatible_yield,
3292+
detail=can_assign.display(),
3293+
)
3294+
3295+
return tv_map.get(ReturnT, AnyValue(AnySource.generic_argument))
32483296

32493297
def visit_Yield(self, node: ast.Yield) -> Value:
32503298
if self._is_checking():
@@ -3260,7 +3308,7 @@ def visit_Yield(self, node: ast.Yield) -> Value:
32603308
if node.value is not None:
32613309
value = self.visit(node.value)
32623310
else:
3263-
value = None
3311+
value = KnownValue(None)
32643312

32653313
if node.value is None and self.async_kind in (
32663314
AsyncFunctionKind.normal,
@@ -3270,10 +3318,21 @@ def visit_Yield(self, node: ast.Yield) -> Value:
32703318
self.is_generator = True
32713319

32723320
# unwrap the results of async yields
3273-
if self.async_kind != AsyncFunctionKind.non_async and value is not None:
3321+
if self.async_kind != AsyncFunctionKind.non_async:
32743322
return self._unwrap_yield_result(node, value)
3275-
else:
3323+
if self.current_function_info is None:
32763324
return AnyValue(AnySource.inference)
3325+
yield_type = self.current_function_info.get_generator_yield_type(self)
3326+
can_assign = yield_type.can_assign(value, self)
3327+
if isinstance(can_assign, CanAssignError):
3328+
self._show_error_if_checking(
3329+
node,
3330+
f"Cannot assign value of type {value} to yield expression of type"
3331+
f" {yield_type}",
3332+
error_code=ErrorCode.incompatible_yield,
3333+
detail=can_assign.display(),
3334+
)
3335+
return self.current_function_info.get_generator_send_type(self)
32773336

32783337
def _unwrap_yield_result(self, node: ast.AST, value: Value) -> Value:
32793338
if isinstance(value, AsyncTaskIncompleteValue):
@@ -3335,7 +3394,7 @@ def _unwrap_yield_result(self, node: ast.AST, value: Value) -> Value:
33353394
else:
33363395
self._show_error_if_checking(
33373396
node,
3338-
"Invalid value yielded: %r" % (value,),
3397+
f"Invalid value yielded: {value}",
33393398
error_code=ErrorCode.bad_async_yield,
33403399
)
33413400
return AnyValue(AnySource.error)
@@ -3354,11 +3413,18 @@ def visit_Return(self, node: ast.Return) -> None:
33543413
self._show_error_if_checking(
33553414
node, error_code=ErrorCode.no_return_may_return
33563415
)
3357-
elif (
3358-
# TODO check generator types properly
3359-
not (self.is_generator and self.async_kind == AsyncFunctionKind.non_async)
3360-
and self.expected_return_value is not None
3361-
):
3416+
elif self.is_generator and self.async_kind == AsyncFunctionKind.non_async:
3417+
if self.current_function_info is not None:
3418+
expected = self.current_function_info.get_generator_return_type(self)
3419+
can_assign = expected.can_assign(value, self)
3420+
if isinstance(can_assign, CanAssignError):
3421+
self._show_error_if_checking(
3422+
node,
3423+
f"Incompatible return value {value} (expected {expected})",
3424+
error_code=ErrorCode.incompatible_return_value,
3425+
detail=can_assign.display(),
3426+
)
3427+
elif self.expected_return_value is not None:
33623428
can_assign = self.expected_return_value.can_assign(value, self)
33633429
if isinstance(can_assign, CanAssignError):
33643430
self._show_error_if_checking(

pyanalyze/signature.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
)
6060
from .typevar import resolve_bounds_map
6161
from .value import (
62+
TypeVarLike,
6263
annotate_value,
6364
AnnotatedValue,
6465
AnySource,
@@ -522,10 +523,10 @@ class Signature:
522523
"""Whether type checking can call the actual function to retrieve a precise return value."""
523524
evaluator: Optional[Evaluator] = None
524525
"""Type evaluator for this function."""
525-
typevars_of_params: Dict[str, List["TypeVar"]] = field(
526+
typevars_of_params: Dict[str, List[TypeVarLike]] = field(
526527
init=False, default_factory=dict, repr=False, compare=False, hash=False
527528
)
528-
all_typevars: Set["TypeVar"] = field(
529+
all_typevars: Set[TypeVarLike] = field(
529530
init=False, default_factory=set, repr=False, compare=False, hash=False
530531
)
531532

pyanalyze/test_generators.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# static analysis: ignore
2+
from .implementation import assert_is_value
3+
from .test_name_check_visitor import TestNameCheckVisitorBase
4+
from .test_node_visitor import assert_passes
5+
from .value import AnySource, AnyValue, KnownValue, TypedValue
6+
7+
8+
class TestGenerator(TestNameCheckVisitorBase):
9+
@assert_passes()
10+
def test_generator_return(self):
11+
from typing import Generator
12+
13+
def gen(cond) -> Generator[int, str, float]:
14+
x = yield 1
15+
assert_is_value(x, TypedValue(str))
16+
yield "x" # E: incompatible_yield
17+
if cond:
18+
return 3.0
19+
else:
20+
return "hello" # E: incompatible_return_value
21+
22+
def capybara() -> Generator[int, int, int]:
23+
x = yield from gen(True) # E: incompatible_yield
24+
assert_is_value(x, TypedValue(float))
25+
26+
return 3
27+
28+
@assert_passes()
29+
def test_iterable_return(self):
30+
from typing import Iterable
31+
32+
def gen(cond) -> Iterable[int]:
33+
x = yield 1
34+
assert_is_value(x, KnownValue(None))
35+
36+
yield "x" # E: incompatible_yield
37+
38+
if cond:
39+
return
40+
else:
41+
return 3 # E: incompatible_return_value
42+
43+
def caller() -> Iterable[int]:
44+
x = yield from gen(True)
45+
assert_is_value(x, AnyValue(AnySource.generic_argument))

pyanalyze/value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2604,7 +2604,7 @@ def replace_known_sequence_value(value: Value) -> Value:
26042604
return value
26052605

26062606

2607-
def extract_typevars(value: Value) -> Iterable["TypeVar"]:
2607+
def extract_typevars(value: Value) -> Iterable[TypeVarLike]:
26082608
for val in value.walk_values():
26092609
if isinstance(val, TypeVarValue):
26102610
yield val.typevar

0 commit comments

Comments
 (0)