Skip to content

Commit 82e6105

Browse files
author
Jonathan Mitchell
committed
readme
Signed-off-by: Jonathan Mitchell <[email protected]>
1 parent 310a064 commit 82e6105

File tree

2 files changed

+7
-70
lines changed

2 files changed

+7
-70
lines changed

examples/pytorch/transformer/README.md

Lines changed: 6 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,11 @@
22

33
This directory contains a comprehensive testing framework for validating **Context Parallelism (CP)** in TransformerEngine. The framework compares single-GPU baseline runs against distributed multi-GPU runs to ensure numerical consistency and correctness.
44

5-
**Key Features:**
6-
- 🔄 Automatic synchronization using file-based completion markers
7-
- 📊 Configurable logging verbosity for cleaner output
8-
- 🎯 Dynamic sequence dimension handling for BSHD/SBHD formats
9-
- ⚡ Optimized device placement for reduced CPU-GPU overhead
10-
- 🛡️ Comprehensive shape validation and error handling
11-
125
## 🏗️ Architecture Overview
136

147
### Test Models
158

16-
The framework supports two attention formats, each with its own model implementation in `model.py`:
9+
The framework supports two attention input formats, each with its own model implementation in `model.py`:
1710

1811
**1. SimpleThDModel (THD Format - Token-Head-Dimension):**
1912
- Processes sequences in flattened token format
@@ -76,7 +69,7 @@ The framework runs parallel test suites for both THD and BSHD formats:
7669
- `context_parallel_runner_thd.py` (THD format)
7770
- `context_parallel_runner_bshd.py` (BSHD format)
7871

79-
Both runners execute on a single GPU:
72+
Torchrun is used for both programs, at first no parallelism is used and we have two identical forward passes (one per gpu).
8073
1. **Model Creation**: Initialize model (SimpleThDModel or SimpleBSHDModel)
8174
2. **Forward Pass**: Process full sequences without parallelization
8275
3. **Loss Computation**: Cross-entropy on valid (non-padded) tokens
@@ -87,7 +80,7 @@ Both runners execute on a single GPU:
8780
5. **State Persistence**: Save model weights and results to `/tmp/` for CP=2 comparison
8881

8982
### Phase 2: Distributed Run (CP=2)
90-
**Execution**: Same runner files in distributed mode via `torchrun`
83+
Then torch distributed is initialized and we use context parallel=2, both gpus now participate in a forward pass.
9184

9285
1. **Process Group Setup**: Initialize NCCL backend for 2 GPUs
9386
2. **Device Mesh**: Create `(fsdp=1, cp=2, tp=1)` parallelization strategy
@@ -140,7 +133,7 @@ LOGITS_ELEMENT_TOLERANCE = 2e-2 # Individual element tolerance
140133

141134
**Success Criteria**: ≥85% of logit elements must be within 2e-2 absolute difference
142135

143-
**Why 85%?** Distributed computation with mixed precision (bfloat16) introduces expected numerical differences. This threshold balances strictness with practical distributed computing realities.
136+
**Why 85%?** Distributed computation with mixed precision (bfloat16) introduces expected numerical differences. This threshold balances strictness with practical distributed computing realities. Moreover, we notice that as we increase the number of hidden layers (TE Layers), we see the numerical differences between the non CP and CP counterparts increase.
144137

145138
### Test 5: Loss Similarity
146139
- Compares averaged losses from both CP ranks
@@ -172,7 +165,6 @@ The framework uses **scientifically calibrated tolerances** based on:
172165
1. **Mixed Precision Effects**: bfloat16 has ~3-4 decimal digits of precision
173166
2. **Distributed Communication**: AllReduce operations introduce small numerical errors
174167
3. **Computation Order**: Different operation sequences in CP vs non-CP modes
175-
4. **Hardware Variations**: GPU-to-GPU numerical differences
176168

177169
**Conservative but Practical**: Tolerances are tight enough to catch real bugs while loose enough to handle expected distributed computing variations.
178170

@@ -182,10 +174,6 @@ The framework uses **scientifically calibrated tolerances** based on:
182174
```bash
183175
# Run both THD and BSHD format tests with automatic synchronization
184176
bash run_context_parallel.sh
185-
186-
# Run individual format tests
187-
bash run_context_parallel_thd.sh # THD format only
188-
bash run_context_parallel_bshd.sh # BSHD format only
189177
```
190178

191179
### Script Features
@@ -258,63 +246,12 @@ LOGITS_ELEMENT_TOLERANCE = 2e-2 # Stricter: 1e-2, Looser: 5e-2
258246
GRADIENT_MAX_ABSOLUTE_DIFF_TOLERANCE = 0.05 # Stricter: 0.01, Looser: 0.1
259247
```
260248

261-
## 🎯 What This Framework Validates
249+
## 🎯 What This Folder Validates
262250

263251
**Numerical Correctness**: CP=2 produces equivalent results to CP=1
264252
**Format Compatibility**: Both THD and BSHD attention formats work correctly
265253
**Gradient Consistency**: Distributed training gradients match single-GPU
266254
**Loss Preservation**: Training objectives remain unchanged
267255
**Sequence Reconstruction**: Distributed chunks correctly reassemble
268256
**Memory Efficiency**: Context parallelism reduces per-GPU memory usage
269-
**Scalability**: Framework extends to larger CP sizes and models
270-
271-
## 🔍 Debugging Failed Tests
272-
273-
**Logits Test Failure**: Check for model weight synchronization issues or incorrect sequence chunking
274-
275-
**Gradient Test Failure**: Investigate DDP configuration or learning rate scaling
276-
277-
**Loss Test Failure**: Verify identical random seeds and data preprocessing
278-
279-
**Reconstruction Failure**: Debug `get_batch_on_this_cp_rank()` slicing logic for the specific format (THD vs BSHD)
280-
281-
**Format-Specific Issues**:
282-
- THD: Check cumulative sequence length calculations
283-
- BSHD: Verify batch dimension handling in reconstruction
284-
285-
## 📊 Performance Characteristics
286-
287-
- **Model Size**: ~2.1M parameters (lightweight but representative)
288-
- **Memory Usage**: ~50MB per GPU (enables testing on modest hardware)
289-
- **Runtime**: ~30 seconds per format test suite (fast iteration cycles)
290-
- **Scalability**: Easily extends to larger models and more GPUs
291-
- **Logging**: Clean output by default, verbose logs available on demand
292-
293-
## 🔄 Recent Improvements
294-
295-
### Core Functionality
296-
- **Dynamic Sequence Dimension Handling**: Automatically determines correct dimension based on tensor format (BSHD vs SBHD)
297-
- **Optimized Device Placement**: Index tensors now created directly on target device, eliminating CPU-GPU sync overhead
298-
- **Comprehensive Shape Validation**: Added pre-reshape checks with informative error messages
299-
- **Enhanced Error Handling**: Proper exception handling for shape mismatches and invalid operations
300-
301-
### Testing Infrastructure
302-
- **File-Based Completion Markers**: Reliable synchronization between distributed processes
303-
- **Automatic Process Coordination**: Ensures sequential execution of multiple `torchrun` commands
304-
- **Configurable Timeout Protection**: Prevents hanging on failed distributed runs
305-
- **Clean Marker Management**: Automatic cleanup of old completion markers
306-
307-
### Developer Experience
308-
- **Dual Format Support**: Added BSHD format alongside original THD format
309-
- **Cleaner Output**: Replaced verbose print statements with configurable Python logging
310-
- **Reduced Code Duplication**: Consolidated common logic between test suites
311-
- **Better Debugging**: Rank-aware logging in distributed mode
312-
- **Professional Structure**: Removed redundant comments and improved code organization
313-
- **Flexible Verbosity**: Toggle between clean and detailed output modes
314-
315-
### Test Coverage
316-
- **New BSHD/SBHD Format Tests**: Added comprehensive tests in `test_cp_utils.py`
317-
- **Format-Specific Validation**: Separate test suites for THD and BSHD formats
318-
- **Improved Test Isolation**: Each test properly manages its own state
319-
320-
This framework provides a **robust, automated, and scientifically sound** approach to validating context parallelism implementations in TransformerEngine for multiple attention formats.
257+
**Scalability**: Folder extends to larger CP sizes and models

examples/pytorch/transformer/test_context_parallel_bshd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def test_cp_indices_calculation(load_test_data):
250250

251251
# Get baseline logits to determine batch size and sequence length
252252
cp1_logits = test_data['cp1_results']['logits']
253-
batch_size, seq_len, vocab_size = cp1_logits.shape
253+
batch_size, seq_len, _ = cp1_logits.shape
254254

255255
# Calculate indices for CP=2
256256
rank_indices = calculate_cp_indices_bshd(batch_size, seq_len, cp_size=2)

0 commit comments

Comments
 (0)