Skip to content

Commit 2d8bf2c

Browse files
DarkSharpnessssssnow
authored andcommitted
[Feature] Radix Tree in C++ (#7369)
1 parent f3275cd commit 2d8bf2c

File tree

12 files changed

+1466
-1
lines changed

12 files changed

+1466
-1
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,23 @@ def init_memory_pool_and_cache(self):
569569
page_size=self.page_size,
570570
)
571571
else:
572-
if self.enable_hierarchical_cache:
572+
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
573+
# lazy import to avoid JIT overhead
574+
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
575+
576+
self.tree_cache = RadixCacheCpp(
577+
disable=False,
578+
use_hicache=self.enable_hierarchical_cache,
579+
req_to_token_pool=self.req_to_token_pool,
580+
token_to_kv_pool=self.token_to_kv_pool_allocator,
581+
tp_cache_group=self.tp_cpu_group,
582+
page_size=self.page_size,
583+
hicache_ratio=server_args.hicache_ratio,
584+
hicache_size=server_args.hicache_size,
585+
hicache_write_policy=server_args.hicache_write_policy,
586+
enable_kv_cache_events=self.enable_kv_cache_events,
587+
)
588+
elif self.enable_hierarchical_cache:
573589
self.tree_cache = HiRadixCache(
574590
req_to_token_pool=self.req_to_token_pool,
575591
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../../../sgl-kernel/.clang-format
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
#include <cstddef>
3+
#include <cstdint>
4+
#include <source_location>
5+
#include <span>
6+
#include <stdexcept>
7+
#include <string>
8+
#include <vector>
9+
10+
namespace radix_tree_v2 {
11+
12+
using token_t = std::int32_t;
13+
using token_vec_t = std::vector<token_t>;
14+
using token_slice = std::span<const token_t>;
15+
using NodeHandle = std::size_t;
16+
using IOTicket = std::uint32_t;
17+
18+
inline void _assert(
19+
bool condition,
20+
const char* message = "Assertion failed",
21+
std::source_location loc = std::source_location::current()) {
22+
if (!condition) [[unlikely]] {
23+
std::string msg = message;
24+
msg = msg + " at " + loc.file_name() + ":" + std::to_string(loc.line()) + " in " + loc.function_name();
25+
throw std::runtime_error(msg);
26+
}
27+
}
28+
29+
} // namespace radix_tree_v2
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import TYPE_CHECKING, List, Optional, Tuple
5+
6+
import torch
7+
from torch.utils.cpp_extension import load
8+
9+
_abs_path = os.path.dirname(os.path.abspath(__file__))
10+
radix_tree_cpp = load(
11+
name="radix_tree_cpp",
12+
sources=[
13+
f"{_abs_path}/tree_v2_binding.cpp",
14+
f"{_abs_path}/tree_v2_debug.cpp",
15+
f"{_abs_path}/tree_v2.cpp",
16+
],
17+
extra_cflags=["-O3", "-std=c++20"],
18+
)
19+
20+
if TYPE_CHECKING:
21+
22+
class TreeNodeCpp:
23+
"""
24+
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
25+
"""
26+
27+
class IOHandle:
28+
"""
29+
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
30+
"""
31+
32+
class RadixTreeCpp:
33+
def __init__(
34+
self,
35+
disabled: bool,
36+
host_size: Optional[int],
37+
page_size: int,
38+
write_through_threshold: int,
39+
):
40+
"""
41+
Initializes the RadixTreeCpp instance.
42+
Args:
43+
disabled (bool): If True, the radix tree is disabled.
44+
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
45+
page_size (int): Size of the page for the radix tree.
46+
write_through_threshold (int): Threshold for writing through from GPU to CPU.
47+
"""
48+
self.tree = radix_tree_cpp.RadixTree( # type: ignore
49+
disabled, host_size, page_size, write_through_threshold
50+
)
51+
52+
def match_prefix(
53+
self, prefix: List[int]
54+
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
55+
"""
56+
Matches a prefix in the radix tree.
57+
Args:
58+
prefix (List[int]): The prefix to match.
59+
Returns:
60+
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
61+
0. A list of indices that is matched by the prefix on the GPU.
62+
1. Sum length of the indices matched on the CPU.
63+
2. The last node of the prefix matched on the GPU.
64+
3. The last node of the prefix matched on the CPU.
65+
"""
66+
return self.tree.match_prefix(prefix)
67+
68+
def evict(self, num_tokens: int) -> List[torch.Tensor]:
69+
"""
70+
Evicts a number of tokens from the radix tree.
71+
Args:
72+
num_tokens (int): The number of tokens to evict.
73+
Returns:
74+
List[torch.Tensor]: A list of indices that were evicted.
75+
"""
76+
return self.tree.evict(num_tokens)
77+
78+
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
79+
"""
80+
Locks or unlocks a reference to a tree node.
81+
After locking, the node will not be evicted from the radix tree.
82+
Args:
83+
handle (TreeNodeCpp): The tree node to lock or unlock.
84+
lock (bool): If True, locks the node; if False, unlocks it.
85+
"""
86+
return self.tree.lock_ref(handle, lock)
87+
88+
def writing_through(
89+
self, key: List[int], indices: torch.Tensor
90+
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
91+
"""
92+
Inserts a key-value pair into the radix tree and perform write-through check.
93+
Args:
94+
key (List[int]): The key to insert.
95+
indices (torch.Tensor): The value associated with the key.
96+
Returns:
97+
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
98+
0. A list of (IOHandle, device indices, host indices) tuples.
99+
These IOhandles require write-through to the CPU in python side.
100+
1. The number of indices that are matched on device.
101+
"""
102+
return self.tree.writing_through(key, indices)
103+
104+
def loading_onboard(
105+
self,
106+
host_node: TreeNodeCpp,
107+
new_device_indices: torch.Tensor,
108+
) -> Tuple[IOHandle, List[torch.Tensor]]:
109+
"""
110+
Updates the device indices of tree nodes within a range on the tree.
111+
Args:
112+
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
113+
new_device_indices (torch.Tensor): The new device indices to set.
114+
The length of this tensor must be exactly host indices length.
115+
Returns:
116+
Tuple[IOHandle, List[torch.Tensor]]:
117+
0. An IOHandle that requires loading to the CPU in python side.
118+
1. A list of host indices corresponding to the new device indices.
119+
"""
120+
return self.tree.loading_onboard(host_node, new_device_indices)
121+
122+
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
123+
"""
124+
Commits the write-through process for a tree node.
125+
Args:
126+
handle (IOHandle): The IOHandle to commit.
127+
success (bool): If True, commits the write-through; if False, just indicates failure.
128+
"""
129+
return self.tree.commit_writing_through(handle, success)
130+
131+
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
132+
"""
133+
Commits the load onboard process for tree nodes within a range on the tree.
134+
Args:
135+
handle (IOHandle): The IOHandle to commit.
136+
success (bool): If True, commits the load-onboard; if False, just indicates failure.
137+
"""
138+
return self.tree.commit_loading_onboard(handle, success)
139+
140+
def evictable_size(self) -> int:
141+
"""
142+
Returns the size of the evictable part of the radix tree.
143+
This is the size of the part that can be evicted from the GPU (ref_count = 0).
144+
Returns:
145+
int: The size of the evictable part.
146+
"""
147+
return self.tree.evictable_size()
148+
149+
def protected_size(self) -> int:
150+
"""
151+
Returns the size of the protected part of the radix tree.
152+
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
153+
Returns:
154+
int: The size of the protected part.
155+
"""
156+
return self.tree.protected_size()
157+
158+
def total_size(self) -> int:
159+
"""
160+
Returns the total size of the radix tree (including CPU nodes).
161+
Returns:
162+
int: The total size of the radix tree.
163+
"""
164+
return self.tree.total_size()
165+
166+
def reset(self) -> None:
167+
"""
168+
Resets the radix tree, clearing all nodes and indices.
169+
"""
170+
return self.tree.reset()
171+
172+
def debug_print(self) -> None:
173+
"""
174+
Prints the internal state of the radix tree for debugging purposes.
175+
"""
176+
return self.tree.debug_print()
177+
178+
else:
179+
# Real implementation of the classes for runtime
180+
RadixTreeCpp = radix_tree_cpp.RadixTree
181+
TreeNodeCpp = object
182+
IOHandle = object
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#include "tree_v2.h"
2+
3+
#include <ATen/core/TensorBody.h>
4+
#include <ATen/ops/empty.h>
5+
#include <ATen/ops/tensor.h>
6+
#include <ATen/ops/zeros.h>
7+
#include <c10/util/irange.h>
8+
9+
#include <cstddef>
10+
#include <memory>
11+
#include <queue>
12+
#include <stdexcept>
13+
#include <utility>
14+
#include <vector>
15+
16+
#include "common.h"
17+
#include "tree_v2_impl.h"
18+
#include "tree_v2_node.h"
19+
20+
namespace radix_tree_v2 {
21+
22+
static NodeHandle node2id(TreeNode* node) {
23+
return node->node_id;
24+
}
25+
26+
// compare function for the TreeNode pointers based on their time
27+
// we use LRU, so we want to evict the least recently used nodes
28+
// since std::priority_queue is a max-heap, we need to reverse the comparison
29+
static constexpr auto cmp = [](TreeNode* lhs, TreeNode* rhs) { return lhs->time() > rhs->time(); };
30+
31+
RadixTree::RadixTree(bool disabled, std::optional<std::size_t> host_size, std::size_t page_size, std::size_t threshold)
32+
: m_impl(std::make_unique<Impl>(disabled, host_size.has_value(), page_size, host_size.value_or(0), threshold)) {}
33+
34+
RadixTree::~RadixTree() = default;
35+
36+
std::tuple<std::vector<at::Tensor>, std::size_t, NodeHandle, NodeHandle>
37+
RadixTree::match_prefix(const token_vec_t& _key) {
38+
if (m_impl->disabled) return {};
39+
40+
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
41+
const auto [host_node, _] = m_impl->tree_walk(key);
42+
43+
// walk up to the first non-evicted node
44+
std::size_t host_hit_length = 0;
45+
const auto device_node = host_node;
46+
47+
// collect all the device indices
48+
std::vector<at::Tensor> indices{};
49+
walk_to_root(device_node, [&](TreeNode* n) { indices.push_back(n->device_indices()); });
50+
std::reverse(indices.begin(), indices.end());
51+
52+
return {std::move(indices), host_hit_length, node2id(device_node), node2id(host_node)};
53+
}
54+
55+
std::vector<at::Tensor> RadixTree::evict(std::size_t num_tokens) {
56+
if (m_impl->disabled || num_tokens == 0) return {};
57+
auto heap = std::priority_queue{cmp, m_impl->collect_leaves_device()};
58+
std::vector<at::Tensor> evicted_values;
59+
// evict nodes until we reach the desired number of tokens
60+
std::size_t num_evict = 0;
61+
while (num_evict < num_tokens && !heap.empty()) {
62+
const auto node = heap.top();
63+
heap.pop();
64+
// when ref_count == 0, can't be writing through
65+
_assert(node->on_gpu() && node->ref_count == 0);
66+
if (!node->is_io_free()) continue; // skip nodes that are undergoing IO (i.e. indices protected)
67+
evicted_values.push_back(node->device_indices());
68+
num_evict += node->length();
69+
const auto parent = node->parent();
70+
m_impl->remove_device_node(node);
71+
if (parent->is_leaf_device() && parent->ref_count == 0)
72+
heap.push(parent); // push parent to the heap if it is now a free leaf
73+
}
74+
75+
return evicted_values;
76+
}
77+
78+
std::tuple<std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>>, std::size_t>
79+
RadixTree::writing_through(const token_vec_t& _key, at::Tensor value) {
80+
if (m_impl->disabled) return {};
81+
_assert(_key.size() == std::size_t(value.size(0)), "Key and value must have the same size");
82+
83+
// just align the key to the page size, clip the unaligned tail
84+
const auto key = token_slice{_key.data(), m_impl->align(_key.size())};
85+
86+
// walk the tree to find the right place to insert
87+
const auto [host_node, host_prefix_length] = m_impl->tree_walk(key);
88+
89+
// insert and create a new node if the remaining part of the key is not empty
90+
if (host_prefix_length != key.size()) {
91+
m_impl->create_device_node(
92+
host_node,
93+
{key.begin() + host_prefix_length, key.end()},
94+
value.slice(/*dim=*/0, host_prefix_length, key.size()));
95+
}
96+
97+
// add the hit count for the device node
98+
walk_to_root(host_node, [&](TreeNode* n) { n->hit_count++; });
99+
100+
std::vector<std::tuple<IOTicket, at::Tensor, at::Tensor>> result;
101+
102+
// don't write through if hicache is disabled (no host memory), fast path
103+
if (!m_impl->use_hicache) return {std::move(result), host_prefix_length};
104+
throw std::runtime_error("Not implemented yet");
105+
}
106+
107+
std::tuple<IOTicket, std::vector<at::Tensor>> RadixTree::loading_onboard(NodeHandle, at::Tensor) {
108+
if (m_impl->disabled) return {};
109+
throw std::runtime_error("Not implemented yet");
110+
}
111+
112+
void RadixTree::commit_writing_through(IOTicket, bool) {
113+
if (m_impl->disabled) return;
114+
throw std::runtime_error("Not implemented yet");
115+
}
116+
117+
void RadixTree::commit_loading_onboard(IOTicket, bool) {
118+
if (m_impl->disabled) return;
119+
throw std::runtime_error("Not implemented yet");
120+
}
121+
122+
void RadixTree::reset() {
123+
m_impl->reset();
124+
}
125+
126+
void RadixTree::lock_ref(NodeHandle node_id, bool increment) {
127+
if (m_impl->disabled) return;
128+
m_impl->lock_ref(node_id, increment);
129+
}
130+
131+
std::size_t RadixTree::evictable_size() const {
132+
return m_impl->evictable_size();
133+
}
134+
135+
std::size_t RadixTree::protected_size() const {
136+
return m_impl->protected_size();
137+
}
138+
139+
std::size_t RadixTree::total_size() const {
140+
return m_impl->total_size();
141+
}
142+
143+
} // namespace radix_tree_v2

0 commit comments

Comments
 (0)