Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 67 additions & 94 deletions docs/backend/lora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"source": [
"The following server arguments are relevant for multi-LoRA serving:\n",
"\n",
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
"\n",
"* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n",
"\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
Expand All @@ -35,7 +37,7 @@
"\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n",
Expand Down Expand Up @@ -79,6 +81,7 @@
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" --max-loras-per-batch 1 --lora-backend triton \\\n",
" --disable-radix-cache\n",
Expand All @@ -98,7 +101,7 @@
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses the base model\n",
Expand Down Expand Up @@ -137,6 +140,7 @@
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
Expand All @@ -157,7 +161,7 @@
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
Expand Down Expand Up @@ -191,11 +195,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Basic Usage\n",
"\n",
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
"\n",
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
"When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
]
},
{
Expand All @@ -204,20 +206,36 @@
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
" --enable-lora \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 256\n",
" --lora-target-modules all\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load adapter lora0"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -227,8 +245,8 @@
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
" \"lora_name\": \"lora0\",\n",
" \"lora_path\": lora0,\n",
" },\n",
")\n",
"\n",
Expand All @@ -239,38 +257,10 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
" },\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")"
"Load adapter lora1:"
]
},
{
Expand All @@ -282,8 +272,8 @@
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora2\",\n",
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" },\n",
")\n",
"\n",
Expand All @@ -294,24 +284,10 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
" },\n",
")\n",
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
"Check inference output:"
]
},
{
Expand All @@ -320,18 +296,29 @@
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Advanced: hosting adapters of different shapes\n",
"\n",
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
"\n",
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
"Unload lora0 and replace it with a different adapter:"
]
},
{
Expand All @@ -340,39 +327,18 @@
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" f\"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0={lora0} \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 64\n",
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
" \"\"\"\n",
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" \"lora_name\": \"lora0\",\n",
" \"lora_path\": lora0_new,\n",
" },\n",
")\n",
"\n",
Expand All @@ -382,6 +348,13 @@
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check output again:"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -392,7 +365,7 @@
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
Expand All @@ -402,8 +375,8 @@
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s

| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
)

def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _validate_one_request(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.lora_paths and obj.lora_path:
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)

def _validate_input_ids_in_vocab(
Expand Down Expand Up @@ -1037,6 +1037,10 @@ async def load_lora_adapter(
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)

# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
Expand All @@ -1060,6 +1064,10 @@ async def unload_lora_adapter(
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)

# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
Expand Down
11 changes: 5 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(self, model_runner: ModelRunner):
if self.enable_torch_compile:
set_torch_compile_config()

if self.model_runner.server_args.lora_paths is not None:
if self.model_runner.server_args.enable_lora:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)

# Graph inputs
Expand Down Expand Up @@ -510,11 +510,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)

if self.model_runner.server_args.lora_paths is not None:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
# values if lora is enabled.
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
if self.model_runner.server_args.enable_lora:
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
lora_paths = [None] * bs
else:
lora_paths = None

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def init_new(
ret._compute_mrope_positions(model_runner, batch)

# Init lora information
if model_runner.server_args.lora_paths is not None:
if model_runner.server_args.enable_lora:
model_runner.lora_manager.prepare_lora_batch(ret)

TboForwardBatchPreparer.prepare(
Expand Down
Loading
Loading