-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Feature] Add a test for Layer-wise Prefill #8231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,299 @@ | ||
""" | ||
Test forward_split_prefill functionality. | ||
|
||
Usage: | ||
python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill | ||
or | ||
python3 test_forward_split_prefill.py | ||
""" | ||
|
||
import time | ||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from sglang.srt.configs.model_config import ModelConfig | ||
from sglang.srt.hf_transformers_utils import get_tokenizer | ||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch | ||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode | ||
from sglang.srt.model_executor.model_runner import ModelRunner | ||
from sglang.srt.sampling.sampling_params import SamplingParams | ||
from sglang.srt.server_args import PortArgs, ServerArgs | ||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm | ||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase | ||
|
||
|
||
class TestForwardSplitPrefill(CustomTestCase): | ||
"""Test cases for forward_split_prefill functionality.""" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
"""Set up the test environment once for all tests.""" | ||
cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST | ||
cls.tp_size = 1 | ||
cls.device = "cuda" | ||
|
||
# Initialize server args | ||
cls.server_args = ServerArgs( | ||
model_path=cls.model_path, | ||
tokenizer_path=cls.model_path, | ||
host="127.0.0.1", | ||
disable_cuda_graph=True, # Disable CUDA graph for testing split prefill | ||
disable_hybrid_swa_memory=True, | ||
port=30000, | ||
tp_size=cls.tp_size, | ||
mem_fraction_static=0.8, | ||
trust_remote_code=True, | ||
) | ||
|
||
cls.port_args = PortArgs.init_new(cls.server_args) | ||
|
||
# Load model and tokenizer | ||
cls.model_config = ModelConfig.from_server_args(cls.server_args) | ||
cls.model_runner = ModelRunner( | ||
model_config=cls.model_config, | ||
mem_fraction_static=cls.server_args.mem_fraction_static, | ||
gpu_id=0, | ||
tp_rank=0, | ||
tp_size=cls.tp_size, | ||
pp_rank=0, | ||
pp_size=1, | ||
nccl_port=cls.port_args.nccl_port, | ||
server_args=cls.server_args, | ||
) | ||
|
||
cls.tokenizer = get_tokenizer( | ||
cls.server_args.tokenizer_path, | ||
tokenizer_mode=cls.server_args.tokenizer_mode, | ||
trust_remote_code=cls.server_args.trust_remote_code, | ||
) | ||
|
||
print( | ||
f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}" | ||
) | ||
|
||
def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True): | ||
"""Prepare a test batch for split prefill testing.""" | ||
# Create synthetic input | ||
input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32) | ||
|
||
sampling_params = SamplingParams( | ||
temperature=0.0, | ||
max_new_tokens=8, | ||
) | ||
|
||
reqs = [] | ||
for i in range(batch_size): | ||
req = Req( | ||
rid=i, | ||
origin_input_text="", | ||
origin_input_ids=list(input_ids[i]), | ||
sampling_params=sampling_params, | ||
) | ||
req.prefix_indices = [] | ||
req.fill_ids = req.origin_input_ids | ||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) | ||
req.logprob_start_len = len(req.origin_input_ids) - 1 | ||
reqs.append(req) | ||
|
||
batch = ScheduleBatch.init_new( | ||
reqs=reqs, | ||
req_to_token_pool=self.model_runner.req_to_token_pool, | ||
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, | ||
tree_cache=None, | ||
model_config=self.model_config, | ||
enable_overlap=False, | ||
spec_algorithm=SpeculativeAlgorithm.NONE, | ||
enable_custom_logit_processor=False, | ||
) | ||
if is_split_prefill: | ||
batch.prepare_for_split_prefill() | ||
else: | ||
batch.prepare_for_extend() | ||
|
||
# Create forward batch | ||
model_worker_batch = batch.get_model_worker_batch() | ||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) | ||
|
||
return forward_batch | ||
|
||
def test_split_prefill_functionality(self): | ||
"""Test that split prefill can complete successfully.""" | ||
print("\n=== Testing split prefill functionality ===") | ||
|
||
forward_batch = self.prepare_test_batch(batch_size=2, input_len=64) | ||
|
||
# Reset split index | ||
forward_batch.split_index = 0 | ||
|
||
# Test split prefill in chunks | ||
num_layers = self.model_config.num_hidden_layers | ||
chunk_size = max(1, num_layers // 4) # Split into 4 chunks | ||
|
||
results = [] | ||
split_count = 0 | ||
|
||
while forward_batch.split_index < num_layers: | ||
print( | ||
f"Processing split {split_count}, split_index: {forward_batch.split_index}" | ||
) | ||
|
||
result = self.model_runner.forward_split_prefill( | ||
forward_batch=forward_batch, | ||
reinit_attn_backend=(split_count == 0), | ||
forward_count=chunk_size, | ||
) | ||
|
||
results.append(result) | ||
split_count += 1 | ||
|
||
# Verify split_index is updated correctly | ||
expected_next_index = min(split_count * chunk_size, num_layers) | ||
self.assertEqual(forward_batch.split_index, expected_next_index) | ||
|
||
# The last result should contain logits | ||
self.assertIsNotNone(results[-1], "Final split should return logits") | ||
print(f"Split prefill completed in {split_count} splits") | ||
|
||
def test_split_prefill_vs_normal_prefill(self): | ||
"""Test that split prefill produces the same results as normal prefill.""" | ||
print("\n=== Testing split prefill vs normal prefill consistency ===") | ||
|
||
forward_batch_normal = self.prepare_test_batch( | ||
batch_size=2, input_len=128, is_split_prefill=False | ||
) | ||
forward_batch_split = self.prepare_test_batch( | ||
batch_size=2, input_len=128, is_split_prefill=True | ||
) | ||
|
||
# Ensure same input | ||
forward_batch_split.input_ids = forward_batch_normal.input_ids.clone() | ||
forward_batch_split.positions = forward_batch_normal.positions.clone() | ||
|
||
# Method 1: Normal extend (prefill) | ||
print("Running normal extend (prefill)...") | ||
normal_result = self.model_runner.forward_extend(forward_batch_normal) | ||
|
||
# Method 2: Split prefill | ||
print("Running split prefill...") | ||
num_layers = self.model_config.num_hidden_layers | ||
chunk_size = max(1, num_layers // 3) # Split into 3 chunks | ||
|
||
split_result = None | ||
|
||
while forward_batch_split.split_index < num_layers: | ||
result = self.model_runner.forward_split_prefill( | ||
forward_batch=forward_batch_split, | ||
forward_count=chunk_size, | ||
) | ||
if result is not None: | ||
split_result = result | ||
|
||
# Compare results | ||
self.assertIsNotNone(normal_result, "Normal prefill should return result") | ||
self.assertIsNotNone(split_result, "Split prefill should return result") | ||
|
||
# Compare logits shapes | ||
self.assertEqual( | ||
normal_result.next_token_logits.shape, | ||
split_result.next_token_logits.shape, | ||
"Logits shapes should match", | ||
) | ||
|
||
# Compare logits values (should be very close due to same computation) | ||
# Use a larger tolerance for numerical differences in split computation | ||
torch.testing.assert_close( | ||
normal_result.next_token_logits, | ||
split_result.next_token_logits, | ||
rtol=1e-3, | ||
atol=1e-3, | ||
msg="Split prefill and normal prefill should produce similar logits", | ||
) | ||
|
||
print("✓ Split prefill and normal prefill produce consistent results") | ||
|
||
def test_split_prefill_different_chunk_sizes(self): | ||
"""Test split prefill with different chunk sizes.""" | ||
print("\n=== Testing split prefill with different chunk sizes ===") | ||
|
||
num_layers = self.model_config.num_hidden_layers | ||
chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers] | ||
|
||
# Prepare identical batches for each test | ||
base_batch = self.prepare_test_batch(batch_size=1, input_len=16) | ||
base_input_ids = base_batch.input_ids.clone() | ||
base_positions = base_batch.positions.clone() | ||
|
||
results = [] | ||
|
||
for chunk_size in chunk_sizes: | ||
if chunk_size > num_layers: | ||
continue | ||
|
||
print(f"Testing chunk size: {chunk_size}") | ||
|
||
# Prepare fresh batch | ||
forward_batch = self.prepare_test_batch(batch_size=1, input_len=16) | ||
forward_batch.input_ids = base_input_ids.clone() | ||
forward_batch.positions = base_positions.clone() | ||
forward_batch.split_index = 0 | ||
|
||
# Run split prefill | ||
split_result = None | ||
|
||
while forward_batch.split_index < num_layers: | ||
result = self.model_runner.forward_split_prefill( | ||
forward_batch=forward_batch, | ||
forward_count=chunk_size, | ||
) | ||
Comment on lines
+246
to
+249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The result = self.model_runner.forward_split_prefill(
forward_batch=forward_batch,
reinit_attn_backend=(forward_batch.split_index == 0),
forward_count=chunk_size, |
||
if result is not None: | ||
split_result = result | ||
|
||
self.assertIsNotNone( | ||
split_result, | ||
f"Split prefill should succeed with chunk_size={chunk_size}", | ||
) | ||
results.append(split_result) | ||
|
||
# Compare all results should be identical (same input, same computation) | ||
if len(results) > 1: | ||
for i, result in enumerate(results[1:], 1): | ||
torch.testing.assert_close( | ||
results[0].next_token_logits, | ||
result.next_token_logits, | ||
rtol=1e-3, | ||
atol=1e-3, | ||
msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})", | ||
) | ||
|
||
print("✓ All chunk sizes produce consistent results") | ||
|
||
def test_split_prefill_edge_cases(self): | ||
"""Test edge cases for split prefill.""" | ||
print("\n=== Testing split prefill edge cases ===") | ||
|
||
# Test with single layer chunks | ||
forward_batch = self.prepare_test_batch(batch_size=1, input_len=8) | ||
|
||
# Process one layer at a time | ||
num_layers = self.model_config.num_hidden_layers | ||
for layer_idx in range(num_layers): | ||
result = self.model_runner.forward_split_prefill( | ||
forward_batch=forward_batch, | ||
reinit_attn_backend=(layer_idx == 0), | ||
forward_count=1, # One layer at a time | ||
) | ||
|
||
if layer_idx == num_layers - 1: | ||
# Last layer should return result | ||
self.assertIsNotNone(result, "Last layer should return logits") | ||
else: | ||
# Intermediate layers should return None | ||
self.assertIsNone(result, f"Layer {layer_idx} should return None") | ||
|
||
print("✓ Single layer processing works correctly") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
reinit_attn_backend
argument is missing in the call toforward_split_prefill
. This argument should be set toTrue
for the first chunk of the split prefill to ensure the attention backend is correctly initialized. Without this, the test's correctness could be compromised by state left from previous tests, leading to flaky results.