Skip to content

Commit e638d5e

Browse files
lifuhuangssssnow
authored andcommitted
Support limiting max loaded loras in CPU. (#8650)
1 parent f82c54d commit e638d5e

File tree

8 files changed

+160
-55
lines changed

8 files changed

+160
-55
lines changed

docs/backend/lora.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
"\n",
3434
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
3535
"\n",
36+
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
37+
"\n",
3638
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
3739
"\n",
3840
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",

docs/backend/server_arguments.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
181181
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
182182
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
183183
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
184+
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
184185
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
185186

186187
## Kernel backend

python/sglang/srt/lora/lora_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,10 @@ def _register_adapter(self, lora_ref: LoRARef):
186186
self._registry[lora_ref.lora_name] = lora_ref
187187
self._counters[lora_ref.lora_id] = ConcurrentCounter()
188188
return lora_ref
189+
190+
@property
191+
def num_registered_loras(self) -> int:
192+
"""
193+
Returns the total number of LoRA adapters currently registered.
194+
"""
195+
return len(self._registry)

python/sglang/srt/managers/io_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ def to_ref(self) -> LoRARef:
10971097
class LoRAUpdateResult:
10981098
success: bool
10991099
error_message: Optional[str] = None
1100-
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1100+
loaded_adapters: Optional[Dict[str, LoRARef]] = None
11011101

11021102

11031103
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,76 +1084,98 @@ async def load_lora_adapter(
10841084
_: Optional[fastapi.Request] = None,
10851085
) -> LoadLoRAAdapterReqOutput:
10861086
self.auto_create_handle_loop()
1087-
if not self.server_args.enable_lora:
1088-
raise ValueError(
1089-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1090-
)
10911087

1092-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1093-
# with dp_size > 1.
1094-
assert (
1095-
self.server_args.dp_size == 1
1096-
), "dp_size must be 1 for dynamic lora loading"
1097-
logger.info(
1098-
"Start load Lora adapter. Lora name=%s, path=%s",
1099-
obj.lora_name,
1100-
obj.lora_path,
1101-
)
1088+
try:
1089+
if not self.server_args.enable_lora:
1090+
raise ValueError(
1091+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1092+
)
11021093

1103-
async with self.lora_update_lock:
1104-
# Generate new uniquely identifiable LoRARef object.
1105-
new_adapter = LoRARef(
1106-
lora_name=obj.lora_name,
1107-
lora_path=obj.lora_path,
1094+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1095+
# with dp_size > 1.
1096+
assert (
1097+
self.server_args.dp_size == 1
1098+
), "dp_size must be 1 for dynamic lora loading"
1099+
logger.info(
1100+
"Start load Lora adapter. Lora name=%s, path=%s",
1101+
obj.lora_name,
1102+
obj.lora_path,
11081103
)
11091104

1110-
# Trigger the actual loading operation at the backend processes.
1111-
obj.lora_id = new_adapter.lora_id
1112-
result = (await self.update_lora_adapter_communicator(obj))[0]
1105+
async with self.lora_update_lock:
1106+
if (
1107+
self.server_args.max_loaded_loras is not None
1108+
and self.lora_registry.num_registered_loras
1109+
>= self.server_args.max_loaded_loras
1110+
):
1111+
raise ValueError(
1112+
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1113+
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1114+
"Please unload some LoRA adapters before loading new ones."
1115+
)
11131116

1114-
# Register the LoRA adapter only after loading is successful.
1115-
if result.success:
1116-
await self.lora_registry.register(new_adapter)
1117+
# Generate new uniquely identifiable LoRARef object.
1118+
new_adapter = LoRARef(
1119+
lora_name=obj.lora_name,
1120+
lora_path=obj.lora_path,
1121+
)
11171122

1118-
return result
1123+
# Trigger the actual loading operation at the backend processes.
1124+
obj.lora_id = new_adapter.lora_id
1125+
result = (await self.update_lora_adapter_communicator(obj))[0]
1126+
1127+
# Register the LoRA adapter only after loading is successful.
1128+
if result.success:
1129+
await self.lora_registry.register(new_adapter)
1130+
1131+
return result
1132+
except ValueError as e:
1133+
return LoadLoRAAdapterReqOutput(
1134+
success=False,
1135+
error_message=str(e),
1136+
)
11191137

11201138
async def unload_lora_adapter(
11211139
self,
11221140
obj: UnloadLoRAAdapterReqInput,
11231141
_: Optional[fastapi.Request] = None,
11241142
) -> UnloadLoRAAdapterReqOutput:
11251143
self.auto_create_handle_loop()
1126-
if not self.server_args.enable_lora:
1127-
raise ValueError(
1128-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1129-
)
11301144

1131-
assert (
1132-
obj.lora_name is not None
1133-
), "lora_name must be provided to unload LoRA adapter"
1145+
try:
1146+
if not self.server_args.enable_lora:
1147+
raise ValueError(
1148+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1149+
)
11341150

1135-
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1136-
# with dp_size > 1.
1137-
assert (
1138-
self.server_args.dp_size == 1
1139-
), "dp_size must be 1 for dynamic lora loading"
1140-
logger.info(
1141-
"Start unload Lora adapter. Lora name=%s",
1142-
obj.lora_name,
1143-
)
1151+
assert (
1152+
obj.lora_name is not None
1153+
), "lora_name must be provided to unload LoRA adapter"
1154+
1155+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1156+
# with dp_size > 1.
1157+
assert (
1158+
self.server_args.dp_size == 1
1159+
), "dp_size must be 1 for dynamic lora loading"
1160+
logger.info(
1161+
"Start unload Lora adapter. Lora name=%s",
1162+
obj.lora_name,
1163+
)
11441164

1145-
async with self.lora_update_lock:
1146-
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
1147-
# from being started.
1148-
lora_id = await self.lora_registry.unregister(obj.lora_name)
1149-
obj.lora_id = lora_id
1165+
async with self.lora_update_lock:
1166+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
1167+
# from being started.
1168+
lora_id = await self.lora_registry.unregister(obj.lora_name)
1169+
obj.lora_id = lora_id
11501170

1151-
# Initiate the actual unloading operation at the backend processes only after all
1152-
# ongoing requests using this LoRA adapter are finished.
1153-
await self.lora_registry.wait_for_unload(lora_id)
1154-
result = (await self.update_lora_adapter_communicator(obj))[0]
1171+
# Initiate the actual unloading operation at the backend processes only after all
1172+
# ongoing requests using this LoRA adapter are finished.
1173+
await self.lora_registry.wait_for_unload(lora_id)
1174+
result = (await self.update_lora_adapter_communicator(obj))[0]
11551175

1156-
return result
1176+
return result
1177+
except ValueError as e:
1178+
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
11571179

11581180
async def get_weights_by_name(
11591181
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None

python/sglang/srt/server_args.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class ServerArgs:
149149
max_lora_rank: Optional[int] = None
150150
lora_target_modules: Optional[Union[set[str], List[str]]] = None
151151
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
152+
max_loaded_loras: Optional[int] = None
152153
max_loras_per_batch: int = 8
153154
lora_backend: str = "triton"
154155

@@ -1237,6 +1238,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
12371238
default=8,
12381239
help="Maximum number of adapters for a running batch, include base-only request.",
12391240
)
1241+
parser.add_argument(
1242+
"--max-loaded-loras",
1243+
type=int,
1244+
default=ServerArgs.max_loaded_loras,
1245+
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
1246+
)
12401247
parser.add_argument(
12411248
"--lora-backend",
12421249
type=str,
@@ -2008,6 +2015,19 @@ def check_lora_server_args(self):
20082015
self.max_lora_rank and self.lora_target_modules
20092016
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
20102017

2018+
# Validate max_loaded_loras
2019+
if self.max_loaded_loras is not None:
2020+
assert self.max_loaded_loras >= self.max_loras_per_batch, (
2021+
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
2022+
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
2023+
)
2024+
assert (
2025+
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
2026+
), (
2027+
"The number of LoRA paths should not exceed max_loaded_loras. "
2028+
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
2029+
)
2030+
20112031
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
20122032
larger_tp = max(decode_tp, prefill_tp)
20132033
smaller_tp = min(decode_tp, prefill_tp)

python/sglang/test/runners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def __init__(
514514
max_lora_rank: Optional[int] = None,
515515
lora_target_modules: Optional[List[str]] = None,
516516
enable_lora: Optional[bool] = None,
517+
max_loaded_loras: Optional[int] = None,
517518
):
518519
self.model_type = model_type
519520
self.is_generation = model_type == "generation"
@@ -556,6 +557,7 @@ def __init__(
556557
max_lora_rank=max_lora_rank,
557558
lora_target_modules=lora_target_modules,
558559
enable_lora=enable_lora,
560+
max_loaded_loras=max_loaded_loras,
559561
**spec_kwargs,
560562
)
561563

test/srt/models/lora/test_lora_update.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class TestCase:
7070
max_lora_rank: Optional[int] = None
7171
lora_target_modules: Optional[List] = None
7272
max_new_tokens: int = 32
73+
max_loaded_loras: Optional[int] = None
7374

7475

7576
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
@@ -559,7 +560,43 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
559560
],
560561
),
561562
]
562-
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
563+
MAX_LOADED_LORAS_TESTS = [
564+
TestCase(
565+
description="Test max_loaded_loras limit",
566+
base="meta-llama/Llama-3.1-8B-Instruct",
567+
max_loras_per_batch=2,
568+
max_loaded_loras=2,
569+
all_adapters=[
570+
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
571+
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
572+
"pbevan11/llama-3.1-8b-ocr-correction",
573+
],
574+
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
575+
op_sequence=[
576+
Operation(
577+
type=OperationType.LOAD,
578+
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
579+
),
580+
Operation(
581+
type=OperationType.LOAD,
582+
data="pbevan11/llama-3.1-8b-ocr-correction",
583+
expected_error="Maximum number of loaded LoRA adapters",
584+
),
585+
Operation(
586+
type=OperationType.UNLOAD,
587+
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
588+
),
589+
Operation(
590+
type=OperationType.LOAD,
591+
data="pbevan11/llama-3.1-8b-ocr-correction",
592+
),
593+
],
594+
),
595+
]
596+
597+
ALL_TESTS = (
598+
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
599+
)
563600

564601

565602
class LoRAUpdateTestSessionMode(Enum):
@@ -579,6 +616,7 @@ def __init__(
579616
model_path: str,
580617
lora_paths: list[str],
581618
max_loras_per_batch: int,
619+
max_loaded_loras: Optional[int] = None,
582620
max_lora_rank: Optional[int],
583621
enable_lora: Optional[bool] = None,
584622
lora_target_modules: Optional[List[str]] = None,
@@ -592,6 +630,7 @@ def __init__(
592630
self.max_lora_rank = max_lora_rank
593631
self.lora_target_modules = lora_target_modules
594632
self.max_loras_per_batch = max_loras_per_batch
633+
self.max_loaded_loras = max_loaded_loras
595634
self.lora_backend = lora_backend
596635
self.disable_cuda_graph = disable_cuda_graph
597636
self.cuda_graph_max_bs = cuda_graph_max_bs
@@ -654,6 +693,7 @@ def __enter__(self):
654693
torch_dtype=torch.float16,
655694
mem_fraction_static=MEM_FRACTION_STATIC,
656695
max_loras_per_batch=self.max_loras_per_batch,
696+
max_loaded_loras=self.max_loaded_loras,
657697
disable_cuda_graph=self.disable_cuda_graph,
658698
cuda_graph_max_bs=self.cuda_graph_max_bs,
659699
disable_radix_cache=True,
@@ -774,6 +814,8 @@ def __enter__(self):
774814
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
775815
if self.lora_target_modules is not None:
776816
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
817+
if self.max_loaded_loras is not None:
818+
other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])
777819

778820
# launch external server
779821
self.handle = popen_launch_server(
@@ -898,8 +940,9 @@ def _run_operation_sequence(
898940
mode: LoRAUpdateTestSessionMode,
899941
base: str,
900942
initial_adapters: List[str],
901-
max_loras_per_batch: int,
902943
op_sequence: List[Operation],
944+
max_loras_per_batch: int,
945+
max_loaded_loras: Optional[int] = None,
903946
enable_lora: Optional[bool] = None,
904947
max_lora_rank: Optional[int] = None,
905948
lora_target_modules: Optional[List[str]] = None,
@@ -917,6 +960,7 @@ def _run_operation_sequence(
917960
model_path=base,
918961
lora_paths=initial_adapters,
919962
max_loras_per_batch=max_loras_per_batch,
963+
max_loaded_loras=max_loaded_loras,
920964
max_lora_rank=max_lora_rank,
921965
lora_target_modules=lora_target_modules,
922966
enable_lora=enable_lora,
@@ -972,6 +1016,7 @@ def _run_dynamic_adapter_updates(
9721016
enable_lora=test_case.enable_lora,
9731017
base=test_case.base,
9741018
max_loras_per_batch=test_case.max_loras_per_batch,
1019+
max_loaded_loras=test_case.max_loaded_loras,
9751020
op_sequence=test_case.op_sequence,
9761021
max_new_tokens=test_case.max_new_tokens,
9771022
max_lora_rank=test_case.max_lora_rank,
@@ -985,6 +1030,12 @@ def _run_dynamic_adapter_updates(
9851030
if x.type == OperationType.FORWARD and x.expected_error is None
9861031
]
9871032

1033+
if not forward_ops:
1034+
print(
1035+
f"No forward operations found in test case {case_idx}. Skipping static pass."
1036+
)
1037+
continue
1038+
9881039
print("=" * 100)
9891040
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
9901041
static_output = self._run_operation_sequence(

0 commit comments

Comments
 (0)