|
| 1 | +import random |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from transformers import AutoConfig |
| 7 | +from vllm.model_executor.layers.rotary_embedding import get_rope |
| 8 | + |
| 9 | + |
| 10 | +def gen_rope_model(config): |
| 11 | + hidden_size = config.hidden_size |
| 12 | + rope_theta = getattr(config, "rope_theta", 10000) |
| 13 | + rope_scaling = getattr(config, "rope_scaling", None) |
| 14 | + if rope_scaling is not None and getattr( |
| 15 | + config, "original_max_position_embeddings", None |
| 16 | + ): |
| 17 | + rope_scaling["original_max_position_embeddings"] = ( |
| 18 | + config.original_max_position_embeddings |
| 19 | + ) |
| 20 | + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) |
| 21 | + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) |
| 22 | + num_head = getattr(config, "num_attention_heads") |
| 23 | + head_dim = hidden_size // num_head |
| 24 | + |
| 25 | + print( |
| 26 | + head_dim, max_position_embeddings, rope_theta, rope_scaling, rope_is_neox_style |
| 27 | + ) |
| 28 | + rope = get_rope( |
| 29 | + head_dim, # 128 |
| 30 | + rotary_dim=head_dim, # 128 |
| 31 | + max_position=max_position_embeddings, # 8192 |
| 32 | + base=rope_theta, # 5e5 |
| 33 | + rope_scaling=rope_scaling, # None |
| 34 | + is_neox_style=rope_is_neox_style, # True |
| 35 | + ) |
| 36 | + return rope |
| 37 | + |
| 38 | + |
| 39 | +def get_metadata(config): |
| 40 | + hidden_size = config.hidden_size |
| 41 | + num_head = getattr(config, "num_attention_heads") |
| 42 | + num_kv_head = 8 |
| 43 | + head_dim = hidden_size // num_head |
| 44 | + return num_head, num_kv_head, head_dim |
| 45 | + |
| 46 | + |
| 47 | +def gen_data(config, tp_size, sp_size, num_tokens): |
| 48 | + num_head, num_kv_head, head_dim = get_metadata(config) |
| 49 | + num_head_tp = num_head // tp_size |
| 50 | + num_kv_head_tp = num_kv_head // tp_size |
| 51 | + torch.manual_seed(42) |
| 52 | + random.seed(42) |
| 53 | + full_q = torch.rand((num_tokens, num_head_tp, head_dim), dtype=torch.bfloat16) |
| 54 | + full_k = torch.rand((num_tokens, num_kv_head_tp, head_dim), dtype=torch.bfloat16) |
| 55 | + full_pos = [] |
| 56 | + extend_seq_lens = [] |
| 57 | + extend_start_loc = [] |
| 58 | + # req_offset = [] |
| 59 | + while len(full_pos) < num_tokens: |
| 60 | + length = random.randint(sp_size, num_tokens // 2) |
| 61 | + length = min(length, num_tokens - len(full_pos)) |
| 62 | + len_offset = random.randint(0, num_tokens) |
| 63 | + |
| 64 | + extend_seq_lens.append(length) |
| 65 | + extend_start_loc.append(len(full_pos)) |
| 66 | + |
| 67 | + full_pos.extend(range(len_offset, length + len_offset)) |
| 68 | + # req_offset.append(len_offset) |
| 69 | + print(f"extend_seq_lens:{extend_seq_lens}") |
| 70 | + full_pos = torch.from_numpy(np.fromiter(full_pos, dtype=np.int32)).to(torch.long) |
| 71 | + extend_seq_lens = torch.from_numpy(np.fromiter(extend_seq_lens, dtype=np.int32)).to( |
| 72 | + torch.long |
| 73 | + ) |
| 74 | + extend_start_loc = torch.from_numpy( |
| 75 | + np.fromiter(extend_start_loc, dtype=np.int32) |
| 76 | + ).to(torch.long) |
| 77 | + return full_q, full_k, full_pos, extend_seq_lens, extend_start_loc |
| 78 | + |
| 79 | + |
| 80 | +def gen_sp_indices( |
| 81 | + sp_size, extend_seq_lens, extend_start_loc, num_tokens, test: bool = False |
| 82 | +): |
| 83 | + all_indices = set() |
| 84 | + indices = [] |
| 85 | + for sp_rank in range(sp_size): |
| 86 | + sp_indices = get_prefill_indices( |
| 87 | + sp_rank, sp_size, extend_seq_lens, extend_start_loc |
| 88 | + ) |
| 89 | + indices.append(sp_indices) |
| 90 | + if test: |
| 91 | + num_sp_indices = set([i for i in sp_indices]) |
| 92 | + assert all_indices.isdisjoint( |
| 93 | + set(num_sp_indices) |
| 94 | + ), all_indices.intersection(num_sp_indices) |
| 95 | + all_indices.update(set(num_sp_indices)) |
| 96 | + if test: |
| 97 | + assert all_indices == set(range(0, num_tokens)) |
| 98 | + return indices |
| 99 | + |
| 100 | + |
| 101 | +######## Function to be test |
| 102 | +def get_prefill_indices(sp_rank, sp_size, extend_seq_lens, extend_start_loc): |
| 103 | + sp_req_len = extend_seq_lens // sp_size + ( |
| 104 | + (extend_seq_lens % sp_size) > sp_rank |
| 105 | + ).to(torch.long) |
| 106 | + # the offset of each request in the batch. Only the first few ranks may get 1 more token (for each). |
| 107 | + # for sp_rank=r, there are r ranks ahread (since 0-based), each may get one token |
| 108 | + sp_in_req_offset = extend_seq_lens // sp_size * sp_rank + torch.clamp( |
| 109 | + extend_seq_lens % sp_size, max=sp_rank |
| 110 | + ) |
| 111 | + sp_req_start = extend_start_loc + sp_in_req_offset |
| 112 | + sp_indices = torch.concat( |
| 113 | + [torch.arange(s, s + l) for s, l in zip(sp_req_start, sp_req_len)] |
| 114 | + ) |
| 115 | + return sp_indices.cpu().numpy() |
| 116 | + |
| 117 | + |
| 118 | +def get_decode_mask(sp_rank, sp_size, seq_lens): |
| 119 | + # True means the corresponding token is located on this device. Otherwise False. |
| 120 | + return seq_lens % sp_size == sp_rank |
| 121 | + |
| 122 | + |
| 123 | +@torch.no_grad() |
| 124 | +def main(): |
| 125 | + config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
| 126 | + tp_size = 4 |
| 127 | + sp_size = 4 |
| 128 | + num_tokens = 1280 |
| 129 | + (full_q, full_k, full_pos, extend_seq_lens, extend_start_loc) = gen_data( |
| 130 | + config, tp_size, sp_size, num_tokens |
| 131 | + ) |
| 132 | + sp_indices = gen_sp_indices( |
| 133 | + sp_size, extend_seq_lens, extend_start_loc, num_tokens, test=True |
| 134 | + ) |
| 135 | + # correct result |
| 136 | + rope = gen_rope_model(config) |
| 137 | + full_pos = full_pos.cuda() |
| 138 | + full_q = full_q.cuda().reshape(num_tokens, -1) |
| 139 | + full_k = full_k.cuda().reshape(num_tokens, -1) |
| 140 | + |
| 141 | + cor_q, cor_k = rope(full_pos, torch.clone(full_q), torch.clone(full_k)) |
| 142 | + |
| 143 | + # simulated result |
| 144 | + test_q = torch.zeros_like(cor_q).squeeze(0) |
| 145 | + test_k = torch.zeros_like(cor_k).squeeze(0) |
| 146 | + for sp_idx in sp_indices: |
| 147 | + q, k = rope( |
| 148 | + full_pos[sp_idx], torch.clone(full_q[sp_idx]), torch.clone(full_k[sp_idx]) |
| 149 | + ) |
| 150 | + test_q[sp_idx] = q |
| 151 | + test_k[sp_idx] = k |
| 152 | + test_q = test_q.reshape(cor_q.shape) |
| 153 | + test_k = test_k.reshape(cor_k.shape) |
| 154 | + |
| 155 | + torch.testing.assert_close(cor_q, test_q) |
| 156 | + torch.testing.assert_close(cor_k, test_k) |
| 157 | + |
| 158 | + |
| 159 | +main() |
0 commit comments