Skip to content

Commit f0ecff7

Browse files
xiezhq-hermannlifuhuang
authored andcommitted
Interface change for kvcache io to support page first layout (#8318)
1 parent b325b5a commit f0ecff7

File tree

6 files changed

+371
-171
lines changed

6 files changed

+371
-171
lines changed

python/sglang/srt/managers/cache_controller.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,7 @@ def __init__(
231231
self.mem_pool_host = mem_pool_host
232232
self.write_policy = write_policy
233233
self.page_size = page_size
234-
# using kernel for small page KV cache transfer and DMA for large pages
235-
if not io_backend:
236-
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
237-
self.io_backend = (
238-
"direct"
239-
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
240-
else "kernel"
241-
)
242-
else:
243-
self.io_backend = io_backend
234+
self.io_backend = io_backend
244235

245236
self.enable_storage = False
246237
# todo: move backend initialization to storage backend module
@@ -447,11 +438,8 @@ def write_thread_func_direct(self):
447438
host_indices, device_indices = self.move_indices(
448439
operation.host_indices, operation.device_indices
449440
)
450-
self.mem_pool_device.backup_to_host_all_layer(
451-
self.mem_pool_host,
452-
host_indices,
453-
device_indices,
454-
self.io_backend,
441+
self.mem_pool_host.backup_from_device_all_layer(
442+
self.mem_pool_device, host_indices, device_indices, self.io_backend
455443
)
456444
self.write_stream.synchronize()
457445
self.mem_pool_host.complete_io(operation.host_indices)
@@ -491,8 +479,8 @@ def load_thread_func_layer_by_layer(self):
491479
batch_operation.host_indices, batch_operation.device_indices
492480
)
493481
for i in range(self.mem_pool_host.layer_num):
494-
self.mem_pool_device.load_from_host_per_layer(
495-
self.mem_pool_host,
482+
self.mem_pool_host.load_to_device_per_layer(
483+
self.mem_pool_device,
496484
host_indices,
497485
device_indices,
498486
i,

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def init_memory_pool_and_cache(self):
588588
== "fa3" # hot fix for incompatibility
589589
else server_args.hicache_io_backend
590590
),
591+
hicache_mem_layout=server_args.hicache_mem_layout,
591592
hicache_storage_backend=server_args.hicache_storage_backend,
592593
)
593594
self.tp_worker.register_hicache_layer_transfer_counter(

python/sglang/srt/mem_cache/hiradix_cache.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,33 @@ def __init__(
3535
hicache_size: int,
3636
hicache_write_policy: str,
3737
hicache_io_backend: str,
38+
hicache_mem_layout: str,
3839
hicache_storage_backend: Optional[str] = None,
3940
):
41+
42+
if hicache_io_backend == "direct":
43+
if hicache_mem_layout == "page_first":
44+
hicache_mem_layout = "layer_first"
45+
logger.warning(
46+
"Page first layout is not supported with direct IO backend, switching to layer first layout"
47+
)
48+
4049
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
4150
if isinstance(self.kv_cache, MHATokenToKVPool):
4251
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
43-
self.kv_cache, hicache_ratio, hicache_size, page_size
52+
self.kv_cache,
53+
hicache_ratio,
54+
hicache_size,
55+
page_size,
56+
hicache_mem_layout,
4457
)
4558
elif isinstance(self.kv_cache, MLATokenToKVPool):
4659
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
47-
self.kv_cache, hicache_ratio, hicache_size, page_size
60+
self.kv_cache,
61+
hicache_ratio,
62+
hicache_size,
63+
page_size,
64+
hicache_mem_layout,
4865
)
4966
else:
5067
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 15 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,17 @@
3131

3232
import numpy as np
3333
import torch
34-
import torch.distributed as dist
3534
import triton
3635
import triton.language as tl
3736

3837
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
3938
from sglang.srt.layers.radix_attention import RadixAttention
40-
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
39+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
4140

4241
logger = logging.getLogger(__name__)
4342

4443
GB = 1024 * 1024 * 1024
4544
_is_cuda = is_cuda()
46-
_is_npu = is_npu()
47-
if not _is_npu:
48-
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
4945

5046

5147
class ReqToTokenPool:
@@ -153,18 +149,6 @@ def set_kv_buffer(
153149
) -> None:
154150
raise NotImplementedError()
155151

156-
@abc.abstractmethod
157-
def load_from_host_per_layer(
158-
self, host_pool, host_indices, device_indices, layer_id, io_backend
159-
):
160-
raise NotImplementedError()
161-
162-
@abc.abstractmethod
163-
def backup_to_host_all_layer(
164-
self, host_pool, host_indices, device_indices, io_backend
165-
):
166-
raise NotImplementedError()
167-
168152
def register_layer_transfer_counter(self, layer_transfer_counter):
169153
self.layer_transfer_counter = layer_transfer_counter
170154

@@ -253,12 +237,18 @@ def _create_buffers(self):
253237
)
254238
for _ in range(self.layer_num)
255239
]
256-
self.token_stride = self.head_num * self.head_dim
257-
self.data_ptrs = torch.tensor(
258-
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
240+
241+
self.k_data_ptrs = torch.tensor(
242+
[x.data_ptr() for x in self.k_buffer],
243+
dtype=torch.uint64,
244+
device=self.device,
245+
)
246+
self.v_data_ptrs = torch.tensor(
247+
[x.data_ptr() for x in self.v_buffer],
259248
dtype=torch.uint64,
260249
device=self.device,
261250
)
251+
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
262252
self.data_strides = torch.tensor(
263253
[
264254
np.prod(x.shape[1:]) * x.dtype.itemsize
@@ -347,47 +337,6 @@ def load_cpu_copy(self, kv_cache_cpu, indices):
347337
self.v_buffer[layer_id][chunk_indices] = v_chunk
348338
torch.cuda.synchronize()
349339

350-
def load_from_host_per_layer(
351-
self,
352-
host_pool,
353-
host_indices,
354-
device_indices,
355-
layer_id,
356-
io_backend,
357-
):
358-
transfer_kv_per_layer(
359-
src_k=host_pool.k_buffer[layer_id],
360-
dst_k=self.k_buffer[layer_id],
361-
src_v=host_pool.v_buffer[layer_id],
362-
dst_v=self.v_buffer[layer_id],
363-
src_indices=host_indices,
364-
dst_indices=device_indices,
365-
io_backend=io_backend,
366-
page_size=self.page_size,
367-
item_size=self.token_stride,
368-
)
369-
370-
def backup_to_host_all_layer(
371-
self, host_pool, host_indices, device_indices, io_backend
372-
):
373-
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
374-
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
375-
if layer_id - self.start_layer >= len(host_pool.k_buffer):
376-
raise ValueError(
377-
f"Layer ID {layer_id} exceeds the number of layers in host pool."
378-
)
379-
transfer_kv_per_layer(
380-
src_k=self.k_buffer[layer_id],
381-
dst_k=host_pool.k_buffer[layer_id],
382-
src_v=self.v_buffer[layer_id],
383-
dst_v=host_pool.v_buffer[layer_id],
384-
src_indices=device_indices,
385-
dst_indices=host_indices,
386-
io_backend=io_backend,
387-
page_size=self.page_size,
388-
item_size=self.token_stride,
389-
)
390-
391340
def _get_key_buffer(self, layer_id: int):
392341
# for internal use of referencing
393342
if self.store_dtype != self.dtype:
@@ -602,16 +551,6 @@ def set_kv_buffer(
602551
layer_id_override=layer_id_pool,
603552
)
604553

605-
def load_from_host_per_layer(
606-
self, host_pool, host_indices, device_indices, layer_id, io_backend
607-
):
608-
raise NotImplementedError("HiCache not supported for SWAKVPool.")
609-
610-
def backup_to_host_all_layer(
611-
self, host_pool, host_indices, device_indices, io_backend
612-
):
613-
raise NotImplementedError("HiCache not supported for SWAKVPool.")
614-
615554

616555
class AscendTokenToKVPool(MHATokenToKVPool):
617556

@@ -823,7 +762,11 @@ def __init__(
823762
for _ in range(layer_num)
824763
]
825764

826-
self.token_stride = kv_lora_rank + qk_rope_head_dim
765+
self.data_ptrs = torch.tensor(
766+
[x.data_ptr() for x in self.kv_buffer],
767+
dtype=torch.uint64,
768+
device=self.device,
769+
)
827770
self.layer_transfer_counter = None
828771

829772
kv_size = self.get_kv_size_bytes()
@@ -909,38 +852,6 @@ def set_mla_kv_buffer(
909852
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
910853
)
911854

912-
def load_from_host_per_layer(
913-
self, host_pool, host_indices, device_indices, layer_id, io_backend
914-
):
915-
transfer_kv_per_layer_mla(
916-
src=host_pool.kv_buffer[layer_id],
917-
dst=self.kv_buffer[layer_id],
918-
src_indices=host_indices,
919-
dst_indices=device_indices,
920-
io_backend=io_backend,
921-
page_size=self.page_size,
922-
item_size=self.token_stride,
923-
)
924-
925-
def backup_to_host_all_layer(
926-
self, host_pool, host_indices, device_indices, io_backend
927-
):
928-
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
929-
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
930-
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
931-
raise ValueError(
932-
f"Layer ID {layer_id} exceeds the number of layers in host pool."
933-
)
934-
transfer_kv_per_layer_mla(
935-
src=self.kv_buffer[layer_id],
936-
dst=host_pool.kv_buffer[layer_id],
937-
src_indices=device_indices,
938-
dst_indices=host_indices,
939-
io_backend=io_backend,
940-
page_size=self.page_size,
941-
item_size=self.token_stride,
942-
)
943-
944855
def get_cpu_copy(self, indices):
945856
torch.cuda.synchronize()
946857
kv_cache_cpu = []
@@ -1131,20 +1042,6 @@ def set_kv_buffer(
11311042
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
11321043
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
11331044

1134-
def load_from_host_per_layer(
1135-
self, host_pool, host_indices, device_indices, layer_id, io_backend
1136-
):
1137-
raise NotImplementedError(
1138-
"HiCache not supported for DoubleSparseTokenToKVPool."
1139-
)
1140-
1141-
def backup_to_host_all_layer(
1142-
self, host_pool, host_indices, device_indices, io_backend
1143-
):
1144-
raise NotImplementedError(
1145-
"HiCache not supported for DoubleSparseTokenToKVPool."
1146-
)
1147-
11481045

11491046
@triton.jit
11501047
def copy_all_layer_kv_cache(

0 commit comments

Comments
 (0)