Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 0 additions & 74 deletions python/sglang/srt/debug_utils.py

This file was deleted.

Empty file.
131 changes: 131 additions & 0 deletions python/sglang/srt/debug_utils/dump_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import argparse
import functools
import re
from pathlib import Path

import polars as pl
import torch

from sglang.srt.debug_utils.dumper import get_truncated_value


def main(args):
df_target = read_meta(args.target_path)
df_target = df_target.sort("rank", "dump_index")
df_target = df_target.filter(
(pl.col("forward_pass_id") >= args.start_id)
& (pl.col("forward_pass_id") <= args.end_id)
)
assert all(
c in df_target.columns
for c in ["rank", "forward_pass_id", "dump_index", "name"]
)

df_baseline = read_meta(args.baseline_path)
print("df_target", df_target)
print("df_baseline", df_baseline)

for row in df_target.iter_rows(named=True):
rows_baseline = df_baseline.filter(
(
pl.col("forward_pass_id")
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
)
& functools.reduce(
lambda a, b: a & b,
[
pl.col(col) == row[col]
for col in row.keys()
if col not in ["forward_pass_id", "dump_index", "filename"]
],
)
Comment on lines +34 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using pl.all_horizontal for combining Polars expressions, as it can be more concise than functools.reduce. This improves readability.

            & pl.all_horizontal(
                pl.col(col) == row[col]
                for col in row.keys()
                if col not in ["forward_pass_id", "dump_index", "filename"]
            )

)
assert len(rows_baseline) == 1, f"{rows_baseline=}"
row_baseline = rows_baseline.to_dicts()[0]

path_baseline = Path(args.baseline_path) / row_baseline["filename"]
path_target = Path(args.target_path) / row["filename"]
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
print()


def read_meta(directory):
directory = Path(directory)
assert directory.is_dir(), f"{directory=} should be a directory"

rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)

df = pl.DataFrame(rows)
df = df.with_columns(
pl.col("forward_pass_id").cast(int),
pl.col("rank").cast(int),
)
return df


def check_tensor_pair(path_baseline, path_target):
x_baseline = torch.load(path_baseline, weights_only=True)
x_target = torch.load(path_target, weights_only=True)

print(
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
)

if x_baseline.shape != x_target.shape:
print(f"❌ Shape mismatch")
return

raw_abs_diff = (x_target - x_baseline).abs()

max_abs_diff = raw_abs_diff.max().item()
mean_abs_diff = raw_abs_diff.mean().item()
rel_diff = _calc_rel_diff(x_target, x_baseline)

needs_print = max_abs_diff > 1e-3

print(
"\t".join(
f"{'❌' if value > 1e-3 else '✅'} {name}={value}"
for name, value in [
("rel_diff", rel_diff),
("max_abs_diff", max_abs_diff),
("mean_abs_diff", mean_abs_diff),
]
)
)

if needs_print:
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
print(f"x_target(sample)={get_truncated_value(x_target)}")


# Copied from DeepGEMM
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--baseline-path", type=str)
parser.add_argument("--target-path", type=str)
parser.add_argument("--start-id", type=int, default=0)
parser.add_argument("--end-id", type=int, default=1000000)
parser.add_argument("--baseline-start-id", type=int, default=0)
args = parser.parse_args()
main(args)
108 changes: 108 additions & 0 deletions python/sglang/srt/debug_utils/dumper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import time
from pathlib import Path
from typing import Optional

import torch
import torch.distributed as dist


class _Dumper:
"""Utility to dump tensors, which can be useful when comparison checking models.

Example usage:
dumper.on_forward_pass_start()
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)

Import from non-SGLang system:
```
import sys
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
from dumper import dumper
```

Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
"""

def __init__(self):
# Do not import `sglang` to make this file standalone
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
self._enable_write_file = bool(
int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
)
self._partial_name: Optional[str] = None
self._dump_index = 0
self._forward_pass_id = 0

def on_forward_pass_start(self):
self._forward_pass_id += 1
print(
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
)

def dump(self, name, value, **kwargs):
if not self._enable:
return

assert (
self._forward_pass_id >= 1
), "Do you forget to call `dumper.on_forward_pass_start()`?"
self._dump_index += 1

if self._partial_name is None:
self._partial_name = _get_partial_name()

rank = dist.get_rank()
full_kwargs = dict(
forward_pass_id=self._forward_pass_id,
rank=rank,
name=name,
dump_index=self._dump_index,
**kwargs,
)
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename

sample_value = get_truncated_value(value)

print(
f"[Dumper] [{rank}, {time.time()}] {path} "
f"type={type(value)} "
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
f"sample_value={sample_value}"
)

if self._enable_write_file:
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(value, str(path))


def _get_partial_name():
rank = dist.get_rank()
object_list = [str(time.time()) if rank == 0 else None]
dist.broadcast_object_list(object_list, device="cuda")
return object_list[0]


def get_truncated_value(value):
if value is None:
return None

if isinstance(value, tuple):
return [get_truncated_value(x) for x in value]

if not isinstance(value, torch.Tensor):
return None

if value.numel() < 200:
return value

slices = [
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
]
Comment on lines +99 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tensor truncation logic appears inconsistent. The tensor is truncated if numel() >= 200, but a dimension is truncated only if dim_size > 200. Consider using a smaller threshold for dim_size to make truncation more effective for multi-dimensional tensors.

Suggested change
if value.numel() < 200:
return value
slices = [
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
]
if value.numel() < 200:
return value
slices = [
slice(0, 5) if dim_size > 10 else slice(None) for dim_size in value.shape
]

return value[tuple(slices)]


dumper = _Dumper()
Loading