Skip to content

Commit a11bc61

Browse files
author
Yifan Qiao
authored
Merge pull request #1 from ivanium/pr-sp-rope
Sequence Parallel system setup
2 parents a9ef49c + dd2382d commit a11bc61

File tree

6 files changed

+440
-10
lines changed

6 files changed

+440
-10
lines changed

playground/sp_rope.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import random
2+
from typing import Optional
3+
4+
import numpy as np
5+
import torch
6+
from transformers import AutoConfig
7+
from vllm.model_executor.layers.rotary_embedding import get_rope
8+
9+
10+
def gen_rope_model(config):
11+
hidden_size = config.hidden_size
12+
rope_theta = getattr(config, "rope_theta", 10000)
13+
rope_scaling = getattr(config, "rope_scaling", None)
14+
if rope_scaling is not None and getattr(
15+
config, "original_max_position_embeddings", None
16+
):
17+
rope_scaling["original_max_position_embeddings"] = (
18+
config.original_max_position_embeddings
19+
)
20+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
21+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
22+
num_head = getattr(config, "num_attention_heads")
23+
head_dim = hidden_size // num_head
24+
25+
print(
26+
head_dim, max_position_embeddings, rope_theta, rope_scaling, rope_is_neox_style
27+
)
28+
rope = get_rope(
29+
head_dim, # 128
30+
rotary_dim=head_dim, # 128
31+
max_position=max_position_embeddings, # 8192
32+
base=rope_theta, # 5e5
33+
rope_scaling=rope_scaling, # None
34+
is_neox_style=rope_is_neox_style, # True
35+
)
36+
return rope
37+
38+
39+
def get_metadata(config):
40+
hidden_size = config.hidden_size
41+
num_head = getattr(config, "num_attention_heads")
42+
num_kv_head = 8
43+
head_dim = hidden_size // num_head
44+
return num_head, num_kv_head, head_dim
45+
46+
47+
def gen_data(config, tp_size, sp_size, num_tokens):
48+
num_head, num_kv_head, head_dim = get_metadata(config)
49+
num_head_tp = num_head // tp_size
50+
num_kv_head_tp = num_kv_head // tp_size
51+
torch.manual_seed(42)
52+
random.seed(42)
53+
full_q = torch.rand((num_tokens, num_head_tp, head_dim), dtype=torch.bfloat16)
54+
full_k = torch.rand((num_tokens, num_kv_head_tp, head_dim), dtype=torch.bfloat16)
55+
full_pos = []
56+
extend_seq_lens = []
57+
extend_start_loc = []
58+
# req_offset = []
59+
while len(full_pos) < num_tokens:
60+
length = random.randint(sp_size, num_tokens // 2)
61+
length = min(length, num_tokens - len(full_pos))
62+
len_offset = random.randint(0, num_tokens)
63+
64+
extend_seq_lens.append(length)
65+
extend_start_loc.append(len(full_pos))
66+
67+
full_pos.extend(range(len_offset, length + len_offset))
68+
# req_offset.append(len_offset)
69+
print(f"extend_seq_lens:{extend_seq_lens}")
70+
full_pos = torch.from_numpy(np.fromiter(full_pos, dtype=np.int32)).to(torch.long)
71+
extend_seq_lens = torch.from_numpy(np.fromiter(extend_seq_lens, dtype=np.int32)).to(
72+
torch.long
73+
)
74+
extend_start_loc = torch.from_numpy(
75+
np.fromiter(extend_start_loc, dtype=np.int32)
76+
).to(torch.long)
77+
return full_q, full_k, full_pos, extend_seq_lens, extend_start_loc
78+
79+
80+
def gen_sp_indices(
81+
sp_size, extend_seq_lens, extend_start_loc, num_tokens, test: bool = False
82+
):
83+
all_indices = set()
84+
indices = []
85+
for sp_rank in range(sp_size):
86+
sp_indices = get_prefill_indices(
87+
sp_rank, sp_size, extend_seq_lens, extend_start_loc
88+
)
89+
indices.append(sp_indices)
90+
if test:
91+
num_sp_indices = set([i for i in sp_indices])
92+
assert all_indices.isdisjoint(
93+
set(num_sp_indices)
94+
), all_indices.intersection(num_sp_indices)
95+
all_indices.update(set(num_sp_indices))
96+
if test:
97+
assert all_indices == set(range(0, num_tokens))
98+
return indices
99+
100+
101+
######## Function to be test
102+
def get_prefill_indices(sp_rank, sp_size, extend_seq_lens, extend_start_loc):
103+
sp_req_len = extend_seq_lens // sp_size + (
104+
(extend_seq_lens % sp_size) > sp_rank
105+
).to(torch.long)
106+
# the offset of each request in the batch. Only the first few ranks may get 1 more token (for each).
107+
# for sp_rank=r, there are r ranks ahread (since 0-based), each may get one token
108+
sp_in_req_offset = extend_seq_lens // sp_size * sp_rank + torch.clamp(
109+
extend_seq_lens % sp_size, max=sp_rank
110+
)
111+
sp_req_start = extend_start_loc + sp_in_req_offset
112+
sp_indices = torch.concat(
113+
[torch.arange(s, s + l) for s, l in zip(sp_req_start, sp_req_len)]
114+
)
115+
return sp_indices.cpu().numpy()
116+
117+
118+
def get_decode_mask(sp_rank, sp_size, seq_lens):
119+
# True means the corresponding token is located on this device. Otherwise False.
120+
return seq_lens % sp_size == sp_rank
121+
122+
123+
@torch.no_grad()
124+
def main():
125+
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
126+
tp_size = 4
127+
sp_size = 4
128+
num_tokens = 1280
129+
(full_q, full_k, full_pos, extend_seq_lens, extend_start_loc) = gen_data(
130+
config, tp_size, sp_size, num_tokens
131+
)
132+
sp_indices = gen_sp_indices(
133+
sp_size, extend_seq_lens, extend_start_loc, num_tokens, test=True
134+
)
135+
# correct result
136+
rope = gen_rope_model(config)
137+
full_pos = full_pos.cuda()
138+
full_q = full_q.cuda().reshape(num_tokens, -1)
139+
full_k = full_k.cuda().reshape(num_tokens, -1)
140+
141+
cor_q, cor_k = rope(full_pos, torch.clone(full_q), torch.clone(full_k))
142+
143+
# simulated result
144+
test_q = torch.zeros_like(cor_q).squeeze(0)
145+
test_k = torch.zeros_like(cor_k).squeeze(0)
146+
for sp_idx in sp_indices:
147+
q, k = rope(
148+
full_pos[sp_idx], torch.clone(full_q[sp_idx]), torch.clone(full_k[sp_idx])
149+
)
150+
test_q[sp_idx] = q
151+
test_k[sp_idx] = k
152+
test_q = test_q.reshape(cor_q.shape)
153+
test_k = test_k.reshape(cor_k.shape)
154+
155+
torch.testing.assert_close(cor_q, test_q)
156+
torch.testing.assert_close(cor_k, test_k)
157+
158+
159+
main()

python/sglang/bench_latency.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def from_cli_args(cls, args: argparse.Namespace):
6868
return cls(**{attr: getattr(args, attr) for attr in attrs})
6969

7070

71-
def load_model(server_args, tp_rank):
71+
def load_model(server_args, tp_rank, sp_rank: int = 0):
7272
suppress_other_loggers()
7373
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
7474

@@ -79,6 +79,8 @@ def load_model(server_args, tp_rank):
7979
gpu_id=tp_rank,
8080
tp_rank=tp_rank,
8181
tp_size=server_args.tp_size,
82+
sp_rank=sp_rank,
83+
sp_size=server_args.sp_size,
8284
nccl_port=28888,
8385
server_args=server_args,
8486
)
@@ -153,6 +155,8 @@ def extend(reqs, model_runner):
153155
req_to_token_pool=model_runner.req_to_token_pool,
154156
token_to_kv_pool=model_runner.token_to_kv_pool,
155157
tree_cache=None,
158+
sp_size=model_runner.sp_size,
159+
sp_rank=model_runner.sp_rank,
156160
)
157161
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
158162
output = model_runner.forward(batch, ForwardMode.EXTEND)
@@ -172,11 +176,12 @@ def correctness_test(
172176
server_args,
173177
bench_args,
174178
tp_rank,
179+
sp_rank=0,
175180
):
176181
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
177182

178183
# Load the model
179-
model_runner, tokenizer = load_model(server_args, tp_rank)
184+
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)
180185

181186
# Prepare inputs
182187
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
@@ -289,12 +294,14 @@ def main(server_args, bench_args):
289294
else:
290295
workers = []
291296
for tp_rank in range(server_args.tp_size):
297+
sp_rank = tp_rank % server_args.sp_size
292298
proc = multiprocessing.Process(
293299
target=work_func,
294300
args=(
295301
server_args,
296302
bench_args,
297303
tp_rank,
304+
sp_rank,
298305
),
299306
)
300307
proc.start()

0 commit comments

Comments
 (0)