|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +from typing import TYPE_CHECKING, List, Optional, Tuple |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.utils.cpp_extension import load |
| 8 | + |
| 9 | +_abs_path = os.path.dirname(os.path.abspath(__file__)) |
| 10 | +radix_tree_cpp = load( |
| 11 | + name="radix_tree_cpp", |
| 12 | + sources=[ |
| 13 | + f"{_abs_path}/tree_v2_binding.cpp", |
| 14 | + f"{_abs_path}/tree_v2_debug.cpp", |
| 15 | + f"{_abs_path}/tree_v2.cpp", |
| 16 | + ], |
| 17 | + extra_cflags=["-O3", "-std=c++20"], |
| 18 | +) |
| 19 | + |
| 20 | +if TYPE_CHECKING: |
| 21 | + |
| 22 | + class TreeNodeCpp: |
| 23 | + """ |
| 24 | + A placeholder for the TreeNode class. Cannot be constructed elsewhere. |
| 25 | + """ |
| 26 | + |
| 27 | + class IOHandle: |
| 28 | + """ |
| 29 | + A placeholder for the IOHandle class. Cannot be constructed elsewhere. |
| 30 | + """ |
| 31 | + |
| 32 | + class RadixTreeCpp: |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + disabled: bool, |
| 36 | + host_size: Optional[int], |
| 37 | + page_size: int, |
| 38 | + write_through_threshold: int, |
| 39 | + ): |
| 40 | + """ |
| 41 | + Initializes the RadixTreeCpp instance. |
| 42 | + Args: |
| 43 | + disabled (bool): If True, the radix tree is disabled. |
| 44 | + host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree. |
| 45 | + page_size (int): Size of the page for the radix tree. |
| 46 | + write_through_threshold (int): Threshold for writing through from GPU to CPU. |
| 47 | + """ |
| 48 | + self.tree = radix_tree_cpp.RadixTree( # type: ignore |
| 49 | + disabled, host_size, page_size, write_through_threshold |
| 50 | + ) |
| 51 | + |
| 52 | + def match_prefix( |
| 53 | + self, prefix: List[int] |
| 54 | + ) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]: |
| 55 | + """ |
| 56 | + Matches a prefix in the radix tree. |
| 57 | + Args: |
| 58 | + prefix (List[int]): The prefix to match. |
| 59 | + Returns: |
| 60 | + Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]: |
| 61 | + 0. A list of indices that is matched by the prefix on the GPU. |
| 62 | + 1. Sum length of the indices matched on the CPU. |
| 63 | + 2. The last node of the prefix matched on the GPU. |
| 64 | + 3. The last node of the prefix matched on the CPU. |
| 65 | + """ |
| 66 | + return self.tree.match_prefix(prefix) |
| 67 | + |
| 68 | + def evict(self, num_tokens: int) -> List[torch.Tensor]: |
| 69 | + """ |
| 70 | + Evicts a number of tokens from the radix tree. |
| 71 | + Args: |
| 72 | + num_tokens (int): The number of tokens to evict. |
| 73 | + Returns: |
| 74 | + List[torch.Tensor]: A list of indices that were evicted. |
| 75 | + """ |
| 76 | + return self.tree.evict(num_tokens) |
| 77 | + |
| 78 | + def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None: |
| 79 | + """ |
| 80 | + Locks or unlocks a reference to a tree node. |
| 81 | + After locking, the node will not be evicted from the radix tree. |
| 82 | + Args: |
| 83 | + handle (TreeNodeCpp): The tree node to lock or unlock. |
| 84 | + lock (bool): If True, locks the node; if False, unlocks it. |
| 85 | + """ |
| 86 | + return self.tree.lock_ref(handle, lock) |
| 87 | + |
| 88 | + def writing_through( |
| 89 | + self, key: List[int], indices: torch.Tensor |
| 90 | + ) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]: |
| 91 | + """ |
| 92 | + Inserts a key-value pair into the radix tree and perform write-through check. |
| 93 | + Args: |
| 94 | + key (List[int]): The key to insert. |
| 95 | + indices (torch.Tensor): The value associated with the key. |
| 96 | + Returns: |
| 97 | + Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]: |
| 98 | + 0. A list of (IOHandle, device indices, host indices) tuples. |
| 99 | + These IOhandles require write-through to the CPU in python side. |
| 100 | + 1. The number of indices that are matched on device. |
| 101 | + """ |
| 102 | + return self.tree.writing_through(key, indices) |
| 103 | + |
| 104 | + def loading_onboard( |
| 105 | + self, |
| 106 | + host_node: TreeNodeCpp, |
| 107 | + new_device_indices: torch.Tensor, |
| 108 | + ) -> Tuple[IOHandle, List[torch.Tensor]]: |
| 109 | + """ |
| 110 | + Updates the device indices of tree nodes within a range on the tree. |
| 111 | + Args: |
| 112 | + host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node. |
| 113 | + new_device_indices (torch.Tensor): The new device indices to set. |
| 114 | + The length of this tensor must be exactly host indices length. |
| 115 | + Returns: |
| 116 | + Tuple[IOHandle, List[torch.Tensor]]: |
| 117 | + 0. An IOHandle that requires loading to the CPU in python side. |
| 118 | + 1. A list of host indices corresponding to the new device indices. |
| 119 | + """ |
| 120 | + return self.tree.loading_onboard(host_node, new_device_indices) |
| 121 | + |
| 122 | + def commit_writing_through(self, handle: IOHandle, success: bool) -> None: |
| 123 | + """ |
| 124 | + Commits the write-through process for a tree node. |
| 125 | + Args: |
| 126 | + handle (IOHandle): The IOHandle to commit. |
| 127 | + success (bool): If True, commits the write-through; if False, just indicates failure. |
| 128 | + """ |
| 129 | + return self.tree.commit_writing_through(handle, success) |
| 130 | + |
| 131 | + def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None: |
| 132 | + """ |
| 133 | + Commits the load onboard process for tree nodes within a range on the tree. |
| 134 | + Args: |
| 135 | + handle (IOHandle): The IOHandle to commit. |
| 136 | + success (bool): If True, commits the load-onboard; if False, just indicates failure. |
| 137 | + """ |
| 138 | + return self.tree.commit_loading_onboard(handle, success) |
| 139 | + |
| 140 | + def evictable_size(self) -> int: |
| 141 | + """ |
| 142 | + Returns the size of the evictable part of the radix tree. |
| 143 | + This is the size of the part that can be evicted from the GPU (ref_count = 0). |
| 144 | + Returns: |
| 145 | + int: The size of the evictable part. |
| 146 | + """ |
| 147 | + return self.tree.evictable_size() |
| 148 | + |
| 149 | + def protected_size(self) -> int: |
| 150 | + """ |
| 151 | + Returns the size of the protected part of the radix tree. |
| 152 | + This is the size of the part that cannot be evicted from the GPU (ref_count > 0). |
| 153 | + Returns: |
| 154 | + int: The size of the protected part. |
| 155 | + """ |
| 156 | + return self.tree.protected_size() |
| 157 | + |
| 158 | + def total_size(self) -> int: |
| 159 | + """ |
| 160 | + Returns the total size of the radix tree (including CPU nodes). |
| 161 | + Returns: |
| 162 | + int: The total size of the radix tree. |
| 163 | + """ |
| 164 | + return self.tree.total_size() |
| 165 | + |
| 166 | + def reset(self) -> None: |
| 167 | + """ |
| 168 | + Resets the radix tree, clearing all nodes and indices. |
| 169 | + """ |
| 170 | + return self.tree.reset() |
| 171 | + |
| 172 | + def debug_print(self) -> None: |
| 173 | + """ |
| 174 | + Prints the internal state of the radix tree for debugging purposes. |
| 175 | + """ |
| 176 | + return self.tree.debug_print() |
| 177 | + |
| 178 | +else: |
| 179 | + # Real implementation of the classes for runtime |
| 180 | + RadixTreeCpp = radix_tree_cpp.RadixTree |
| 181 | + TreeNodeCpp = object |
| 182 | + IOHandle = object |
0 commit comments