Skip to content

Commit 4859658

Browse files
lifuhuangssssnow
authored andcommitted
Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)
1 parent 885a99e commit 4859658

File tree

11 files changed

+399
-260
lines changed

11 files changed

+399
-260
lines changed

python/sglang/srt/lora/lora_manager.py

Lines changed: 133 additions & 169 deletions
Large diffs are not rendered by default.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
15+
16+
import asyncio
17+
from dataclasses import dataclass, field, fields
18+
from typing import Dict, List, Optional, Union
19+
from uuid import uuid4
20+
21+
22+
@dataclass(frozen=True, slots=True)
23+
class LoRARef:
24+
"""
25+
Reference record for a LoRA model.
26+
27+
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
28+
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
29+
keys (e.g., radix cache).
30+
"""
31+
32+
lora_id: str = field(default_factory=lambda: uuid4().hex)
33+
lora_name: Optional[str] = None
34+
lora_path: Optional[str] = None
35+
36+
def __post_init__(self):
37+
if self.lora_id is None:
38+
raise ValueError("lora_id cannot be None")
39+
40+
def __str__(self) -> str:
41+
parts = [
42+
f"{f.name}={value}"
43+
for f in fields(self)
44+
if (value := getattr(self, f.name)) is not None
45+
]
46+
return f"{self.__class__.__name__}({', '.join(parts)})"
47+
48+
49+
class LoRARegistry:
50+
"""
51+
The central registry to keep track of available LoRA adapters.
52+
53+
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
54+
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
55+
"""
56+
57+
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
58+
assert lora_paths is None or all(
59+
isinstance(lora, LoRARef) for lora in lora_paths.values()
60+
), (
61+
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
62+
"Please file an issue if you see this error."
63+
)
64+
65+
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
66+
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
67+
68+
async def register(self, lora_ref: LoRARef):
69+
"""
70+
Register a new LoRARef object in the registry.
71+
72+
Args:
73+
lora_ref (LoRARef): The LoRARef object to register.
74+
"""
75+
if lora_ref.lora_name in self._registry:
76+
raise ValueError(
77+
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
78+
)
79+
self._registry[lora_ref.lora_name] = lora_ref
80+
81+
async def unregister(self, lora_name: str) -> str:
82+
"""
83+
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
84+
85+
Args:
86+
lora_name (str): The name of the LoRA model to unregister.
87+
"""
88+
lora_ref = self._registry.get(lora_name, None)
89+
if lora_ref is None:
90+
raise ValueError(
91+
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
92+
)
93+
del self._registry[lora_name]
94+
95+
return lora_ref.lora_id
96+
97+
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
98+
"""
99+
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
100+
by incrementing its counter.
101+
102+
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
103+
"""
104+
105+
async def _acquire_single(name: str) -> str:
106+
lora_ref = self._registry.get(name, None)
107+
if lora_ref is None:
108+
raise ValueError(
109+
f"The following requested LoRA adapters are not loaded: {name}\n"
110+
f"Loaded adapters: {self._registry.keys()}."
111+
)
112+
# await self._counters[lora_ref.lora_id].increment()
113+
return lora_ref.lora_id
114+
115+
if isinstance(lora_name, str):
116+
lora_id = await _acquire_single(lora_name)
117+
return lora_id
118+
elif isinstance(lora_name, list):
119+
lora_ids = await asyncio.gather(
120+
*[_acquire_single(name) for name in lora_name]
121+
)
122+
return lora_ids
123+
else:
124+
raise TypeError("lora_name must be either a string or a list of strings.")

python/sglang/srt/lora/mem_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def prepare_lora_batch(
153153
self,
154154
cur_uids: Set[Optional[str]],
155155
lora_adapters: Dict[str, LoRAAdapter],
156-
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
156+
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
157157
):
158158
def get_available_buffer_slot():
159159
for buffer_id in range(self.max_loras_per_batch):
@@ -186,7 +186,7 @@ def load_lora_weight_to_buffer(
186186
uid: str,
187187
buffer_id: int,
188188
lora_adapter: LoRAAdapter,
189-
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
189+
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
190190
):
191191
def load_lora_weight_tensor(
192192
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]

python/sglang/srt/managers/io_struct.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from enum import Enum
2323
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2424

25+
from sglang.srt.lora.lora_registry import LoRARef
2526
from sglang.srt.managers.schedule_batch import BaseFinishReason
2627
from sglang.srt.multimodal.mm_utils import has_valid_data
2728
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
10671068
lora_name: str
10681069
# The path of loading.
10691070
lora_path: str
1071+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1072+
lora_id: Optional[str] = None
1073+
1074+
def to_ref(self) -> LoRARef:
1075+
return LoRARef(
1076+
lora_id=self.lora_id,
1077+
lora_name=self.lora_name,
1078+
lora_path=self.lora_path,
1079+
)
10701080

10711081

10721082
@dataclass
10731083
class UnloadLoRAAdapterReqInput:
10741084
# The name of lora module to unload.
10751085
lora_name: str
1086+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1087+
lora_id: Optional[str] = None
1088+
1089+
def to_ref(self) -> LoRARef:
1090+
return LoRARef(
1091+
lora_id=self.lora_id,
1092+
lora_name=self.lora_name,
1093+
)
10761094

10771095

10781096
@dataclass
10791097
class LoRAUpdateResult:
10801098
success: bool
10811099
error_message: Optional[str] = None
1082-
loaded_adapters: Dict[str, str] = field(default_factory=dict)
1100+
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
10831101

10841102

10851103
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def __init__(
247247
self.pp_size = server_args.pp_size
248248
self.dp_size = server_args.dp_size
249249
self.schedule_policy = server_args.schedule_policy
250-
self.lora_paths = server_args.lora_paths
250+
self.enable_lora = server_args.enable_lora
251251
self.max_loras_per_batch = server_args.max_loras_per_batch
252252
self.enable_overlap = not server_args.disable_overlap_schedule
253253
self.skip_tokenizer_init = server_args.skip_tokenizer_init
@@ -1706,13 +1706,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
17061706
self.chunked_req.init_next_round_input()
17071707
self.chunked_req = adder.add_chunked_req(self.chunked_req)
17081708

1709-
if self.lora_paths:
1709+
if self.enable_lora:
17101710
lora_set = set([req.lora_path for req in self.running_batch.reqs])
17111711

17121712
# Get requests from the waiting queue to a new prefill batch
17131713
for req in self.waiting_queue:
17141714
if (
1715-
self.lora_paths
1715+
self.enable_lora
17161716
and len(
17171717
lora_set
17181718
| set([req.lora_path for req in adder.can_run_list])
@@ -2466,12 +2466,6 @@ def load_lora_adapter(
24662466
"""In-place loading a new lora adapter from disk or huggingface."""
24672467

24682468
result = self.tp_worker.load_lora_adapter(recv_req)
2469-
2470-
if result.success:
2471-
flush_cache_success = self.flush_cache()
2472-
assert flush_cache_success, "Cache flush failed after loading lora adapter."
2473-
else:
2474-
logger.error(result.error_message)
24752469
return result
24762470

24772471
def unload_lora_adapter(
@@ -2480,14 +2474,6 @@ def unload_lora_adapter(
24802474
"""Unload the lora adapter."""
24812475

24822476
result = self.tp_worker.unload_lora_adapter(recv_req)
2483-
2484-
if result.success:
2485-
flush_cache_success = self.flush_cache()
2486-
assert (
2487-
flush_cache_success
2488-
), "Cache flush failed after unloading LoRA weights"
2489-
else:
2490-
logger.error(result.error_message)
24912477
return result
24922478

24932479
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
get_tokenizer,
6363
get_tokenizer_from_processor,
6464
)
65+
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
6566
from sglang.srt.managers.io_struct import (
6667
AbortReq,
6768
BatchEmbeddingOut,
@@ -242,11 +243,11 @@ def __init__(
242243
revision=server_args.revision,
243244
)
244245

245-
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
246-
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
247-
self.loaded_lora_adapters: Dict[str, str] = dict(
248-
self.server_args.lora_paths or {}
249-
)
246+
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
247+
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
248+
# serves as the source of truth for available adapters and maps user-friendly LoRA names
249+
# to internally used unique LoRA IDs.
250+
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
250251

251252
# Store states
252253
self.no_create_loop = False
@@ -523,6 +524,10 @@ async def _tokenize_one_request(
523524
else:
524525
mm_inputs = None
525526

527+
if self.server_args.enable_lora and obj.lora_path:
528+
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
529+
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
530+
526531
self._validate_one_request(obj, input_ids)
527532
return self._create_tokenized_object(
528533
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -574,8 +579,6 @@ def _validate_one_request(
574579
"The server is not configured to enable custom logit processor. "
575580
"Please set `--enable-custom-logits-processor` to enable this feature."
576581
)
577-
if self.server_args.enable_lora and obj.lora_path:
578-
self._validate_lora_adapters(obj)
579582

580583
def _validate_input_ids_in_vocab(
581584
self, input_ids: List[int], vocab_size: int
@@ -689,21 +692,6 @@ def _validate_batch_tokenization_constraints(
689692
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
690693
)
691694

692-
def _validate_lora_adapters(self, obj: GenerateReqInput):
693-
"""Validate that the requested LoRA adapters are loaded."""
694-
requested_adapters = (
695-
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
696-
)
697-
loaded_adapters = (
698-
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
699-
)
700-
unloaded_adapters = requested_adapters - loaded_adapters
701-
if unloaded_adapters:
702-
raise ValueError(
703-
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
704-
f"Loaded adapters: {loaded_adapters}."
705-
)
706-
707695
def _send_one_request(
708696
self,
709697
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -1054,8 +1042,18 @@ async def load_lora_adapter(
10541042
)
10551043

10561044
async with self.model_update_lock.writer_lock:
1045+
# Generate new uniquely identifiable LoRARef object.
1046+
new_adapter = LoRARef(
1047+
lora_name=obj.lora_name,
1048+
lora_path=obj.lora_path,
1049+
)
1050+
1051+
# Register the new adapter in the registry.
1052+
obj.lora_id = new_adapter.lora_id
10571053
result = (await self.update_lora_adapter_communicator(obj))[0]
1058-
self.loaded_lora_adapters = result.loaded_adapters
1054+
if result.success:
1055+
await self.lora_registry.register(new_adapter)
1056+
10591057
return result
10601058

10611059
async def unload_lora_adapter(
@@ -1069,6 +1067,10 @@ async def unload_lora_adapter(
10691067
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
10701068
)
10711069

1070+
assert (
1071+
obj.lora_name is not None
1072+
), "lora_name must be provided to unload LoRA adapter"
1073+
10721074
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
10731075
# with dp_size > 1.
10741076
assert (
@@ -1080,8 +1082,9 @@ async def unload_lora_adapter(
10801082
)
10811083

10821084
async with self.model_update_lock.writer_lock:
1085+
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
10831086
result = (await self.update_lora_adapter_communicator(obj))[0]
1084-
self.loaded_lora_adapters = result.loaded_adapters
1087+
10851088
return result
10861089

10871090
async def get_weights_by_name(
@@ -1309,7 +1312,7 @@ def dump_requests_before_crash(self):
13091312
filename = os.path.join(
13101313
self.crash_dump_folder,
13111314
os.getenv("HOSTNAME", None),
1312-
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
1315+
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
13131316
)
13141317

13151318
os.makedirs(os.path.dirname(filename), exist_ok=True)

python/sglang/srt/managers/tp_worker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,9 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
293293
return parameter
294294

295295
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
296-
result = self.model_runner.load_lora_adapter(
297-
recv_req.lora_name, recv_req.lora_path
298-
)
296+
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
299297
return result
300298

301299
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
302-
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
300+
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
303301
return result

0 commit comments

Comments
 (0)