Skip to content

Commit 20decdf

Browse files
updated
Signed-off-by: [email protected] <[email protected]>
1 parent 5145566 commit 20decdf

File tree

2 files changed

+75
-72
lines changed

2 files changed

+75
-72
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import TYPE_CHECKING, Union
3+
4+
from vllm import envs
5+
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
6+
from vllm.distributed.kv_transfer.kv_connector.factory import (
7+
KVConnectorFactory)
8+
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
9+
KVConnectorRole)
10+
from vllm.distributed.parallel_state import get_world_group
11+
12+
if TYPE_CHECKING:
13+
from vllm.config import VllmConfig
14+
15+
_KV_CONNECTOR_AGENT: Union[KVConnectorBase, KVConnectorBase_V1, None] = None
16+
17+
18+
def get_kv_transfer_group() -> Union[KVConnectorBase, KVConnectorBase_V1]:
19+
assert _KV_CONNECTOR_AGENT is not None, (
20+
"disaggregated KV cache transfer parallel group is not initialized")
21+
return _KV_CONNECTOR_AGENT
22+
23+
24+
def has_kv_transfer_group() -> bool:
25+
return _KV_CONNECTOR_AGENT is not None
26+
27+
28+
def is_v1_kv_transfer_group(
29+
connector: Union[KVConnectorBase_V1, KVConnectorBase,
30+
None] = None) -> bool:
31+
"""Check if the KV connector is the v1 connector.
32+
If the argument is None, it will check the global KV connector
33+
34+
Args:
35+
connector: The KV connector to check. If None, it will check the
36+
global KV connector.
37+
38+
Note:
39+
This function will no-longer be needed after the v1 KV connector
40+
becomes the default.
41+
"""
42+
if connector is None:
43+
connector = _KV_CONNECTOR_AGENT
44+
45+
if connector is None:
46+
return False
47+
48+
return isinstance(connector, KVConnectorBase_V1)
49+
50+
51+
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
52+
"""
53+
Initialize KV cache transfer parallel group.
54+
"""
55+
56+
global _KV_CONNECTOR_AGENT
57+
58+
if vllm_config.kv_transfer_config is None:
59+
return
60+
61+
if all([
62+
vllm_config.kv_transfer_config.is_kv_transfer_instance,
63+
_KV_CONNECTOR_AGENT is None
64+
]):
65+
66+
if envs.VLLM_USE_V1:
67+
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
68+
config=vllm_config, role=KVConnectorRole.WORKER)
69+
else:
70+
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
71+
rank=get_world_group().rank,
72+
local_rank=get_world_group().local_rank,
73+
config=vllm_config,
74+
)

vllm/distributed/parallel_state.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@
4040
import vllm.envs as envs
4141
from vllm.distributed.device_communicators.base_device_communicator import (
4242
DeviceCommunicatorBase)
43-
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
44-
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
4543
from vllm.distributed.utils import StatelessProcessGroup
4644
from vllm.logger import init_logger
4745
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
4846
supports_custom_op)
4947

5048
if TYPE_CHECKING:
51-
from vllm.config import VllmConfig
49+
pass
5250

5351

5452
@dataclass
@@ -771,44 +769,6 @@ def get_pp_group() -> GroupCoordinator:
771769
# kept for backward compatibility
772770
get_pipeline_model_parallel_group = get_pp_group
773771

774-
# TODO: once we deprecate V0 KV transfer, we can move this to
775-
# be a non-global object.
776-
_KV_CONNECTOR_AGENT: Union[KVConnectorBase, KVConnectorBase_V1, None] = None
777-
778-
779-
def get_kv_transfer_group() -> Union[KVConnectorBase, KVConnectorBase_V1]:
780-
assert _KV_CONNECTOR_AGENT is not None, (
781-
"disaggregated KV cache transfer parallel group is not initialized")
782-
return _KV_CONNECTOR_AGENT
783-
784-
785-
def has_kv_transfer_group() -> bool:
786-
return _KV_CONNECTOR_AGENT is not None
787-
788-
789-
def is_v1_kv_transfer_group(
790-
connector: Union[KVConnectorBase_V1, KVConnectorBase,
791-
None] = None) -> bool:
792-
"""Check if the KV connector is the v1 connector.
793-
If the argument is None, it will check the global KV connector
794-
795-
Args:
796-
connector: The KV connector to check. If None, it will check the
797-
global KV connector.
798-
799-
Note:
800-
This function will no-longer be needed after the v1 KV connector
801-
becomes the default.
802-
"""
803-
if connector is None:
804-
connector = _KV_CONNECTOR_AGENT
805-
806-
if connector is None:
807-
# Global KV connector is not set
808-
return False
809-
810-
return isinstance(connector, KVConnectorBase_V1)
811-
812772

813773
@contextmanager
814774
def graph_capture(device: torch.device):
@@ -991,37 +951,6 @@ def initialize_model_parallel(
991951
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
992952

993953

994-
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
995-
"""
996-
Initialize KV cache transfer parallel group.
997-
"""
998-
999-
global _KV_CONNECTOR_AGENT
1000-
1001-
if vllm_config.kv_transfer_config is None:
1002-
return
1003-
1004-
if all([
1005-
vllm_config.kv_transfer_config.is_kv_transfer_instance,
1006-
_KV_CONNECTOR_AGENT is None
1007-
]):
1008-
from vllm.distributed.kv_transfer.kv_connector.factory import (
1009-
KVConnectorFactory)
1010-
from vllm.distributed.kv_transfer.kv_connector.v1 import (
1011-
KVConnectorRole as KVConnectorRole_V1)
1012-
1013-
kwargs = {
1014-
"rank": get_world_group().rank,
1015-
"local_rank": get_world_group().local_rank,
1016-
"config": vllm_config,
1017-
# NOTE(Kuntai):
1018-
# Parallel state is initialized in v1 worker,
1019-
# so this connector is for sure worker connector.
1020-
"role": KVConnectorRole_V1.WORKER,
1021-
}
1022-
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(**kwargs)
1023-
1024-
1025954
def ensure_model_parallel_initialized(
1026955
tensor_model_parallel_size: int,
1027956
pipeline_model_parallel_size: int,

0 commit comments

Comments
 (0)