82
82
FunctionInfo ,
83
83
FunctionNode ,
84
84
FunctionResult ,
85
+ GeneratorValue ,
85
86
IMPLICIT_CLASSMETHODS ,
87
+ IterableValue ,
88
+ ReturnT ,
89
+ SendT ,
90
+ YieldT ,
86
91
)
87
92
from .options import (
88
93
add_arguments ,
@@ -1021,6 +1026,7 @@ class NameCheckVisitor(node_visitor.ReplacingNodeVisitor):
1021
1026
current_class : Optional [type ]
1022
1027
current_enum_members : Optional [Dict [object , str ]]
1023
1028
current_function : Optional [object ]
1029
+ current_function_info : Optional [FunctionInfo ]
1024
1030
current_function_name : Optional [str ]
1025
1031
error_for_implicit_any : bool
1026
1032
expected_return_value : Optional [Value ]
@@ -1081,6 +1087,7 @@ def __init__(
1081
1087
# current class (for inferring the type of cls and self arguments)
1082
1088
self .current_class = None
1083
1089
self .current_function_name = None
1090
+ self .current_function_info = None
1084
1091
1085
1092
# async
1086
1093
self .async_kind = AsyncFunctionKind .non_async
@@ -1712,6 +1719,8 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
1712
1719
self , "current_function" , potential_function
1713
1720
), qcore .override (
1714
1721
self , "expected_return_value" , expected_return
1722
+ ), qcore .override (
1723
+ self , "current_function_info" , info
1715
1724
):
1716
1725
result = self ._visit_function_body (info )
1717
1726
@@ -1831,7 +1840,8 @@ def record_call(self, callable: object, arguments: CallArgs) -> None:
1831
1840
def visit_Lambda (self , node : ast .Lambda ) -> Value :
1832
1841
with self .asynq_checker .set_func_name ("<lambda>" ):
1833
1842
info = compute_function_info (node , self )
1834
- result = self ._visit_function_body (info )
1843
+ with qcore .override (self , "current_function_info" , info ):
1844
+ result = self ._visit_function_body (info )
1835
1845
return compute_value_of_function (info , self , result = result .return_value )
1836
1846
1837
1847
def _visit_function_body (self , function_info : FunctionInfo ) -> FunctionResult :
@@ -3236,15 +3246,53 @@ def unpack_awaitable(self, composite: Composite, node: ast.AST) -> Value:
3236
3246
def visit_YieldFrom (self , node : ast .YieldFrom ) -> Value :
3237
3247
self .is_generator = True
3238
3248
value = self .visit (node .value )
3239
- if not TypedValue (collections .abc .Iterable ).is_assignable (
3240
- value , self
3241
- ) and not AwaitableValue .is_assignable (value , self ):
3242
- self ._show_error_if_checking (
3243
- node ,
3244
- f"Cannot use { value } in yield from" ,
3245
- error_code = ErrorCode .bad_yield_from ,
3246
- )
3247
- return AnyValue (AnySource .inference )
3249
+ tv_map = get_tv_map (GeneratorValue , value , self )
3250
+ if isinstance (tv_map , CanAssignError ):
3251
+ can_assign = get_tv_map (AwaitableValue , value , self )
3252
+ if not isinstance (can_assign , CanAssignError ):
3253
+ tv_map = {
3254
+ ReturnT : can_assign .get (T , AnyValue (AnySource .generic_argument ))
3255
+ }
3256
+ else :
3257
+ can_assign = get_tv_map (IterableValue , value , self )
3258
+ if isinstance (can_assign , CanAssignError ):
3259
+ self ._show_error_if_checking (
3260
+ node ,
3261
+ f"Cannot use { value } in yield from" ,
3262
+ error_code = ErrorCode .bad_yield_from ,
3263
+ detail = can_assign .display (),
3264
+ )
3265
+ tv_map = {ReturnT : AnyValue (AnySource .error )}
3266
+ else :
3267
+ tv_map = {
3268
+ YieldT : can_assign .get (T , AnyValue (AnySource .generic_argument ))
3269
+ }
3270
+
3271
+ if self .current_function_info is not None :
3272
+ expected_yield = self .current_function_info .get_generator_yield_type (self )
3273
+ yield_type = tv_map .get (YieldT , AnyValue (AnySource .generic_argument ))
3274
+ can_assign = expected_yield .can_assign (yield_type , self )
3275
+ if isinstance (can_assign , CanAssignError ):
3276
+ self ._show_error_if_checking (
3277
+ node ,
3278
+ f"Cannot yield from { value } (expected { expected_yield } )" ,
3279
+ error_code = ErrorCode .incompatible_yield ,
3280
+ detail = can_assign .display (),
3281
+ )
3282
+
3283
+ expected_send = self .current_function_info .get_generator_send_type (self )
3284
+ send_type = tv_map .get (SendT , AnyValue (AnySource .generic_argument ))
3285
+ can_assign = send_type .can_assign (expected_send , self )
3286
+ if isinstance (can_assign , CanAssignError ):
3287
+ self ._show_error_if_checking (
3288
+ node ,
3289
+ f"Cannot send { send_type } to a generator (expected"
3290
+ f" { expected_send } )" ,
3291
+ error_code = ErrorCode .incompatible_yield ,
3292
+ detail = can_assign .display (),
3293
+ )
3294
+
3295
+ return tv_map .get (ReturnT , AnyValue (AnySource .generic_argument ))
3248
3296
3249
3297
def visit_Yield (self , node : ast .Yield ) -> Value :
3250
3298
if self ._is_checking ():
@@ -3260,7 +3308,7 @@ def visit_Yield(self, node: ast.Yield) -> Value:
3260
3308
if node .value is not None :
3261
3309
value = self .visit (node .value )
3262
3310
else :
3263
- value = None
3311
+ value = KnownValue ( None )
3264
3312
3265
3313
if node .value is None and self .async_kind in (
3266
3314
AsyncFunctionKind .normal ,
@@ -3270,10 +3318,21 @@ def visit_Yield(self, node: ast.Yield) -> Value:
3270
3318
self .is_generator = True
3271
3319
3272
3320
# unwrap the results of async yields
3273
- if self .async_kind != AsyncFunctionKind .non_async and value is not None :
3321
+ if self .async_kind != AsyncFunctionKind .non_async :
3274
3322
return self ._unwrap_yield_result (node , value )
3275
- else :
3323
+ if self . current_function_info is None :
3276
3324
return AnyValue (AnySource .inference )
3325
+ yield_type = self .current_function_info .get_generator_yield_type (self )
3326
+ can_assign = yield_type .can_assign (value , self )
3327
+ if isinstance (can_assign , CanAssignError ):
3328
+ self ._show_error_if_checking (
3329
+ node ,
3330
+ f"Cannot assign value of type { value } to yield expression of type"
3331
+ f" { yield_type } " ,
3332
+ error_code = ErrorCode .incompatible_yield ,
3333
+ detail = can_assign .display (),
3334
+ )
3335
+ return self .current_function_info .get_generator_send_type (self )
3277
3336
3278
3337
def _unwrap_yield_result (self , node : ast .AST , value : Value ) -> Value :
3279
3338
if isinstance (value , AsyncTaskIncompleteValue ):
@@ -3335,7 +3394,7 @@ def _unwrap_yield_result(self, node: ast.AST, value: Value) -> Value:
3335
3394
else :
3336
3395
self ._show_error_if_checking (
3337
3396
node ,
3338
- "Invalid value yielded: %r" % ( value ,) ,
3397
+ f "Invalid value yielded: { value } " ,
3339
3398
error_code = ErrorCode .bad_async_yield ,
3340
3399
)
3341
3400
return AnyValue (AnySource .error )
@@ -3354,11 +3413,18 @@ def visit_Return(self, node: ast.Return) -> None:
3354
3413
self ._show_error_if_checking (
3355
3414
node , error_code = ErrorCode .no_return_may_return
3356
3415
)
3357
- elif (
3358
- # TODO check generator types properly
3359
- not (self .is_generator and self .async_kind == AsyncFunctionKind .non_async )
3360
- and self .expected_return_value is not None
3361
- ):
3416
+ elif self .is_generator and self .async_kind == AsyncFunctionKind .non_async :
3417
+ if self .current_function_info is not None :
3418
+ expected = self .current_function_info .get_generator_return_type (self )
3419
+ can_assign = expected .can_assign (value , self )
3420
+ if isinstance (can_assign , CanAssignError ):
3421
+ self ._show_error_if_checking (
3422
+ node ,
3423
+ f"Incompatible return value { value } (expected { expected } )" ,
3424
+ error_code = ErrorCode .incompatible_return_value ,
3425
+ detail = can_assign .display (),
3426
+ )
3427
+ elif self .expected_return_value is not None :
3362
3428
can_assign = self .expected_return_value .can_assign (value , self )
3363
3429
if isinstance (can_assign , CanAssignError ):
3364
3430
self ._show_error_if_checking (
0 commit comments