Skip to content

Commit 11a5529

Browse files
V0.7.2 ulysses v1 debug (vllm-project#6)
* print numforward * print forward context * print forward context * print forward context * print forward context * synchronize after Ulysses in graph capture * synchronize after Ulysses in graph capture * synchronize NCCL operations (all-to-all) for graph capture
1 parent 73b53fd commit 11a5529

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

vllm/model_executor/models/llama.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
539539
self.make_empty_intermediate_tensors = (
540540
self.model.make_empty_intermediate_tensors)
541541

542+
self.numforward = 0
543+
542544
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
543545
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
544546

@@ -570,6 +572,12 @@ def forward(
570572
# print(f"input_ids: {input_ids.shape}")
571573
# print(f"positions: {positions.shape}")
572574
# print(f"N {N}, SP {SP}, N_ranks {N_ranks} sum {sum(N_ranks)}")
575+
# from vllm.forward_context import get_forward_context
576+
# if torch.distributed.get_rank() == 0:
577+
# print(f"numforward {self.numforward} N {N} "
578+
# f"N_ranks {N_ranks}")
579+
# if get_forward_context().attn_metadata is not None:
580+
# self.numforward += 1
573581

574582
input_ids[0:N_ulysses] = input_ids[N_start:N_start + N_ulysses]
575583
positions[0:N_ulysses] = positions[N_start:N_start + N_ulysses]

vllm/v1/attention/backends/flash_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def forward(
316316
group=self.device_group)
317317
c = torch.transpose(c, 0, 1).reshape(
318318
N_ulysses, self.num_heads * self.SP * self.head_size)
319+
# synchronization of NCCL operations is necessary for graph capture
320+
if attn_metadata is None:
321+
torch.cuda.synchronize()
319322
return c
320323

321324

0 commit comments

Comments
 (0)