|
31 | 31 |
|
32 | 32 | import numpy as np
|
33 | 33 | import torch
|
34 |
| -import torch.distributed as dist |
35 | 34 | import triton
|
36 | 35 | import triton.language as tl
|
37 | 36 |
|
38 | 37 | from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
39 | 38 | from sglang.srt.layers.radix_attention import RadixAttention
|
40 |
| -from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 |
| 39 | +from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 |
41 | 40 |
|
42 | 41 | logger = logging.getLogger(__name__)
|
43 | 42 |
|
44 | 43 | GB = 1024 * 1024 * 1024
|
45 | 44 | _is_cuda = is_cuda()
|
46 |
| -_is_npu = is_npu() |
47 |
| -if not _is_npu: |
48 |
| - from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla |
49 | 45 |
|
50 | 46 |
|
51 | 47 | class ReqToTokenPool:
|
@@ -153,18 +149,6 @@ def set_kv_buffer(
|
153 | 149 | ) -> None:
|
154 | 150 | raise NotImplementedError()
|
155 | 151 |
|
156 |
| - @abc.abstractmethod |
157 |
| - def load_from_host_per_layer( |
158 |
| - self, host_pool, host_indices, device_indices, layer_id, io_backend |
159 |
| - ): |
160 |
| - raise NotImplementedError() |
161 |
| - |
162 |
| - @abc.abstractmethod |
163 |
| - def backup_to_host_all_layer( |
164 |
| - self, host_pool, host_indices, device_indices, io_backend |
165 |
| - ): |
166 |
| - raise NotImplementedError() |
167 |
| - |
168 | 152 | def register_layer_transfer_counter(self, layer_transfer_counter):
|
169 | 153 | self.layer_transfer_counter = layer_transfer_counter
|
170 | 154 |
|
@@ -253,12 +237,18 @@ def _create_buffers(self):
|
253 | 237 | )
|
254 | 238 | for _ in range(self.layer_num)
|
255 | 239 | ]
|
256 |
| - self.token_stride = self.head_num * self.head_dim |
257 |
| - self.data_ptrs = torch.tensor( |
258 |
| - [x.data_ptr() for x in self.k_buffer + self.v_buffer], |
| 240 | + |
| 241 | + self.k_data_ptrs = torch.tensor( |
| 242 | + [x.data_ptr() for x in self.k_buffer], |
| 243 | + dtype=torch.uint64, |
| 244 | + device=self.device, |
| 245 | + ) |
| 246 | + self.v_data_ptrs = torch.tensor( |
| 247 | + [x.data_ptr() for x in self.v_buffer], |
259 | 248 | dtype=torch.uint64,
|
260 | 249 | device=self.device,
|
261 | 250 | )
|
| 251 | + self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0) |
262 | 252 | self.data_strides = torch.tensor(
|
263 | 253 | [
|
264 | 254 | np.prod(x.shape[1:]) * x.dtype.itemsize
|
@@ -347,47 +337,6 @@ def load_cpu_copy(self, kv_cache_cpu, indices):
|
347 | 337 | self.v_buffer[layer_id][chunk_indices] = v_chunk
|
348 | 338 | torch.cuda.synchronize()
|
349 | 339 |
|
350 |
| - def load_from_host_per_layer( |
351 |
| - self, |
352 |
| - host_pool, |
353 |
| - host_indices, |
354 |
| - device_indices, |
355 |
| - layer_id, |
356 |
| - io_backend, |
357 |
| - ): |
358 |
| - transfer_kv_per_layer( |
359 |
| - src_k=host_pool.k_buffer[layer_id], |
360 |
| - dst_k=self.k_buffer[layer_id], |
361 |
| - src_v=host_pool.v_buffer[layer_id], |
362 |
| - dst_v=self.v_buffer[layer_id], |
363 |
| - src_indices=host_indices, |
364 |
| - dst_indices=device_indices, |
365 |
| - io_backend=io_backend, |
366 |
| - page_size=self.page_size, |
367 |
| - item_size=self.token_stride, |
368 |
| - ) |
369 |
| - |
370 |
| - def backup_to_host_all_layer( |
371 |
| - self, host_pool, host_indices, device_indices, io_backend |
372 |
| - ): |
373 |
| - # todo: specialized all layer kernels for the layer-non-contiguous memory pool |
374 |
| - for layer_id in range(self.start_layer, self.start_layer + self.layer_num): |
375 |
| - if layer_id - self.start_layer >= len(host_pool.k_buffer): |
376 |
| - raise ValueError( |
377 |
| - f"Layer ID {layer_id} exceeds the number of layers in host pool." |
378 |
| - ) |
379 |
| - transfer_kv_per_layer( |
380 |
| - src_k=self.k_buffer[layer_id], |
381 |
| - dst_k=host_pool.k_buffer[layer_id], |
382 |
| - src_v=self.v_buffer[layer_id], |
383 |
| - dst_v=host_pool.v_buffer[layer_id], |
384 |
| - src_indices=device_indices, |
385 |
| - dst_indices=host_indices, |
386 |
| - io_backend=io_backend, |
387 |
| - page_size=self.page_size, |
388 |
| - item_size=self.token_stride, |
389 |
| - ) |
390 |
| - |
391 | 340 | def _get_key_buffer(self, layer_id: int):
|
392 | 341 | # for internal use of referencing
|
393 | 342 | if self.store_dtype != self.dtype:
|
@@ -602,16 +551,6 @@ def set_kv_buffer(
|
602 | 551 | layer_id_override=layer_id_pool,
|
603 | 552 | )
|
604 | 553 |
|
605 |
| - def load_from_host_per_layer( |
606 |
| - self, host_pool, host_indices, device_indices, layer_id, io_backend |
607 |
| - ): |
608 |
| - raise NotImplementedError("HiCache not supported for SWAKVPool.") |
609 |
| - |
610 |
| - def backup_to_host_all_layer( |
611 |
| - self, host_pool, host_indices, device_indices, io_backend |
612 |
| - ): |
613 |
| - raise NotImplementedError("HiCache not supported for SWAKVPool.") |
614 |
| - |
615 | 554 |
|
616 | 555 | class AscendTokenToKVPool(MHATokenToKVPool):
|
617 | 556 |
|
@@ -823,7 +762,11 @@ def __init__(
|
823 | 762 | for _ in range(layer_num)
|
824 | 763 | ]
|
825 | 764 |
|
826 |
| - self.token_stride = kv_lora_rank + qk_rope_head_dim |
| 765 | + self.data_ptrs = torch.tensor( |
| 766 | + [x.data_ptr() for x in self.kv_buffer], |
| 767 | + dtype=torch.uint64, |
| 768 | + device=self.device, |
| 769 | + ) |
827 | 770 | self.layer_transfer_counter = None
|
828 | 771 |
|
829 | 772 | kv_size = self.get_kv_size_bytes()
|
@@ -909,38 +852,6 @@ def set_mla_kv_buffer(
|
909 | 852 | self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
910 | 853 | )
|
911 | 854 |
|
912 |
| - def load_from_host_per_layer( |
913 |
| - self, host_pool, host_indices, device_indices, layer_id, io_backend |
914 |
| - ): |
915 |
| - transfer_kv_per_layer_mla( |
916 |
| - src=host_pool.kv_buffer[layer_id], |
917 |
| - dst=self.kv_buffer[layer_id], |
918 |
| - src_indices=host_indices, |
919 |
| - dst_indices=device_indices, |
920 |
| - io_backend=io_backend, |
921 |
| - page_size=self.page_size, |
922 |
| - item_size=self.token_stride, |
923 |
| - ) |
924 |
| - |
925 |
| - def backup_to_host_all_layer( |
926 |
| - self, host_pool, host_indices, device_indices, io_backend |
927 |
| - ): |
928 |
| - # todo: specialized all layer kernels for the layer-non-contiguous memory pool |
929 |
| - for layer_id in range(self.start_layer, self.start_layer + self.layer_num): |
930 |
| - if layer_id - self.start_layer >= len(host_pool.kv_buffer): |
931 |
| - raise ValueError( |
932 |
| - f"Layer ID {layer_id} exceeds the number of layers in host pool." |
933 |
| - ) |
934 |
| - transfer_kv_per_layer_mla( |
935 |
| - src=self.kv_buffer[layer_id], |
936 |
| - dst=host_pool.kv_buffer[layer_id], |
937 |
| - src_indices=device_indices, |
938 |
| - dst_indices=host_indices, |
939 |
| - io_backend=io_backend, |
940 |
| - page_size=self.page_size, |
941 |
| - item_size=self.token_stride, |
942 |
| - ) |
943 |
| - |
944 | 855 | def get_cpu_copy(self, indices):
|
945 | 856 | torch.cuda.synchronize()
|
946 | 857 | kv_cache_cpu = []
|
@@ -1131,20 +1042,6 @@ def set_kv_buffer(
|
1131 | 1042 | self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
1132 | 1043 | self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
1133 | 1044 |
|
1134 |
| - def load_from_host_per_layer( |
1135 |
| - self, host_pool, host_indices, device_indices, layer_id, io_backend |
1136 |
| - ): |
1137 |
| - raise NotImplementedError( |
1138 |
| - "HiCache not supported for DoubleSparseTokenToKVPool." |
1139 |
| - ) |
1140 |
| - |
1141 |
| - def backup_to_host_all_layer( |
1142 |
| - self, host_pool, host_indices, device_indices, io_backend |
1143 |
| - ): |
1144 |
| - raise NotImplementedError( |
1145 |
| - "HiCache not supported for DoubleSparseTokenToKVPool." |
1146 |
| - ) |
1147 |
| - |
1148 | 1045 |
|
1149 | 1046 | @triton.jit
|
1150 | 1047 | def copy_all_layer_kv_cache(
|
|
0 commit comments