diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index 6c089b654fd..8626d3e71a6 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -27,6 +27,8 @@ "source": [ "The following server arguments are relevant for multi-LoRA serving:\n", "\n", + "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", + "\n", "* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n", "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", @@ -35,7 +37,7 @@ "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "\n", - "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n", + "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n", "\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "\n", @@ -79,6 +81,7 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --max-loras-per-batch 1 --lora-backend triton \\\n", " --disable-radix-cache\n", @@ -98,7 +101,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses the base model\n", @@ -137,6 +140,7 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", @@ -157,7 +161,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", @@ -191,11 +195,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Basic Usage\n", - "\n", "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n", "\n", - "(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)" + "When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." ] }, { @@ -204,13 +206,22 @@ "metadata": {}, "outputs": [], "source": [ + "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", + "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", + "lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", + "\n", + "\n", + "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", + "# We are adding it here just to demonstrate usage.\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - " --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n", + " --enable-lora \\\n", " --cuda-graph-max-bs 2 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", " --disable-radix-cache\n", + " --max-lora-rank 256\n", + " --lora-target-modules all\n", " \"\"\"\n", ")\n", "\n", @@ -218,6 +229,13 @@ "wait_for_server(url)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load adapter lora0" + ] + }, { "cell_type": "code", "execution_count": null, @@ -227,8 +245,8 @@ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", - " \"lora_name\": \"lora1\",\n", - " \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n", + " \"lora_name\": \"lora0\",\n", + " \"lora_path\": lora0,\n", " },\n", ")\n", "\n", @@ -239,38 +257,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/generate\",\n", - " json={\n", - " \"text\": [\n", - " \"List 3 countries and their capitals.\",\n", - " \"List 3 countries and their capitals.\",\n", - " ],\n", - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", - " \"lora_path\": [\"lora0\", \"lora1\"],\n", - " },\n", - ")\n", - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora1: {response.json()[1]['text']}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = requests.post(\n", - " url + \"/unload_lora_adapter\",\n", - " json={\n", - " \"lora_name\": \"lora0\",\n", - " },\n", - ")" + "Load adapter lora1:" ] }, { @@ -282,8 +272,8 @@ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", - " \"lora_name\": \"lora2\",\n", - " \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n", + " \"lora_name\": \"lora1\",\n", + " \"lora_path\": lora1,\n", " },\n", ")\n", "\n", @@ -294,24 +284,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = requests.post(\n", - " url + \"/generate\",\n", - " json={\n", - " \"text\": [\n", - " \"List 3 countries and their capitals.\",\n", - " \"List 3 countries and their capitals.\",\n", - " ],\n", - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", - " \"lora_path\": [\"lora1\", \"lora2\"],\n", - " },\n", - ")\n", - "print(f\"Output from lora1: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora2: {response.json()[1]['text']}\")" + "Check inference output:" ] }, { @@ -320,18 +296,29 @@ "metadata": {}, "outputs": [], "source": [ - "terminate_process(server_process)" + "url = f\"http://127.0.0.1:{port}\"\n", + "json_data = {\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " # The first input uses lora0, and the second input uses lora1\n", + " \"lora_path\": [\"lora0\", \"lora1\"],\n", + "}\n", + "response = requests.post(\n", + " url + \"/generate\",\n", + " json=json_data,\n", + ")\n", + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Advanced: hosting adapters of different shapes\n", - "\n", - "In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n", - "\n", - "For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." + "Unload lora0 and replace it with a different adapter:" ] }, { @@ -340,39 +327,18 @@ "metadata": {}, "outputs": [], "source": [ - "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", - "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", - "\n", - "\n", - "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", - "# We are adding it here just to demonstrate usage.\n", - "server_process, port = launch_server_cmd(\n", - " f\"\"\"\n", - " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - " --lora-paths lora0={lora0} \\\n", - " --cuda-graph-max-bs 2 \\\n", - " --max-loras-per-batch 2 --lora-backend triton \\\n", - " --disable-radix-cache\n", - " --max-lora-rank 64\n", - " --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n", - " \"\"\"\n", + "response = requests.post(\n", + " url + \"/unload_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora0\",\n", + " },\n", ")\n", "\n", - "url = f\"http://127.0.0.1:{port}\"\n", - "wait_for_server(url)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", - " \"lora_name\": \"lora1\",\n", - " \"lora_path\": lora1,\n", + " \"lora_name\": \"lora0\",\n", + " \"lora_path\": lora0_new,\n", " },\n", ")\n", "\n", @@ -382,6 +348,13 @@ " print(\"Failed to load LoRA adapter.\", response.json())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check output again:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -392,7 +365,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", @@ -402,8 +375,8 @@ " url + \"/generate\",\n", " json=json_data,\n", ")\n", - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora1: {response.json()[1]['text']}\")" + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, { diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 6320a6e61aa..d7c5ff520dc 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -176,8 +176,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| +| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False | | `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | -| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None | +| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 96102d1efd5..85fd246163c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -186,9 +186,9 @@ def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig): ) if incompatible: raise ValueError( - f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration." - "We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, " - "You can specify expected configs via --max_lora_rank and --enable_lora_modules." + f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " + "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " + "included in `--enable_lora_modules`." ) def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7ba07f67512..631d23f1733 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -574,7 +574,7 @@ def _validate_one_request( "The server is not configured to enable custom logit processor. " "Please set `--enable-custom-logits-processor` to enable this feature." ) - if self.server_args.lora_paths and obj.lora_path: + if self.server_args.enable_lora and obj.lora_path: self._validate_lora_adapters(obj) def _validate_input_ids_in_vocab( @@ -1037,6 +1037,10 @@ async def load_lora_adapter( _: Optional[fastapi.Request] = None, ) -> LoadLoRAAdapterReqOutput: self.auto_create_handle_loop() + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. @@ -1060,6 +1064,10 @@ async def unload_lora_adapter( _: Optional[fastapi.Request] = None, ) -> UnloadLoRAAdapterReqOutput: self.auto_create_handle_loop() + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1f654ca7ecf..520a631c5ec 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -264,7 +264,7 @@ def __init__(self, model_runner: ModelRunner): if self.enable_torch_compile: set_torch_compile_config() - if self.model_runner.server_args.lora_paths is not None: + if self.model_runner.server_args.enable_lora: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) # Graph inputs @@ -510,11 +510,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable): spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ) - if self.model_runner.server_args.lora_paths is not None: - # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a - # different logic to handle lora, so we need to set `lora_paths` to a list of non-None - # values if lora is enabled. - lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs + if self.model_runner.server_args.enable_lora: + # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever + # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization). + lora_paths = [None] * bs else: lora_paths = None diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fde60e0e501..6f3ea547477 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -418,7 +418,7 @@ def init_new( ret._compute_mrope_positions(model_runner, batch) # Init lora information - if model_runner.server_args.lora_paths is not None: + if model_runner.server_args.enable_lora: model_runner.lora_manager.prepare_lora_batch(ret) TboForwardBatchPreparer.prepare( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bbd5b000067..4f0b1d64ce8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -304,11 +304,7 @@ def initialize(self, min_per_gpu_memory: float): self.apply_torch_tp() # Init lora - # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add - # a new server arg `enable_lora` to control whether to init LoRA manager to be more - # explicit, as it is perfectly valid to start a server with an empty lora_paths and - # load LoRA adapters dynamically later. - if server_args.lora_paths is not None: + if server_args.enable_lora: self.init_lora_manager() # Init memory pool and attention backends @@ -895,7 +891,7 @@ def init_lora_manager(self): max_lora_rank=self.server_args.max_lora_rank, target_modules=self.server_args.lora_target_modules, ) - result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) + result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {}) if result.success: logger.info( f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 24292bcd79b..6464f9f40a3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,8 @@ from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( + LORA_TARGET_ALL_MODULES, + SUPPORTED_LORA_TARGET_MODULES, configure_ipv6, get_device, get_device_memory_capacity, @@ -140,8 +142,9 @@ class ServerArgs: preferred_sampling_params: Optional[str] = None # LoRA + enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None - lora_target_modules: Optional[List[str]] = None + lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1148,6 +1151,12 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # LoRA + parser.add_argument( + "--enable-lora", + default=ServerArgs.enable_lora, + action="store_true", + help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.", + ) parser.add_argument( "--max-lora-rank", default=ServerArgs.max_lora_rank, @@ -1157,18 +1166,12 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--lora-target-modules", type=str, - choices=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES], nargs="*", default=None, - help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.", + help="The union set of all target modules where LoRA should be applied. If not specified, " + "it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, " + "all supported modules will be targeted.", ) parser.add_argument( "--lora-paths", @@ -1816,15 +1819,46 @@ def check_server_args(self): None, }, "moe_dense_tp_size only support 1 and None currently" - if isinstance(self.lora_paths, list): - lora_paths = self.lora_paths - self.lora_paths = {} - for lora_path in lora_paths: - if "=" in lora_path: - name, path = lora_path.split("=", 1) - self.lora_paths[name] = path - else: - self.lora_paths[lora_path] = lora_path + self.check_lora_server_args() + + def check_lora_server_args(self): + # Enable LoRA if any LoRA paths are provided for backward compatibility. + if self.lora_paths: + if self.enable_lora is None: + self.enable_lora = True + logger.info( + "--enable-lora is set to True because --lora-paths is provided." + ) + elif self.enable_lora is False: + logger.warning( + "--enable-lora is set to False, any provided lora_paths will be ignored." + ) + + if self.enable_lora: + # Normalize lora_paths to a dictionary if it is a list. + if isinstance(self.lora_paths, list): + lora_paths = self.lora_paths + self.lora_paths = {} + for lora_path in lora_paths: + if "=" in lora_path: + name, path = lora_path.split("=", 1) + self.lora_paths[name] = path + else: + self.lora_paths[lora_path] = lora_path + + # Expand target modules + if self.lora_target_modules: + self.lora_target_modules = set(self.lora_target_modules) + if "all" in self.lora_target_modules: + assert ( + len(self.lora_target_modules) == 1 + ), "If 'all' is specified in --lora-target-modules, it should be the only module specified." + self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES) + + # Ensure sufficient information is provided for LoRA initialization. + assert self.lora_paths or ( + self.max_lora_rank and self.lora_target_modules + ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): larger_tp = max(decode_tp, prefill_tp) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index dc6e72d75dc..57c06ea7d14 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2885,3 +2885,17 @@ def placeholder(*args, **kwargs): return final_module, getattr(final_module, function_name) return final_module, None + + +# LoRA-related constants and utilities +SUPPORTED_LORA_TARGET_MODULES = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +LORA_TARGET_ALL_MODULES = "all" diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 941940fe0fd..9ec71c29bac 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -507,6 +507,7 @@ def __init__( sleep_on_idle=False, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, + enable_lora: Optional[bool] = None, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -547,6 +548,7 @@ def __init__( sleep_on_idle=sleep_on_idle, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, + enable_lora=enable_lora, **spec_kwargs, ) diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index 785b44e953f..83392b9247b 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -64,8 +64,9 @@ class TestCase: base: str max_loras_per_batch: int all_adapters: List[str] - initial_adapters: List[str] op_sequence: List[Operation] + initial_adapters: Optional[List[str]] = None + enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[List] = None max_new_tokens: int = 32 @@ -171,6 +172,64 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ), ], ), + TestCase( + description="dynamic lora update without initial lora_paths", + base="meta-llama/Llama-3.1-8B-Instruct", + enable_lora=True, + max_lora_rank=256, + lora_target_modules=["all"], + max_loras_per_batch=4, + all_adapters=[ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ], + op_sequence=[ + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + Operation( + type=OperationType.UNLOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + None, + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + ], + ), TestCase( description="dynamic lora update with evictions", base="meta-llama/Llama-3.1-8B-Instruct", @@ -371,7 +430,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: Operation( type=OperationType.LOAD, data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -431,7 +490,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: Operation( type=OperationType.LOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -470,7 +529,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: Operation( type=OperationType.LOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -521,6 +580,7 @@ def __init__( lora_paths: list[str], max_loras_per_batch: int, max_lora_rank: Optional[int], + enable_lora: Optional[bool] = None, lora_target_modules: Optional[List[str]] = None, lora_backend: str = "triton", disable_cuda_graph: bool = False, @@ -535,8 +595,9 @@ def __init__( self.lora_backend = lora_backend self.disable_cuda_graph = disable_cuda_graph self.cuda_graph_max_bs = cuda_graph_max_bs + self.enable_lora = enable_lora - self.expected_adapters = set(lora_paths) + self.expected_adapters = set(lora_paths or []) self.handle = None # Will be set in __enter__ def __enter__(self): @@ -596,6 +657,7 @@ def __enter__(self): disable_cuda_graph=self.disable_cuda_graph, cuda_graph_max_bs=self.cuda_graph_max_bs, disable_radix_cache=True, + enable_lora=self.enable_lora, ) self.handle.__enter__() return self @@ -690,8 +752,6 @@ def __enter__(self): other_args = [ "--cuda-graph-max-bs", str(self.cuda_graph_max_bs), - "--lora-paths", - *self.lora_paths, "--max-loras-per-batch", str(self.max_loras_per_batch), "--lora-backend", @@ -704,6 +764,10 @@ def __enter__(self): "--mem-fraction-static", str(MEM_FRACTION_STATIC), ] + if self.enable_lora: + other_args.append("--enable-lora") + if self.lora_paths: + other_args.extend(["--lora-paths"] + self.lora_paths) if self.disable_cuda_graph: other_args.append("--disable-cuda-graph") if self.max_lora_rank is not None: @@ -836,6 +900,7 @@ def _run_operation_sequence( initial_adapters: List[str], max_loras_per_batch: int, op_sequence: List[Operation], + enable_lora: Optional[bool] = None, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, max_new_tokens: int = 32, @@ -854,6 +919,7 @@ def _run_operation_sequence( max_loras_per_batch=max_loras_per_batch, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, + enable_lora=enable_lora, ) as session: for op in op_sequence: op_type = op.type @@ -903,6 +969,7 @@ def _run_dynamic_adapter_updates( dynamic_output = self._run_operation_sequence( mode=mode, initial_adapters=test_case.initial_adapters, + enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, op_sequence=test_case.op_sequence, @@ -923,6 +990,7 @@ def _run_dynamic_adapter_updates( static_output = self._run_operation_sequence( mode=mode, initial_adapters=test_case.all_adapters, + enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, op_sequence=forward_ops, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e67362cf825..0be7f8a6a39 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,7 +17,7 @@ class TestFile: TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), - TestFile("models/lora/test_lora_update.py", 700), + TestFile("models/lora/test_lora_update.py", 800), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100),