|
| 1 | +# Context Parallel Testing Framework |
| 2 | + |
| 3 | +This directory contains a comprehensive testing framework for validating **Context Parallelism (CP)** in TransformerEngine. The framework compares single-GPU baseline runs against distributed multi-GPU runs to ensure numerical consistency and correctness. |
| 4 | + |
| 5 | +## 🏗️ Architecture Overview |
| 6 | + |
| 7 | +### Test Models |
| 8 | + |
| 9 | +The framework supports two attention input formats, each with its own model implementation in `model.py`: |
| 10 | + |
| 11 | +**1. SimpleThDModel (THD Format - Token-Head-Dimension):** |
| 12 | +- Processes sequences in flattened token format |
| 13 | +- Uses cumulative sequence lengths for batch handling |
| 14 | +- Attention computation in THD layout |
| 15 | + |
| 16 | +**2. SimpleBSHDModel (BSHD Format - Batch-Sequence-Head-Dimension):** |
| 17 | +- Processes sequences in standard batch format |
| 18 | +- Maintains batch dimension throughout computation |
| 19 | +- Attention computation in BSHD layout |
| 20 | + |
| 21 | +**Common Architecture (both models):** |
| 22 | +- **Embedding Layer**: Token embedding (vocab_size=33, hidden_size=320) |
| 23 | +- **1 TransformerEngine Layer**: Full attention + MLP block with: |
| 24 | + - 20 attention heads (16-dimensional each) |
| 25 | + - 1280 intermediate FFN size |
| 26 | + - GELU activation |
| 27 | + - RoPE positional embeddings |
| 28 | + - Mixed precision (bfloat16) |
| 29 | +- **Output Layer**: Linear projection back to vocabulary space |
| 30 | +- **Layer Normalization**: Applied after transformer layers |
| 31 | + |
| 32 | +**Key Features:** |
| 33 | +- Designed for **variable-length sequences** with padding |
| 34 | +- Uses **cumulative sequence lengths** (`cu_seqlens`) for efficient batching |
| 35 | +- Supports **context parallel** attention computation |
| 36 | +- Deterministic initialization for reproducible testing |
| 37 | + |
| 38 | +### Test Data Generation |
| 39 | + |
| 40 | +The `utils.py` module provides synthetic test data in two formats: |
| 41 | + |
| 42 | +**THD Format (`get_dummy_data_thd()`):** |
| 43 | +```python |
| 44 | +# Three sequences of different lengths (flattened): |
| 45 | +Sequence 1: [1,1,1,1,1,1,1,1] # 8 tokens (padded) |
| 46 | +Sequence 2: [2,2,2,2,2,2,2,2,2,2,2,2] # 12 tokens (padded) |
| 47 | +Sequence 3: [3,3,3,3,3,3,3,3] # 8 tokens (padded) |
| 48 | +# Flattened into single tensor of 28 tokens total |
| 49 | +``` |
| 50 | + |
| 51 | +**BSHD Format (`get_dummy_data_bshd()`):** |
| 52 | +```python |
| 53 | +# Single batch with one long sequence: |
| 54 | +Batch shape: [1, 1024, hidden_size] # Standard batch tensor |
| 55 | +``` |
| 56 | + |
| 57 | +**Data Processing Pipeline:** |
| 58 | +1. **Padding**: Sequences padded to be divisible by `2 * cp_size` (required for CP) |
| 59 | +2. **Cumulative Lengths**: Used for tracking sequence boundaries |
| 60 | +3. **Labels**: Corresponding target tokens for loss computation |
| 61 | +4. **Position IDs**: Relative positions within each sequence |
| 62 | + |
| 63 | +## 🚀 Testing Workflow |
| 64 | + |
| 65 | +The framework runs parallel test suites for both THD and BSHD formats: |
| 66 | + |
| 67 | +### Phase 1: Baseline Run (CP=1) |
| 68 | +**Files**: |
| 69 | +- `context_parallel_runner_thd.py` (THD format) |
| 70 | +- `context_parallel_runner_bshd.py` (BSHD format) |
| 71 | + |
| 72 | +Torchrun is used for both programs, at first no parallelism is used and we have two identical forward passes (one per gpu). |
| 73 | +1. **Model Creation**: Initialize model (SimpleThDModel or SimpleBSHDModel) |
| 74 | +2. **Forward Pass**: Process full sequences without parallelization |
| 75 | +3. **Loss Computation**: Cross-entropy on valid (non-padded) tokens |
| 76 | +4. **Gradient Collection**: Gather gradients from key model components: |
| 77 | + - Embedding layer |
| 78 | + - Transformer layer(s) |
| 79 | + - Output linear layer |
| 80 | +5. **State Persistence**: Save model weights and results to `/tmp/` for CP=2 comparison |
| 81 | + |
| 82 | +### Phase 2: Distributed Run (CP=2) |
| 83 | +Then torch distributed is initialized and we use context parallel=2, both gpus now participate in a forward pass. |
| 84 | + |
| 85 | +1. **Process Group Setup**: Initialize NCCL backend for 2 GPUs |
| 86 | +2. **Device Mesh**: Create `(fsdp=1, cp=2, tp=1)` parallelization strategy |
| 87 | +3. **Model Replication**: Load identical weights from CP=1 baseline |
| 88 | +4. **DDP Wrapping**: Enable gradient synchronization across ranks |
| 89 | +5. **Context Parallel Setup**: Configure attention layers for sequence splitting |
| 90 | +6. **Data Partitioning**: |
| 91 | + - THD: Use `get_batch_on_this_cp_rank(..., qvk_format="thd")` |
| 92 | + - BSHD: Use `get_batch_on_this_cp_rank(..., qvk_format="bshd")` |
| 93 | + - **Rank 0**: Gets first + last chunks of each sequence |
| 94 | + - **Rank 1**: Gets middle chunks of each sequence |
| 95 | +7. **Synchronized Forward/Backward**: Identical computation with distributed data |
| 96 | +8. **Completion Markers**: Creates `/tmp/{format}_complete.marker` when done |
| 97 | + |
| 98 | +### Phase 3: Validation Testing |
| 99 | +**Files**: |
| 100 | +- `test_context_parallel_thd.py` (THD format tests) |
| 101 | +- `test_context_parallel_bshd.py` (BSHD format tests) |
| 102 | + |
| 103 | +Both test suites perform identical validations with format-specific reconstruction: |
| 104 | + |
| 105 | +## 🧪 Test Suite Details |
| 106 | + |
| 107 | +### Test 1: Data Loading Validation |
| 108 | +- Verifies all required result files exist |
| 109 | +- Validates tensor shapes and dimensions |
| 110 | +- Confirms vocabulary size consistency |
| 111 | +- Checks data structure integrity |
| 112 | + |
| 113 | +### Test 2: CP Index Calculation |
| 114 | +- Tests the sequence splitting algorithm |
| 115 | +- Verifies no overlap between rank assignments |
| 116 | +- Confirms complete sequence coverage |
| 117 | +- Validates chunk size calculations |
| 118 | + |
| 119 | +### Test 3: Logits Reconstruction |
| 120 | +- Reconstructs full sequences from distributed chunks |
| 121 | +- Validates reconstruction algorithm correctness |
| 122 | +- Ensures output shapes match baseline |
| 123 | +- Tests the "first+last vs middle" chunk distribution |
| 124 | + |
| 125 | +### Test 4: Logits Accuracy Comparison ⭐ |
| 126 | +**The Core Test**: Compares reconstructed CP=2 logits against CP=1 baseline |
| 127 | + |
| 128 | +**Tolerance Configuration:** |
| 129 | +```python |
| 130 | +LOGITS_ACCURACY_THRESHOLD = 85.0 # % of elements within tolerance |
| 131 | +LOGITS_ELEMENT_TOLERANCE = 2e-2 # Individual element tolerance |
| 132 | +``` |
| 133 | + |
| 134 | +**Success Criteria**: ≥85% of logit elements must be within 2e-2 absolute difference |
| 135 | + |
| 136 | +**Why 85%?** Distributed computation with mixed precision (bfloat16) introduces expected numerical differences. This threshold balances strictness with practical distributed computing realities. Moreover, we notice that as we increase the number of hidden layers (TE Layers), we see the numerical differences between the non CP and CP counterparts increase. |
| 137 | + |
| 138 | +### Test 5: Loss Similarity |
| 139 | +- Compares averaged losses from both CP ranks |
| 140 | +- **Tolerance**: 5% relative difference |
| 141 | +- Validates that distributed training preserves loss computation |
| 142 | + |
| 143 | +### Test 6: Gradient Consistency ⭐ |
| 144 | +**Dual Validation Approach:** |
| 145 | + |
| 146 | +1. **Strict Bound**: No gradient can exceed 0.05 absolute difference |
| 147 | + ```python |
| 148 | + GRADIENT_MAX_ABSOLUTE_DIFF_TOLERANCE = 0.05 |
| 149 | + ``` |
| 150 | + |
| 151 | +2. **Statistical Quality**: ≥80% of gradients must be "acceptable" |
| 152 | + ```python |
| 153 | + GRADIENT_SUCCESS_RATE_THRESHOLD = 80.0 |
| 154 | + ``` |
| 155 | + |
| 156 | +**Gradient Categories:** |
| 157 | +- **Excellent**: Absolute difference < 1e-4 |
| 158 | +- **Good**: Relative difference < 2e-2 |
| 159 | +- **Acceptable**: Relative difference < 5e-2 |
| 160 | + |
| 161 | +## 🎯 Tolerance Philosophy |
| 162 | + |
| 163 | +The framework uses **scientifically calibrated tolerances** based on: |
| 164 | + |
| 165 | +1. **Mixed Precision Effects**: bfloat16 has ~3-4 decimal digits of precision |
| 166 | +2. **Distributed Communication**: AllReduce operations introduce small numerical errors |
| 167 | +3. **Computation Order**: Different operation sequences in CP vs non-CP modes |
| 168 | + |
| 169 | +**Conservative but Practical**: Tolerances are tight enough to catch real bugs while loose enough to handle expected distributed computing variations. |
| 170 | + |
| 171 | +## 🚀 Running the Tests |
| 172 | + |
| 173 | +### Quick Start |
| 174 | +```bash |
| 175 | +# Run both THD and BSHD format tests with automatic synchronization |
| 176 | +bash run_context_parallel.sh |
| 177 | +``` |
| 178 | + |
| 179 | +### Script Features |
| 180 | + |
| 181 | +**Automatic Process Synchronization:** |
| 182 | +- Uses file-based completion markers (`/tmp/{format}_complete.marker`) |
| 183 | +- Ensures first `torchrun` completes before starting second |
| 184 | +- Configurable timeout (default 60 seconds) |
| 185 | +- Automatic cleanup of old markers |
| 186 | + |
| 187 | +**Clean Output Mode:** |
| 188 | +- Suppressed verbose logging by default |
| 189 | +- Only shows test results and errors |
| 190 | +- Use `--verbose` flag for detailed output |
| 191 | + |
| 192 | +### Manual Execution |
| 193 | + |
| 194 | +**THD Format:** |
| 195 | +```bash |
| 196 | +# Step 1: Generate test data with distributed run |
| 197 | +torchrun --nproc_per_node=2 --master_port=29501 context_parallel_runner_thd.py |
| 198 | + |
| 199 | +# Step 2: Run validation tests |
| 200 | +python -m pytest test_context_parallel_thd.py -v |
| 201 | +``` |
| 202 | + |
| 203 | +**BSHD Format:** |
| 204 | +```bash |
| 205 | +# Step 1: Generate test data with distributed run |
| 206 | +torchrun --nproc_per_node=2 --master_port=29501 context_parallel_runner_bshd.py |
| 207 | + |
| 208 | +# Step 2: Run validation tests |
| 209 | +python -m pytest test_context_parallel_bshd.py -v |
| 210 | +``` |
| 211 | + |
| 212 | +### Expected Output |
| 213 | + |
| 214 | +**Standard Run (Clean Output):** |
| 215 | +``` |
| 216 | +Running distributed training for BSHD format... |
| 217 | +✓ BSHD training completed successfully |
| 218 | +
|
| 219 | +Running Context Parallel Tests for BSHD... |
| 220 | +============================== 6 passed in 2.34s ============================== |
| 221 | +
|
| 222 | +Running distributed training for THD format... |
| 223 | +✓ THD training completed successfully |
| 224 | +
|
| 225 | +Running Context Parallel Tests for THD... |
| 226 | +============================== 6 passed in 2.45s ============================== |
| 227 | +``` |
| 228 | + |
| 229 | +**Verbose Mode:** |
| 230 | +```bash |
| 231 | +# Run with detailed logging |
| 232 | +bash run_context_parallel.sh --verbose |
| 233 | + |
| 234 | +# Or for pytest directly: |
| 235 | +pytest test_context_parallel_thd.py -v --log-cli-level=INFO |
| 236 | +``` |
| 237 | + |
| 238 | +## 🔧 Customizing Tolerances |
| 239 | + |
| 240 | +All test thresholds are configurable constants at the top of both test files: |
| 241 | + |
| 242 | +```python |
| 243 | +# In test_context_parallel_thd.py and test_context_parallel_bshd.py |
| 244 | +LOGITS_ACCURACY_THRESHOLD = 85.0 # Stricter: 90.0, Looser: 80.0 |
| 245 | +LOGITS_ELEMENT_TOLERANCE = 2e-2 # Stricter: 1e-2, Looser: 5e-2 |
| 246 | +GRADIENT_MAX_ABSOLUTE_DIFF_TOLERANCE = 0.05 # Stricter: 0.01, Looser: 0.1 |
| 247 | +``` |
| 248 | + |
| 249 | +## 🎯 What This Folder Validates |
| 250 | + |
| 251 | +✅ **Numerical Correctness**: CP=2 produces equivalent results to CP=1 |
| 252 | +✅ **Format Compatibility**: Both THD and BSHD attention formats work correctly |
| 253 | +✅ **Gradient Consistency**: Distributed training gradients match single-GPU |
| 254 | +✅ **Loss Preservation**: Training objectives remain unchanged |
| 255 | +✅ **Sequence Reconstruction**: Distributed chunks correctly reassemble |
| 256 | +✅ **Memory Efficiency**: Context parallelism reduces per-GPU memory usage |
| 257 | +✅ **Scalability**: Folder extends to larger CP sizes and models |
0 commit comments