diff --git a/mypy/meet.py b/mypy/meet.py index 50bd93b580e1..ee32f239df8c 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -229,6 +229,15 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: ): return original_declared return meet_types(original_declared, original_narrowed) + elif ( + isinstance(declared, CallableType) + and isinstance(narrowed, CallableType) + and has_type_vars(declared.ret_type) + ): + return narrowed.copy_modified( + ret_type=narrow_declared_type(declared.ret_type, narrowed.ret_type) + ) + return original_narrowed diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7481eb308aa3..e6a798b77e6d 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -3826,3 +3826,26 @@ def get_owner(uid: UserId, tid: TeamId): reveal_type(INVALID) # N: Revealed type is "builtins.int" return None [builtins fixtures/primitives.pyi] + +[case testNarrowCallableTypeVarByEquality] +from typing import Callable, TypeVar + +T = TypeVar("T") + +def remove(path: str) -> None: ... +def unlink(path: str) -> None: ... + +def f1(func: Callable[..., T], arg: str) -> T: + if func == remove: + reveal_type(func) # N: Revealed type is "def (path: builtins.str) -> T`-1" + reveal_type(func(arg)) # N: Revealed type is "T`-1" + return func(arg) + return func(arg) + +def f2(func: Callable[..., T], arg: str) -> T: + if func in [unlink, remove]: + reveal_type(func) # N: Revealed type is "def (path: builtins.str) -> T`-1" + reveal_type(func(arg)) # N: Revealed type is "T`-1" + return func(arg) + return func(arg) +[builtins fixtures/primitives.pyi]