Skip to content

Commit 34974b9

Browse files
Improve ParamSpec support (#772)
1 parent 51c1e12 commit 34974b9

File tree

7 files changed

+232
-10
lines changed

7 files changed

+232
-10
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Improve `ParamSpec` support (#772)
56
- Fix handling of stub functions with positional-only parameters with
67
defaults (#769)
78
- Recognize exhaustive pattern matching (#766)

pyanalyze/arg_spec.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
KnownValue,
8585
KVPair,
8686
NewTypeValue,
87+
ParamSpecArgsValue,
88+
ParamSpecKwargsValue,
8789
SubclassValue,
8890
TypedDictEntry,
8991
TypedDictValue,
@@ -444,15 +446,25 @@ def from_signature(
444446
returns = make_coro_type(returns)
445447

446448
parameters = []
449+
seen_paramspec_args: Optional[ParamSpecArgsValue] = None
447450
for i, parameter in enumerate(sig.parameters.values()):
448-
param, make_everything_pos_only = self._make_sig_parameter(
449-
parameter, func_globals, function_object, is_wrapped, i
451+
param, make_everything_pos_only, new_ps_args = self._make_sig_parameter(
452+
parameter,
453+
func_globals,
454+
function_object,
455+
is_wrapped,
456+
i,
457+
seen_paramspec_args,
450458
)
451459
if make_everything_pos_only:
452460
parameters = [
453461
replace(param, kind=ParameterKind.POSITIONAL_ONLY)
454462
for param in parameters
455463
]
464+
if new_ps_args is not None:
465+
seen_paramspec_args = new_ps_args
466+
if param is None:
467+
continue
456468
parameters.append(param)
457469

458470
return Signature.make(
@@ -473,14 +485,20 @@ def _make_sig_parameter(
473485
function_object: Optional[object],
474486
is_wrapped: bool,
475487
index: int,
476-
) -> Tuple[SigParameter, bool]:
488+
seen_paramspec_args: Optional[ParamSpecArgsValue],
489+
) -> Tuple[Optional[SigParameter], bool, Optional[ParamSpecArgsValue]]:
477490
"""Given an inspect.Parameter, returns a Parameter object."""
478491
if is_wrapped:
479492
typ = AnyValue(AnySource.inference)
480493
else:
481494
typ = self._get_type_for_parameter(
482495
parameter, func_globals, function_object, index
483496
)
497+
if (
498+
isinstance(typ, ParamSpecArgsValue)
499+
and parameter.kind is inspect.Parameter.VAR_POSITIONAL
500+
):
501+
return (None, False, typ)
484502
if parameter.default is inspect.Parameter.empty:
485503
default = None
486504
else:
@@ -496,9 +514,18 @@ def _make_sig_parameter(
496514
else:
497515
kind = ParameterKind(parameter.kind)
498516
make_everything_pos_only = False
517+
if (
518+
seen_paramspec_args is not None
519+
and kind is ParameterKind.VAR_KEYWORD
520+
and isinstance(typ, ParamSpecKwargsValue)
521+
and seen_paramspec_args.param_spec is typ.param_spec
522+
):
523+
kind = ParameterKind.PARAM_SPEC
524+
typ = TypeVarValue(typ.param_spec)
499525
return (
500526
SigParameter(parameter.name, kind, default=default, annotation=typ),
501527
make_everything_pos_only,
528+
None,
502529
)
503530

504531
def _get_type_for_parameter(

pyanalyze/functions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
CanAssignError,
3131
GenericValue,
3232
KnownValue,
33+
ParamSpecArgsValue,
34+
ParamSpecKwargsValue,
3335
SubclassValue,
3436
TypedValue,
3537
TypeVarValue,
@@ -250,6 +252,7 @@ def compute_parameters(
250252
params = []
251253
tv_index = 1
252254

255+
seen_paramspec_args: Optional[Tuple[ast.arg, ParamSpecArgsValue]] = None
253256
for idx, (param, default) in enumerate(zip_longest(args, defaults)):
254257
assert param is not None, "must have more args than defaults"
255258
(kind, arg) = param
@@ -301,6 +304,38 @@ def compute_parameters(
301304
value = unite_values(value, default)
302305

303306
value = translate_vararg_type(kind, value, ctx, error_ctx=ctx, node=arg)
307+
if isinstance(value, ParamSpecArgsValue):
308+
if kind is ParameterKind.VAR_POSITIONAL:
309+
seen_paramspec_args = (arg, value)
310+
else:
311+
ctx.show_error(
312+
arg,
313+
f"ParamSpec.args must be used on *args, not {arg.arg}",
314+
error_code=ErrorCode.invalid_annotation,
315+
)
316+
elif isinstance(value, ParamSpecKwargsValue):
317+
if kind is ParameterKind.VAR_KEYWORD:
318+
if seen_paramspec_args is not None:
319+
_, ps_args = seen_paramspec_args
320+
if ps_args.param_spec is not value.param_spec:
321+
ctx.show_error(
322+
arg,
323+
"The same ParamSpec must be used on *args and **kwargs",
324+
error_code=ErrorCode.invalid_annotation,
325+
)
326+
else:
327+
ctx.show_error(
328+
arg,
329+
"ParamSpec.kwargs must be used together with ParamSpec.args",
330+
error_code=ErrorCode.invalid_annotation,
331+
)
332+
else:
333+
ctx.show_error(
334+
arg,
335+
f"ParamSpec.kwargs must be used on **kwargs, not {arg.arg}",
336+
error_code=ErrorCode.invalid_annotation,
337+
)
338+
304339
param = SigParameter(arg.arg, kind, default, value)
305340
info = ParamInfo(param, arg, is_self)
306341
params.append(info)
@@ -326,6 +361,8 @@ def translate_vararg_type(
326361
)
327362
return AnyValue(AnySource.error)
328363
return typ.value
364+
elif isinstance(typ, ParamSpecArgsValue):
365+
return typ
329366
else:
330367
return GenericValue(tuple, [typ])
331368
elif kind is ParameterKind.VAR_KEYWORD:
@@ -339,6 +376,8 @@ def translate_vararg_type(
339376
)
340377
return AnyValue(AnySource.error)
341378
return typ.value
379+
elif isinstance(typ, ParamSpecKwargsValue):
380+
return typ
342381
else:
343382
return GenericValue(dict, [TypedValue(str), typ])
344383
return typ

pyanalyze/signature.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
AsyncTaskIncompleteValue,
7171
BoundsMap,
7272
CallableValue,
73+
CallValue,
7374
CanAssign,
7475
CanAssignContext,
7576
CanAssignError,
@@ -411,7 +412,10 @@ def allow_unpack(self) -> bool:
411412
ParameterKind.VAR_POSITIONAL,
412413
ParameterKind.KEYWORD_ONLY,
413414
},
414-
ParameterKind.PARAM_SPEC: {ParameterKind.POSITIONAL_ONLY},
415+
ParameterKind.PARAM_SPEC: {
416+
ParameterKind.POSITIONAL_ONLY,
417+
ParameterKind.POSITIONAL_OR_KEYWORD,
418+
},
415419
ParameterKind.ELLIPSIS: {ParameterKind.POSITIONAL_ONLY},
416420
}
417421
CAN_HAVE_DEFAULT = {
@@ -1079,8 +1083,28 @@ def bind_arguments(
10791083
)
10801084
bound_args[param.name] = KWARGS, composite
10811085
else:
1082-
self.show_call_error("Callable requires a ParamSpec argument", ctx)
1083-
return None
1086+
new_actuals = ActualArguments(
1087+
positionals=actual_args.positionals[positional_index:],
1088+
star_args=(
1089+
actual_args.star_args if not star_args_consumed else None
1090+
),
1091+
keywords={
1092+
key: value
1093+
for key, value in actual_args.keywords.items()
1094+
if key not in keywords_consumed
1095+
},
1096+
star_kwargs=(
1097+
actual_args.star_kwargs
1098+
if not star_kwargs_consumed
1099+
else None
1100+
),
1101+
kwargs_required=actual_args.kwargs_required,
1102+
pos_or_keyword_params=actual_args.pos_or_keyword_params,
1103+
)
1104+
star_args_consumed = True
1105+
star_kwargs_consumed = True
1106+
val = CallValue(new_actuals)
1107+
bound_args[param.name] = UNKNOWN, Composite(val)
10841108
else:
10851109
assert False, f"unhandled param {param.kind}"
10861110

@@ -1973,14 +1997,27 @@ def preprocess_args(
19731997
# Step 1: Split up args and kwargs if possible.
19741998
processed_args: List[Argument] = []
19751999
kwargs_requireds = []
2000+
param_spec = None
2001+
param_spec_star_arg = None
2002+
seen_param_spec_kwargs = False
19762003
for arg, label in args:
19772004
if label is ARGS:
2005+
if isinstance(arg.value, ParamSpecArgsValue):
2006+
if param_spec is not None:
2007+
ctx.on_error(
2008+
"Only a single ParamSpec.args can be passed", node=arg.node
2009+
)
2010+
param_spec = TypeVarValue(arg.value.param_spec)
2011+
param_spec_star_arg = arg
2012+
continue
19782013
concrete_values = concrete_values_from_iterable(
19792014
arg.value, ctx.can_assign_ctx
19802015
)
19812016
if isinstance(concrete_values, CanAssignError):
19822017
ctx.on_error(
1983-
f"{arg.value} is not iterable", detail=str(concrete_values)
2018+
f"{arg.value} is not iterable",
2019+
detail=str(concrete_values),
2020+
node=arg.node,
19842021
)
19852022
return None
19862023
elif isinstance(concrete_values, Value):
@@ -1995,6 +2032,23 @@ def preprocess_args(
19952032
for subval in concrete_values:
19962033
processed_args.append((Composite(subval), None))
19972034
elif label is KWARGS:
2035+
if isinstance(arg.value, ParamSpecKwargsValue):
2036+
if param_spec is None:
2037+
ctx.on_error(
2038+
"ParamSpec.kwargs cannot be passed without ParamSpec.args",
2039+
node=arg.node,
2040+
)
2041+
elif param_spec.typevar is not arg.value.param_spec:
2042+
ctx.on_error(
2043+
"ParamSpec.args and ParamSpec.kwargs must use the same ParamSpec",
2044+
node=arg.node,
2045+
)
2046+
elif seen_param_spec_kwargs:
2047+
ctx.on_error(
2048+
"Only a single ParamSpec.kwargs can be passed", node=arg.node
2049+
)
2050+
seen_param_spec_kwargs = True
2051+
continue
19982052
items = {}
19992053
extra_values = []
20002054
if arg.value is NO_RETURN_VALUE:
@@ -2036,6 +2090,11 @@ def preprocess_args(
20362090
processed_args.append((new_composite, KWARGS))
20372091
else:
20382092
processed_args.append((arg, label))
2093+
if param_spec_star_arg is not None and not seen_param_spec_kwargs:
2094+
ctx.on_error(
2095+
"ParamSpec.args cannot be passed without ParamSpec.kwargs",
2096+
node=param_spec_star_arg.node,
2097+
)
20392098

20402099
# Step 2: enforce invariants about ARGS and KWARGS placement. We dump
20412100
# any single arguments that come after *args into *args, and we merge all *args.
@@ -2097,6 +2156,9 @@ def preprocess_args(
20972156
more_processed_kwargs[label.name] = (label.is_required, arg)
20982157
more_processed_args.append((label.is_required, arg))
20992158
elif isinstance(label, TypeVarValue):
2159+
if param_spec is not None:
2160+
ctx.on_error("Multiple ParamSpecs passed")
2161+
continue
21002162
param_spec = label
21012163
elif label is ELLIPSIS:
21022164
is_ellipsis = True
@@ -2276,7 +2338,6 @@ def check_call(
22762338
actual_args = preprocess_args(args, ctx)
22772339
if actual_args is None:
22782340
return AnyValue(AnySource.error)
2279-
22802341
# We first bind the arguments for each overload, to get the obvious errors
22812342
# out of the way first.
22822343
errors_per_overload = []
@@ -2703,6 +2764,21 @@ def decompose_union(
27032764
return None
27042765

27052766

2767+
def check_call_preprocessed(
2768+
sig: ConcreteSignature, args: ActualArguments, ctx: CanAssignContext
2769+
) -> CanAssign:
2770+
if isinstance(sig, Signature):
2771+
check_ctx = _CanAssignBasedContext(ctx)
2772+
sig.check_call_preprocessed(args, check_ctx)
2773+
if check_ctx.errors:
2774+
return CanAssignError(
2775+
"Incompatible callable", [CanAssignError(e) for e in check_ctx.errors]
2776+
)
2777+
return {}
2778+
else:
2779+
return CanAssignError("Overloads are not supported")
2780+
2781+
27062782
def _extract_known_value(val: Value) -> Optional[KnownValue]:
27072783
if isinstance(val, AnnotatedValue):
27082784
val = val.value

pyanalyze/test_annotations.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,41 @@ def capybara():
18251825
func(1, "A")
18261826
func(1, 2) # E: incompatible_argument
18271827

1828+
@assert_passes()
1829+
def test_apply(self):
1830+
from typing import Callable, TypeVar
1831+
1832+
from typing_extensions import ParamSpec, assert_type
1833+
1834+
P = ParamSpec("P")
1835+
T = TypeVar("T")
1836+
1837+
def apply(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
1838+
return func(*args, **kwargs)
1839+
1840+
def sample(x: int) -> str:
1841+
return str(x)
1842+
1843+
def capybara() -> None:
1844+
assert_type(apply(sample, 1), str)
1845+
apply(sample, "x") # E: incompatible_call
1846+
1847+
@assert_passes()
1848+
def test_param_spec_errors(self):
1849+
from typing import Callable, TypeVar
1850+
1851+
from typing_extensions import ParamSpec
1852+
1853+
P = ParamSpec("P")
1854+
T = TypeVar("T")
1855+
1856+
def apply(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
1857+
func(*args) # E: incompatible_call
1858+
func(*args, *args, **kwargs) # E: incompatible_call
1859+
func(**kwargs) # E: incompatible_call
1860+
func(*args, **kwargs, **kwargs) # E: incompatible_call
1861+
func(*args, **kwargs)
1862+
18281863

18291864
class TestTypeAlias(TestNameCheckVisitorBase):
18301865
@assert_passes()

0 commit comments

Comments
 (0)