Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

## Unreleased

- Partial support for PEP 695-style type aliases (#690, #692)
- Fix tests to account for new `typeshed_client` release
(#694)
- Partial support for PEP 695-style type aliases (#690)
- Add option to disable all error codes (#659)
- Add hacky fix for bugs with hashability on type objects (#689)
- Show an error on calls to `typing.Any` (#688)
Expand Down
10 changes: 7 additions & 3 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TypeVar,
Union,
)
import typing_extensions

import qcore

Expand Down Expand Up @@ -539,11 +540,14 @@ def _type_from_runtime(


def make_type_var_value(tv: TypeVarLike, ctx: Context) -> TypeVarValue:
if tv.__bound__ is not None:
if (
isinstance(tv, (TypeVar, typing_extensions.TypeVar))
and tv.__bound__ is not None
):
bound = _type_from_runtime(tv.__bound__, ctx)
else:
bound = None
if isinstance(tv, TypeVar) and tv.__constraints__:
if isinstance(tv, (TypeVar, typing_extensions.TypeVar)) and tv.__constraints__:
constraints = tuple(
_type_from_runtime(constraint, ctx) for constraint in tv.__constraints__
)
Expand Down Expand Up @@ -656,7 +660,7 @@ def _type_from_value(
return _type_from_runtime(
value.val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)
elif isinstance(value, TypeVarValue):
elif isinstance(value, (TypeVarValue, TypeAliasValue)):
return value
elif isinstance(value, MultiValuedValue):
return unite_values(
Expand Down
99 changes: 95 additions & 4 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import traceback
import types
import typing
from argparse import ArgumentParser
from dataclasses import dataclass
from itertools import chain
Expand Down Expand Up @@ -100,6 +101,7 @@
from .predicates import EqualsPredicate, InPredicate
from .reexport import ImplicitReexportTracker
from .safe import (
all_of_type,
is_dataclass_type,
is_hashable,
safe_getattr,
Expand Down Expand Up @@ -162,6 +164,8 @@
DefiniteValueExtension,
DeprecatedExtension,
SkipDeprecatedExtension,
TypeAlias,
TypeAliasValue,
annotate_value,
AnnotatedValue,
AnySource,
Expand Down Expand Up @@ -1754,22 +1758,22 @@ def visit_ClassDef(self, node: ast.ClassDef) -> Value:
value, _ = self._set_name_in_scope(node.name, node, value)
return value

def _get_class_object(self, node: ast.ClassDef) -> Value:
def _get_local_object(self, name: str, node: ast.AST) -> Value:
if self.scopes.scope_type() == ScopeType.module_scope:
return self.scopes.get(node.name, node, self.state)
return self.scopes.get(name, node, self.state)
elif (
self.scopes.scope_type() == ScopeType.class_scope
and self.current_class is not None
and hasattr(self.current_class, "__dict__")
):
runtime_obj = self.current_class.__dict__.get(node.name)
runtime_obj = self.current_class.__dict__.get(name)
if isinstance(runtime_obj, type):
return KnownValue(runtime_obj)
return AnyValue(AnySource.inference)

def _visit_class_and_get_value(self, node: ast.ClassDef) -> Value:
if self._is_checking():
cls_obj = self._get_class_object(node)
cls_obj = self._get_local_object(node.name, node)

module = self.module
if isinstance(cls_obj, MultiValuedValue) and module is not None:
Expand Down Expand Up @@ -4506,6 +4510,93 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None:
# syntax like 'x = y = 0' results in multiple targets
self.visit(node.target)

if sys.version_info >= (3, 12):

def visit_TypeAlias(self, node: ast.TypeAlias) -> Value:
assert isinstance(node.name, ast.Name)
name = node.name.id
alias_val = self._get_local_object(name, node)
if isinstance(alias_val, KnownValue) and isinstance(
alias_val.val, typing.TypeAliasType
):
alias_obj = alias_val.val
else:
alias_obj = None
type_param_values = []
if self._is_checking():
if node.type_params:
with self.scopes.add_scope(
ScopeType.annotation_scope,
scope_node=node,
scope_object=alias_obj,
):
type_param_values = [
self.visit(param) for param in node.type_params
]
assert all_of_type(type_param_values, TypeVarValue)
with self.scopes.add_scope(
ScopeType.annotation_scope,
scope_node=node,
scope_object=alias_obj,
):
value = self.visit(node.value)

else:
with self.scopes.add_scope(
ScopeType.annotation_scope,
scope_node=node,
scope_object=alias_obj,
):
value = self.visit(node.value)
else:
value = None
if alias_obj is None:
if value is None:
alias_val = AnyValue(AnySource.inference)
else:
alias_val = TypeAliasValue(
name,
self.module.__name__ if self.module is not None else "",
TypeAlias(
lambda: type_from_value(value, self, node),
lambda: tuple(val.typevar for val in type_param_values),
),
)
set_value, _ = self._set_name_in_scope(name, node, alias_val)
return set_value

def visit_TypeVar(self, node: ast.TypeVar) -> Value:
bound = constraints = None
if node.bound is not None:
if isinstance(node.bound, ast.Tuple):
constraints = [self.visit(elt) for elt in node.bound.elts]
else:
bound = self.visit(node.bound)
tv = TypeVar(node.name)
typevar = TypeVarValue(
tv,
type_from_value(bound, self, node) if bound is not None else None,
(
tuple(type_from_value(c, self, node) for c in constraints)
if constraints is not None
else ()
),
)
self._set_name_in_scope(node.name, node, typevar)
return typevar

def visit_ParamSpec(self, node: ast.ParamSpec) -> Value:
ps = typing.ParamSpec(node.name)
typevar = TypeVarValue(ps, is_paramspec=True)
self._set_name_in_scope(node.name, node, typevar)
return typevar

def visit_TypeVarTuple(self, node: ast.TypeVarTuple) -> Value:
tv = TypeVar(node.name)
typevar = TypeVarValue(tv, is_typevartuple=True)
self._set_name_in_scope(node.name, node, typevar)
return typevar

def visit_Name(self, node: ast.Name, force_read: bool = False) -> Value:
return self.composite_from_name(node, force_read=force_read).value

Expand Down
1 change: 1 addition & 0 deletions pyanalyze/stacked_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ScopeType(enum.Enum):
module_scope = 2
class_scope = 3
function_scope = 4
annotation_scope = 5


# Nodes as used in scopes can be any object, as long as they are hashable.
Expand Down
96 changes: 96 additions & 0 deletions pyanalyze/test_type_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# static analysis: ignore
from .test_name_check_visitor import TestNameCheckVisitorBase
from .test_node_visitor import assert_passes, skip_before


class TestRecursion(TestNameCheckVisitorBase):
@assert_passes()
def test(self):
from typing import Dict, List, Union

JSON = Union[Dict[str, "JSON"], List["JSON"], int, str, float, bool, None]

def f(x: JSON):
pass

def capybara():
f([])
f([1, 2, 3])
f([[{1}]]) # TODO this should throw an error


class TestTypeAliasType(TestNameCheckVisitorBase):
@assert_passes()
def test_typing_extensions(self):
from typing_extensions import TypeAliasType, assert_type

MyType = TypeAliasType("MyType", int)

def f(x: MyType):
assert_type(x, MyType)
assert_type(x + 1, int)

def capybara(i: int, s: str):
f(i)
f(s) # E: incompatible_argument

@assert_passes()
def test_typing_extensions_generic(self):
from typing_extensions import TypeAliasType, assert_type
from typing import TypeVar, Union, List, Set

T = TypeVar("T")
MyType = TypeAliasType("MyType", Union[List[T], Set[T]], type_params=(T,))

def f(x: MyType[int]):
assert_type(x, MyType[int])
assert_type(list(x), List[int])

def capybara(i: int, s: str):
f([i])
f([s]) # E: incompatible_argument

@skip_before((3, 12))
def test_312(self):
self.assert_passes("""
from typing_extensions import assert_type
type MyType = int

def f(x: MyType):
assert_type(x, MyType)
assert_type(x + 1, int)

def capybara(i: int, s: str):
f(i)
f(s) # E: incompatible_argument
""")

@skip_before((3, 12))
def test_312_generic(self):
self.assert_passes("""
from typing_extensions import assert_type
type MyType[T] = list[T] | set[T]

def f(x: MyType[int]):
assert_type(x, MyType[int])
assert_type(list(x), list[int])

def capybara(i: int, s: str):
f([i])
f([s]) # E: incompatible_argument
""")

@skip_before((3, 12))
def test_312_local_alias(self):
self.assert_passes("""
from typing_extensions import assert_type

def capybara():
type MyType = int
def f(x: MyType):
assert_type(x, MyType)
assert_type(x + 1, int)

f(1)
f("x") # E: incompatible_argument
""")
30 changes: 27 additions & 3 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def function(x: int, y: list[int], z: Any):
from collections import deque
from dataclasses import dataclass, field, InitVar
from itertools import chain
import sys
from types import FunctionType
from typing import (
Any,
Expand Down Expand Up @@ -60,9 +61,31 @@ def function(x: int, y: list[int], z: Any):
KNOWN_MUTABLE_TYPES = (list, set, dict, deque)
ITERATION_LIMIT = 1000

TypeVarLike = Union[
ExternalType["typing.TypeVar"], ExternalType["typing_extensions.ParamSpec"]
]
if sys.version_info >= (3, 11):
TypeVarLike = Union[
ExternalType["typing.TypeVar"],
ExternalType["typing_extensions.TypeVar"],
ExternalType["typing.ParamSpec"],
ExternalType["typing_extensions.ParamSpec"],
ExternalType["typing.TypeVarTuple"],
ExternalType["typing_extensions.TypeVarTuple"],
]
elif sys.version_info >= (3, 10):
TypeVarLike = Union[
ExternalType["typing.TypeVar"],
ExternalType["typing_extensions.TypeVar"],
ExternalType["typing.ParamSpec"],
ExternalType["typing_extensions.ParamSpec"],
ExternalType["typing_extensions.TypeVarTuple"],
]
else:
TypeVarLike = Union[
ExternalType["typing.TypeVar"],
ExternalType["typing_extensions.TypeVar"],
ExternalType["typing_extensions.ParamSpec"],
ExternalType["typing_extensions.TypeVarTuple"],
]

TypeVarMap = Mapping[TypeVarLike, ExternalType["pyanalyze.value.Value"]]
BoundsMap = Mapping[TypeVarLike, Sequence[ExternalType["pyanalyze.value.Bound"]]]
GenericBases = Mapping[Union[type, str], TypeVarMap]
Expand Down Expand Up @@ -1737,6 +1760,7 @@ class TypeVarValue(Value):
bound: Optional[Value] = None
constraints: Sequence[Value] = ()
is_paramspec: bool = False
is_typevartuple: bool = False # unsupported

def substitute_typevars(self, typevars: TypeVarMap) -> Value:
return typevars.get(self.typevar, self)
Expand Down