Skip to content

Commit 8f0bdc3

Browse files
fix recursive base (#402)
1 parent b997b37 commit 8f0bdc3

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Fix stub classes with references to themselves in their
6+
base classes, such as `os._ScandirIterator` in typeshed (#402)
57
- Fix type narrowing on the `else` case of `issubclass()`
68
(#401)
79
- Fix indexing a list with an index typed as a
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from typing import ContextManager, AnyStr
2+
3+
class _ScandirIterator(ContextManager[_ScandirIterator[AnyStr]]):
4+
def close(self) -> None: ...

pyanalyze/test_typeshed.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,20 @@ def capybara(unannotated):
236236
TypedValue(TextIO) | TypedValue(BinaryIO),
237237
)
238238

239+
@assert_passes()
240+
def test_recursive_base(self):
241+
from typing import Any, ContextManager
242+
243+
def capybara():
244+
from _pyanalyze_tests.recursion import _ScandirIterator
245+
246+
def want_cm(cm: ContextManager[Any]) -> None:
247+
pass
248+
249+
def f(x: _ScandirIterator):
250+
want_cm(x)
251+
len(x) # E: incompatible_argument
252+
239253

240254
class Parent(Generic[T]):
241255
pass

pyanalyze/typeshed.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class TypeshedFinder:
155155
_attribute_cache: Dict[Tuple[str, str, bool], Value] = field(
156156
default_factory=dict, repr=False, init=False
157157
)
158+
_active_infos: List[typeshed_client.resolver.ResolvedName] = field(
159+
default_factory=list, repr=False, init=False
160+
)
158161

159162
@classmethod
160163
def make(cls, options: Options, *, verbose: bool = False) -> "TypeshedFinder":
@@ -860,7 +863,7 @@ def _parse_call_assignment(
860863

861864
def make_synthetic_type(self, module: str, info: typeshed_client.NameInfo) -> Value:
862865
fq_name = f"{module}.{info.name}"
863-
bases = self.get_bases_for_fq_name(fq_name)
866+
bases = self._get_bases_from_info(info, module)
864867
typ = TypedValue(fq_name)
865868
if bases is not None:
866869
if any(
@@ -908,6 +911,19 @@ def _make_td_value(self, field: Value, total: bool) -> Tuple[bool, Value]:
908911

909912
def _value_from_info(
910913
self, info: typeshed_client.resolver.ResolvedName, module: str
914+
) -> Value:
915+
# This guard against infinite recursion if a type refers to itself
916+
# (real-world example: os._ScandirIterator).
917+
if info in self._active_infos:
918+
return AnyValue(AnySource.inference)
919+
self._active_infos.append(info)
920+
try:
921+
return self._value_from_info_inner(info, module)
922+
finally:
923+
self._active_infos.pop()
924+
925+
def _value_from_info_inner(
926+
self, info: typeshed_client.resolver.ResolvedName, module: str
911927
) -> Value:
912928
if isinstance(info, typeshed_client.ImportedInfo):
913929
return self._value_from_info(info.info, ".".join(info.source_module))
@@ -945,7 +961,7 @@ def _value_from_info(
945961
return val
946962
if info.ast.value:
947963
return self._parse_expr(info.ast.value, module)
948-
elif isinstance(info.ast, ast.FunctionDef):
964+
elif isinstance(info.ast, (ast.FunctionDef, ast.AsyncFunctionDef)):
949965
sig = self._get_signature_from_info(info, None, fq_name, module)
950966
if sig is not None:
951967
return CallableValue(sig)

0 commit comments

Comments
 (0)