Skip to content
2 changes: 1 addition & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--log-level` | The logging level of all loggers. | info |
| `--log-level-http` | The logging level of HTTP server. If not set, reuse --log-level by default. | None |
| `--log-requests` | Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level. | False |
| `--log-requests-level` | 0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output. | 0 |
| `--log-requests-level` | 0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output. | 0 |
| `--show-time-cost` | Show time cost of custom marks. | False |
| `--enable-metrics` | Enable log prometheus metrics. | False |
| `--bucket-time-to-first-token` | The buckets of time to first token, specified as a list of floats. | None |
Expand Down
16 changes: 14 additions & 2 deletions python/sglang/bench_one_batch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class BenchArgs:
output_len: Tuple[int] = (16,)
temperature: float = 0.0
return_logprob: bool = False
client_stream_interval: int = 1
input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl"
base_url: str = ""
Expand All @@ -60,6 +61,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
type=int,
default=BenchArgs.client_stream_interval,
)
parser.add_argument(
"--input-len-step-percentage",
type=float,
Expand Down Expand Up @@ -120,6 +126,7 @@ def run_one_case(
output_len: int,
temperature: float,
return_logprob: bool,
stream_interval: int,
input_len_step_percentage: float,
run_name: str,
result_filename: str,
Expand Down Expand Up @@ -168,6 +175,7 @@ def run_one_case(
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
"stream_interval": stream_interval,
},
"return_logprob": return_logprob,
"stream": True,
Expand Down Expand Up @@ -245,8 +253,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
else:
proc, base_url = launch_server_process(server_args)

tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
server_info = requests.get(base_url + "/get_server_info")
tokenizer_path = server_info.json()["tokenizer_path"]
tokenizer = get_tokenizer(tokenizer_path)

# warmup
if not bench_args.skip_warmup:
Expand All @@ -258,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
output_len=16,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="",
result_filename="",
Expand All @@ -280,6 +290,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
Expand All @@ -301,6 +312,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,6 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
args.apply_chat_template = True
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/configs/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,11 @@ def _rope_scaling_validation(self):
)
if (
rope_scaling_factor is None
or not isinstance(rope_scaling_factor, float)
or not isinstance(rope_scaling_factor, int)
or not isinstance(rope_scaling_factor, (float, int))
or rope_scaling_factor < 1.0
):
raise ValueError(
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor}"
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
)
if isinstance(rope_scaling_factor, int):
rope_scaling_factor = float(rope_scaling_factor)
Expand Down
29 changes: 23 additions & 6 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ def set_global_state(global_state: _GlobalState):

@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args

# Initialize OpenAI serving handlers
fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
_global_state.tokenizer_manager, _global_state.template_manager
Expand All @@ -145,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
_global_state.tokenizer_manager
)

server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
server_args.disaggregation_mode,
server_args.warmups.split(","),
_global_state.tokenizer_manager,
)
logger.info("Warmup ended")

Expand Down Expand Up @@ -280,13 +281,17 @@ async def get_model_info():
"model_path": _global_state.tokenizer_manager.model_path,
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.tokenizer_manager.is_generation,
"preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
}
return result


@app.get("/get_server_info")
async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state()
# Returns interna states per DP.
internal_states: List[Dict[Any, Any]] = (
await _global_state.tokenizer_manager.get_internal_state()
)
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info,
Expand All @@ -300,6 +305,8 @@ async def get_load():
return await _global_state.tokenizer_manager.get_load()


# example usage:
# curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
Expand Down Expand Up @@ -886,14 +893,23 @@ def launch_server(
add_prometheus_middleware(app)
enable_func_timer()

image_token_text = None
if (
tokenizer_manager.image_token_id is not None
and not server_args.skip_tokenizer_init
):
image_token_text = tokenizer_manager.tokenizer.decode(
[tokenizer_manager.image_token_id]
)

# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id,
image_token_text,
launch_callback,
),
)
Expand Down Expand Up @@ -1022,9 +1038,10 @@ def _wait_and_warmup(
return

# Debug print
# logger.info(f"{res.json()=}")
# logger.info(f"warmup request returns: {res.json()=}")

logger.info("The server is fired up and ready to roll!")

if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")

Expand Down
88 changes: 76 additions & 12 deletions python/sglang/srt/layers/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

_is_hip = is_hip()


fused_softcap_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
Expand Down Expand Up @@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape

min_num_warps = 16 if _is_hip else 32

if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
)
else:
max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
),
4,
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}

Expand Down Expand Up @@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else:
output = torch.empty_like(x)
bs, hidden_dim = x.shape

min_num_warps = 16 if _is_hip else 32

max_warps = 16 if _is_hip else 32
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
),
}

Expand Down Expand Up @@ -331,6 +325,75 @@ def forward_native(
return self.rmsnorm2.forward_native(residual), residual


@triton.jit
def experts_combine_kernel(
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
start_index_mlp = pid * hidden_dim
start_index_rmoe = pid * hidden_dim * combine_k
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
combine_k_offsets = tl.arange(0, combine_k)

moe_x = tl.load(
moe_hidden_states
+ start_index_rmoe
+ combine_k_offsets[:, None] * hidden_dim
+ offsets[None, :],
mask=mask[None, :],
other=0.0,
)
moe_x = tl.sum(moe_x, axis=0)
mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0)
combined_x = (moe_x + mlp_x) / 1.4142135623730951

tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask)


def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None):
assert moe_hidden_states.is_contiguous()
assert mlp_hidden_states.is_contiguous()

if len(moe_hidden_states.shape) == 2:
combine_k = 1 # pre-combined
else:
combine_k = moe_hidden_states.shape[1]

if output_buffer is None:
out_hidden_states = torch.empty_like(mlp_hidden_states)
else:
flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1)
assert flat_output_buffer.numel() >= mlp_hidden_states.numel()
out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape(
mlp_hidden_states.shape
)

bs, hidden_dim = mlp_hidden_states.shape

config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4
),
}

experts_combine_kernel[(bs,)](
out_hidden_states,
moe_hidden_states,
mlp_hidden_states,
combine_k,
hidden_dim,
**config,
)
return out_hidden_states


# gelu on first half of vector
@triton.jit
def gelu_and_mul_kernel(
Expand Down Expand Up @@ -400,10 +463,11 @@ def gelu_and_mul_triton(
out_scales = scales
static_scale = True

max_warps = 16 if _is_hip else 32
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
),
}

Expand Down
Loading
Loading