-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Tool to dump and compare internal activation tensors #7976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
42b6641
cf6804a
e6af668
bc8e997
7c86f8d
3edcb7a
7374191
545cb49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import argparse | ||
import functools | ||
import re | ||
from pathlib import Path | ||
|
||
import polars as pl | ||
import torch | ||
|
||
from sglang.srt.debug_utils.dumper import get_truncated_value | ||
|
||
|
||
def main(args): | ||
df_target = read_meta(args.target_path) | ||
df_target = df_target.sort("rank", "dump_index") | ||
df_target = df_target.filter( | ||
(pl.col("forward_pass_id") >= args.start_id) | ||
& (pl.col("forward_pass_id") <= args.end_id) | ||
) | ||
assert all( | ||
c in df_target.columns | ||
for c in ["rank", "forward_pass_id", "dump_index", "name"] | ||
) | ||
|
||
df_baseline = read_meta(args.baseline_path) | ||
print("df_target", df_target) | ||
print("df_baseline", df_baseline) | ||
|
||
for row in df_target.iter_rows(named=True): | ||
rows_baseline = df_baseline.filter( | ||
( | ||
pl.col("forward_pass_id") | ||
== row["forward_pass_id"] - args.start_id + args.baseline_start_id | ||
) | ||
& functools.reduce( | ||
lambda a, b: a & b, | ||
[ | ||
pl.col(col) == row[col] | ||
for col in row.keys() | ||
if col not in ["forward_pass_id", "dump_index", "filename"] | ||
], | ||
) | ||
) | ||
assert len(rows_baseline) == 1, f"{rows_baseline=}" | ||
row_baseline = rows_baseline.to_dicts()[0] | ||
|
||
path_baseline = Path(args.baseline_path) / row_baseline["filename"] | ||
path_target = Path(args.target_path) / row["filename"] | ||
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}") | ||
check_tensor_pair(path_baseline=path_baseline, path_target=path_target) | ||
print() | ||
|
||
|
||
def read_meta(directory): | ||
directory = Path(directory) | ||
assert directory.is_dir(), f"{directory=} should be a directory" | ||
|
||
rows = [] | ||
for p in directory.glob("*.pt"): | ||
full_kwargs = {} | ||
for kv in p.stem.split("___"): | ||
k, v = kv.split("=") | ||
full_kwargs[k] = v | ||
rows.append( | ||
{ | ||
"filename": str(p.name), | ||
**full_kwargs, | ||
} | ||
) | ||
|
||
df = pl.DataFrame(rows) | ||
df = df.with_columns( | ||
pl.col("forward_pass_id").cast(int), | ||
pl.col("rank").cast(int), | ||
) | ||
return df | ||
|
||
|
||
def check_tensor_pair(path_baseline, path_target): | ||
x_baseline = torch.load(path_baseline, weights_only=True) | ||
x_target = torch.load(path_target, weights_only=True) | ||
|
||
print( | ||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t" | ||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}" | ||
) | ||
|
||
if x_baseline.shape != x_target.shape: | ||
print(f"❌ Shape mismatch") | ||
return | ||
|
||
raw_abs_diff = (x_target - x_baseline).abs() | ||
|
||
max_abs_diff = raw_abs_diff.max().item() | ||
mean_abs_diff = raw_abs_diff.mean().item() | ||
rel_diff = _calc_rel_diff(x_target, x_baseline) | ||
|
||
needs_print = max_abs_diff > 1e-3 | ||
|
||
print( | ||
"\t".join( | ||
f"{'❌' if value > 1e-3 else '✅'} {name}={value}" | ||
for name, value in [ | ||
("rel_diff", rel_diff), | ||
("max_abs_diff", max_abs_diff), | ||
("mean_abs_diff", mean_abs_diff), | ||
] | ||
) | ||
) | ||
|
||
if needs_print: | ||
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}") | ||
print(f"x_target(sample)={get_truncated_value(x_target)}") | ||
|
||
|
||
# Copied from DeepGEMM | ||
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor): | ||
x, y = x.double(), y.double() | ||
denominator = (x * x + y * y).sum() | ||
sim = 2 * (x * y).sum() / denominator | ||
return 1 - sim | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--baseline-path", type=str) | ||
parser.add_argument("--target-path", type=str) | ||
parser.add_argument("--start-id", type=int, default=0) | ||
parser.add_argument("--end-id", type=int, default=1000000) | ||
parser.add_argument("--baseline-start-id", type=int, default=0) | ||
args = parser.parse_args() | ||
main(args) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,108 @@ | ||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||
import time | ||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||
from typing import Optional | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||
import torch.distributed as dist | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class _Dumper: | ||||||||||||||||||||||||||
"""Utility to dump tensors, which can be useful when comparison checking models. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Example usage: | ||||||||||||||||||||||||||
dumper.on_forward_pass_start() | ||||||||||||||||||||||||||
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Import from non-SGLang system: | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
import sys | ||||||||||||||||||||||||||
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils") | ||||||||||||||||||||||||||
from dumper import dumper | ||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||||||
# Do not import `sglang` to make this file standalone | ||||||||||||||||||||||||||
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1"))) | ||||||||||||||||||||||||||
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp")) | ||||||||||||||||||||||||||
self._enable_write_file = bool( | ||||||||||||||||||||||||||
int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1")) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
self._partial_name: Optional[str] = None | ||||||||||||||||||||||||||
self._dump_index = 0 | ||||||||||||||||||||||||||
self._forward_pass_id = 0 | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def on_forward_pass_start(self): | ||||||||||||||||||||||||||
self._forward_pass_id += 1 | ||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def dump(self, name, value, **kwargs): | ||||||||||||||||||||||||||
if not self._enable: | ||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
assert ( | ||||||||||||||||||||||||||
self._forward_pass_id >= 1 | ||||||||||||||||||||||||||
), "Do you forget to call `dumper.on_forward_pass_start()`?" | ||||||||||||||||||||||||||
self._dump_index += 1 | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if self._partial_name is None: | ||||||||||||||||||||||||||
self._partial_name = _get_partial_name() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
rank = dist.get_rank() | ||||||||||||||||||||||||||
full_kwargs = dict( | ||||||||||||||||||||||||||
forward_pass_id=self._forward_pass_id, | ||||||||||||||||||||||||||
rank=rank, | ||||||||||||||||||||||||||
name=name, | ||||||||||||||||||||||||||
dump_index=self._dump_index, | ||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt" | ||||||||||||||||||||||||||
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
sample_value = get_truncated_value(value) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||
f"[Dumper] [{rank}, {time.time()}] {path} " | ||||||||||||||||||||||||||
f"type={type(value)} " | ||||||||||||||||||||||||||
f"shape={value.shape if isinstance(value, torch.Tensor) else None} " | ||||||||||||||||||||||||||
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} " | ||||||||||||||||||||||||||
f"sample_value={sample_value}" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if self._enable_write_file: | ||||||||||||||||||||||||||
path.parent.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||
torch.save(value, str(path)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _get_partial_name(): | ||||||||||||||||||||||||||
rank = dist.get_rank() | ||||||||||||||||||||||||||
object_list = [str(time.time()) if rank == 0 else None] | ||||||||||||||||||||||||||
dist.broadcast_object_list(object_list, device="cuda") | ||||||||||||||||||||||||||
return object_list[0] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def get_truncated_value(value): | ||||||||||||||||||||||||||
if value is None: | ||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if isinstance(value, tuple): | ||||||||||||||||||||||||||
return [get_truncated_value(x) for x in value] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if not isinstance(value, torch.Tensor): | ||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if value.numel() < 200: | ||||||||||||||||||||||||||
return value | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
slices = [ | ||||||||||||||||||||||||||
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
Comment on lines
+99
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tensor truncation logic appears inconsistent. The tensor is truncated if
Suggested change
|
||||||||||||||||||||||||||
return value[tuple(slices)] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
dumper = _Dumper() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
pl.all_horizontal
for combining Polars expressions, as it can be more concise thanfunctools.reduce
. This improves readability.