Skip to content

Commit d935532

Browse files
fzyzcjyssssnow
authored andcommitted
Tool to dump and compare internal activation tensors (#7976)
1 parent db19a24 commit d935532

File tree

4 files changed

+239
-74
lines changed

4 files changed

+239
-74
lines changed

python/sglang/srt/debug_utils.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

python/sglang/srt/debug_utils/__init__.py

Whitespace-only changes.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import argparse
2+
import functools
3+
import re
4+
from pathlib import Path
5+
6+
import polars as pl
7+
import torch
8+
9+
from sglang.srt.debug_utils.dumper import get_truncated_value
10+
11+
12+
def main(args):
13+
df_target = read_meta(args.target_path)
14+
df_target = df_target.sort("rank", "dump_index")
15+
df_target = df_target.filter(
16+
(pl.col("forward_pass_id") >= args.start_id)
17+
& (pl.col("forward_pass_id") <= args.end_id)
18+
)
19+
assert all(
20+
c in df_target.columns
21+
for c in ["rank", "forward_pass_id", "dump_index", "name"]
22+
)
23+
24+
df_baseline = read_meta(args.baseline_path)
25+
print("df_target", df_target)
26+
print("df_baseline", df_baseline)
27+
28+
for row in df_target.iter_rows(named=True):
29+
rows_baseline = df_baseline.filter(
30+
(
31+
pl.col("forward_pass_id")
32+
== row["forward_pass_id"] - args.start_id + args.baseline_start_id
33+
)
34+
& functools.reduce(
35+
lambda a, b: a & b,
36+
[
37+
pl.col(col) == row[col]
38+
for col in row.keys()
39+
if col not in ["forward_pass_id", "dump_index", "filename"]
40+
],
41+
)
42+
)
43+
assert len(rows_baseline) == 1, f"{rows_baseline=}"
44+
row_baseline = rows_baseline.to_dicts()[0]
45+
46+
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
47+
path_target = Path(args.target_path) / row["filename"]
48+
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
49+
check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
50+
print()
51+
52+
53+
def read_meta(directory):
54+
directory = Path(directory)
55+
assert directory.is_dir(), f"{directory=} should be a directory"
56+
57+
rows = []
58+
for p in directory.glob("*.pt"):
59+
full_kwargs = {}
60+
for kv in p.stem.split("___"):
61+
k, v = kv.split("=")
62+
full_kwargs[k] = v
63+
rows.append(
64+
{
65+
"filename": str(p.name),
66+
**full_kwargs,
67+
}
68+
)
69+
70+
df = pl.DataFrame(rows)
71+
df = df.with_columns(
72+
pl.col("forward_pass_id").cast(int),
73+
pl.col("rank").cast(int),
74+
)
75+
return df
76+
77+
78+
def check_tensor_pair(path_baseline, path_target):
79+
x_baseline = torch.load(path_baseline, weights_only=True)
80+
x_target = torch.load(path_target, weights_only=True)
81+
82+
print(
83+
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
84+
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
85+
)
86+
87+
if x_baseline.shape != x_target.shape:
88+
print(f"❌ Shape mismatch")
89+
return
90+
91+
raw_abs_diff = (x_target - x_baseline).abs()
92+
93+
max_abs_diff = raw_abs_diff.max().item()
94+
mean_abs_diff = raw_abs_diff.mean().item()
95+
rel_diff = _calc_rel_diff(x_target, x_baseline)
96+
97+
needs_print = max_abs_diff > 1e-3
98+
99+
print(
100+
"\t".join(
101+
f"{'❌' if value > 1e-3 else '✅'} {name}={value}"
102+
for name, value in [
103+
("rel_diff", rel_diff),
104+
("max_abs_diff", max_abs_diff),
105+
("mean_abs_diff", mean_abs_diff),
106+
]
107+
)
108+
)
109+
110+
if needs_print:
111+
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
112+
print(f"x_target(sample)={get_truncated_value(x_target)}")
113+
114+
115+
# Copied from DeepGEMM
116+
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
117+
x, y = x.double(), y.double()
118+
denominator = (x * x + y * y).sum()
119+
sim = 2 * (x * y).sum() / denominator
120+
return 1 - sim
121+
122+
123+
if __name__ == "__main__":
124+
parser = argparse.ArgumentParser()
125+
parser.add_argument("--baseline-path", type=str)
126+
parser.add_argument("--target-path", type=str)
127+
parser.add_argument("--start-id", type=int, default=0)
128+
parser.add_argument("--end-id", type=int, default=1000000)
129+
parser.add_argument("--baseline-start-id", type=int, default=0)
130+
args = parser.parse_args()
131+
main(args)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import time
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
import torch
7+
import torch.distributed as dist
8+
9+
10+
class _Dumper:
11+
"""Utility to dump tensors, which can be useful when comparison checking models.
12+
13+
Example usage:
14+
dumper.on_forward_pass_start()
15+
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
16+
17+
Import from non-SGLang system:
18+
```
19+
import sys
20+
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
21+
from dumper import dumper
22+
```
23+
24+
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
25+
"""
26+
27+
def __init__(self):
28+
# Do not import `sglang` to make this file standalone
29+
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
30+
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
31+
self._enable_write_file = bool(
32+
int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
33+
)
34+
self._partial_name: Optional[str] = None
35+
self._dump_index = 0
36+
self._forward_pass_id = 0
37+
38+
def on_forward_pass_start(self):
39+
self._forward_pass_id += 1
40+
print(
41+
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
42+
)
43+
44+
def dump(self, name, value, **kwargs):
45+
if not self._enable:
46+
return
47+
48+
assert (
49+
self._forward_pass_id >= 1
50+
), "Do you forget to call `dumper.on_forward_pass_start()`?"
51+
self._dump_index += 1
52+
53+
if self._partial_name is None:
54+
self._partial_name = _get_partial_name()
55+
56+
rank = dist.get_rank()
57+
full_kwargs = dict(
58+
forward_pass_id=self._forward_pass_id,
59+
rank=rank,
60+
name=name,
61+
dump_index=self._dump_index,
62+
**kwargs,
63+
)
64+
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
65+
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
66+
67+
sample_value = get_truncated_value(value)
68+
69+
print(
70+
f"[Dumper] [{rank}, {time.time()}] {path} "
71+
f"type={type(value)} "
72+
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
73+
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
74+
f"sample_value={sample_value}"
75+
)
76+
77+
if self._enable_write_file:
78+
path.parent.mkdir(parents=True, exist_ok=True)
79+
torch.save(value, str(path))
80+
81+
82+
def _get_partial_name():
83+
rank = dist.get_rank()
84+
object_list = [str(time.time()) if rank == 0 else None]
85+
dist.broadcast_object_list(object_list, device="cuda")
86+
return object_list[0]
87+
88+
89+
def get_truncated_value(value):
90+
if value is None:
91+
return None
92+
93+
if isinstance(value, tuple):
94+
return [get_truncated_value(x) for x in value]
95+
96+
if not isinstance(value, torch.Tensor):
97+
return None
98+
99+
if value.numel() < 200:
100+
return value
101+
102+
slices = [
103+
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
104+
]
105+
return value[tuple(slices)]
106+
107+
108+
dumper = _Dumper()

0 commit comments

Comments
 (0)