|
14 | 14 |
|
15 | 15 | from utils import get_dummy_data_thd, collect_gradients, DistributedConfig
|
16 | 16 | from model import SimpleConfig, SimpleThDModel
|
17 |
| -from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import get_batch_on_this_cp_rank |
| 17 | +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( |
| 18 | + get_batch_on_this_cp_rank, |
| 19 | +) |
18 | 20 |
|
19 | 21 |
|
20 | 22 | random.seed(42)
|
|
23 | 25 | if torch.cuda.is_available():
|
24 | 26 | torch.cuda.manual_seed_all(42)
|
25 | 27 |
|
26 |
| -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| 28 | +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
27 | 29 | logger = logging.getLogger(__name__)
|
28 | 30 |
|
29 | 31 | data = get_dummy_data_thd()
|
30 | 32 |
|
31 |
| -DISTRIBUTED_MODE = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ |
| 33 | +DISTRIBUTED_MODE = "RANK" in os.environ and "WORLD_SIZE" in os.environ |
32 | 34 |
|
33 | 35 | device = torch.device("cuda:0")
|
34 | 36 | torch.cuda.set_device(0)
|
|
53 | 55 | torch.cuda.manual_seed(42)
|
54 | 56 |
|
55 | 57 | batch = get_dummy_data_thd()
|
56 |
| -batch['input_ids'] = batch['input_ids_padded'] |
57 |
| -batch['labels'] = batch['labels_padded'] |
58 |
| -batch['position_ids'] = batch['position_ids_padded'] |
59 |
| -batch = {k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| 58 | +batch["input_ids"] = batch["input_ids_padded"] |
| 59 | +batch["labels"] = batch["labels_padded"] |
| 60 | +batch["position_ids"] = batch["position_ids_padded"] |
| 61 | +batch = { |
| 62 | + k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v |
| 63 | + for k, v in batch.items() |
| 64 | +} |
60 | 65 |
|
61 | 66 | with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
62 | 67 | output_cp1 = model_cp1(batch)
|
|
65 | 70 | labels_flat = batch["labels"].view(-1)
|
66 | 71 | valid_mask = labels_flat != -100
|
67 | 72 | if valid_mask.any():
|
68 |
| - loss_cp1 = F.cross_entropy(logits_flat_cp1[valid_mask], labels_flat[valid_mask], reduction="mean") |
| 73 | + loss_cp1 = F.cross_entropy( |
| 74 | + logits_flat_cp1[valid_mask], labels_flat[valid_mask], reduction="mean" |
| 75 | + ) |
69 | 76 | else:
|
70 | 77 | loss_cp1 = logits_flat_cp1.sum() * 0.0
|
71 | 78 |
|
72 | 79 | # Compute gradients for CP=1
|
73 | 80 | loss_cp1.backward()
|
74 | 81 |
|
75 | 82 | target_layers = [
|
76 |
| - 'embedding', # Embedding layers |
77 |
| - 'transformer_layers.0', # First transformer layer |
78 |
| - 'transformer_layers.1', # Second transformer layer |
79 |
| - 'linear' # Language model head |
| 83 | + "embedding", # Embedding layers |
| 84 | + "transformer_layers.0", # First transformer layer |
| 85 | + "transformer_layers.1", # Second transformer layer |
| 86 | + "linear", # Language model head |
80 | 87 | ]
|
81 | 88 |
|
82 | 89 | grads_cp1 = collect_gradients(model_cp1, layer_patterns=target_layers, max_params=15)
|
83 | 90 |
|
84 | 91 | # logger.info(f"CP=1 collected {len(grads_cp1)} parameter gradients")
|
85 | 92 |
|
86 | 93 | initial_state_dict = {k: v.cpu().clone() for k, v in model_cp1.state_dict().items()}
|
87 |
| -torch.save(initial_state_dict, '/tmp/thd_initial_model_state.pt') |
| 94 | +torch.save(initial_state_dict, "/tmp/thd_initial_model_state.pt") |
88 | 95 | # logger.info("Model state saved for CP=2 reuse")
|
89 | 96 |
|
90 | 97 | cp1_results = {
|
91 |
| - 'logits': output_cp1.clone().detach().cpu(), |
92 |
| - 'loss': loss_cp1.clone().detach().cpu(), |
93 |
| - 'grad_norms': {name: grad.norm().item() for name, grad in grads_cp1.items()}, |
94 |
| - 'grads': grads_cp1, |
| 98 | + "logits": output_cp1.clone().detach().cpu(), |
| 99 | + "loss": loss_cp1.clone().detach().cpu(), |
| 100 | + "grad_norms": {name: grad.norm().item() for name, grad in grads_cp1.items()}, |
| 101 | + "grads": grads_cp1, |
95 | 102 | }
|
96 | 103 |
|
97 |
| -torch.save(cp1_results, '/tmp/thd_cp1_results.pt') |
98 |
| -torch.save(data, '/tmp/thd_data.pt') |
| 104 | +torch.save(cp1_results, "/tmp/thd_cp1_results.pt") |
| 105 | +torch.save(data, "/tmp/thd_data.pt") |
99 | 106 |
|
100 | 107 | # STEP 2: RUN CP=2 (CONTEXT PARALLELISM)
|
101 | 108 | # Skip CP=2 if not in distributed mode
|
102 | 109 | if not DISTRIBUTED_MODE:
|
103 |
| - logger.info("="*50) |
| 110 | + logger.info("=" * 50) |
104 | 111 | logger.info("SKIPPING CP=2 - Not running in distributed mode")
|
105 | 112 | logger.info("To test CP=2, run with: torchrun --nproc_per_node=2 run_context_parallel_thd.py")
|
106 |
| - logger.info("="*50) |
| 113 | + logger.info("=" * 50) |
107 | 114 | import sys
|
| 115 | + |
108 | 116 | sys.exit(1)
|
109 | 117 |
|
110 | 118 | # Run CP=2 in distributed mode
|
111 | 119 | if DISTRIBUTED_MODE:
|
112 |
| - logger.info("="*50) |
| 120 | + logger.info("=" * 50) |
113 | 121 | logger.info("RUNNING CP=2 (CONTEXT PARALLELISM)")
|
114 |
| - logger.info("="*50) |
| 122 | + logger.info("=" * 50) |
115 | 123 |
|
116 | 124 | cp_size = 2
|
117 | 125 | dist.init_process_group(backend="nccl")
|
118 | 126 | dist_config = DistributedConfig()
|
119 | 127 | torch.cuda.set_device(dist_config.local_rank)
|
120 | 128 |
|
121 | 129 | device_mesh = init_device_mesh(
|
122 |
| - "cuda", |
123 |
| - mesh_shape=(1, cp_size, 1), |
124 |
| - mesh_dim_names=("fsdp", "cp", "tp"), |
125 |
| - ) |
| 130 | + "cuda", |
| 131 | + mesh_shape=(1, cp_size, 1), |
| 132 | + mesh_dim_names=("fsdp", "cp", "tp"), |
| 133 | + ) |
126 | 134 | device = torch.device(f"cuda:{dist_config.local_rank}")
|
127 | 135 | model = SimpleThDModel(config)
|
128 | 136 |
|
129 | 137 | try:
|
130 |
| - initial_state_dict = torch.load('/tmp/thd_initial_model_state.pt', map_location='cpu') |
| 138 | + initial_state_dict = torch.load("/tmp/thd_initial_model_state.pt", map_location="cpu") |
131 | 139 | model.load_state_dict(initial_state_dict)
|
132 | 140 | # logger.info(f"Rank {dist_config.rank}: Model state loaded successfully")
|
133 | 141 | except Exception as e:
|
|
152 | 160 | transformer_layer.set_context_parallel_group(
|
153 | 161 | cp_group,
|
154 | 162 | torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
|
155 |
| - torch.cuda.Stream() |
| 163 | + torch.cuda.Stream(), |
156 | 164 | )
|
157 | 165 |
|
158 | 166 | dist.barrier()
|
|
162 | 170 | torch.cuda.manual_seed(42)
|
163 | 171 |
|
164 | 172 | batch = get_dummy_data_thd()
|
165 |
| - |
| 173 | + |
166 | 174 | input_ids_padded, labels_padded, position_ids_padded = get_batch_on_this_cp_rank(
|
167 |
| - cu_seqlens_padded=batch['cu_seqlens_q_padded'], |
168 |
| - input_ids_padded=batch['input_ids_padded'], |
169 |
| - labels_padded=batch['labels_padded'], |
170 |
| - position_ids_padded=batch['position_ids_padded'], |
| 175 | + cu_seqlens_padded=batch["cu_seqlens_q_padded"], |
| 176 | + input_ids_padded=batch["input_ids_padded"], |
| 177 | + labels_padded=batch["labels_padded"], |
| 178 | + position_ids_padded=batch["position_ids_padded"], |
171 | 179 | cp_group=cp_group,
|
172 |
| - qvk_format="thd" |
| 180 | + qvk_format="thd", |
173 | 181 | )
|
174 | 182 |
|
175 |
| - batch['input_ids'] = input_ids_padded |
176 |
| - batch['labels'] = labels_padded |
177 |
| - batch['position_ids'] = position_ids_padded |
178 |
| - batch = {k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| 183 | + batch["input_ids"] = input_ids_padded |
| 184 | + batch["labels"] = labels_padded |
| 185 | + batch["position_ids"] = position_ids_padded |
| 186 | + batch = { |
| 187 | + k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v |
| 188 | + for k, v in batch.items() |
| 189 | + } |
179 | 190 |
|
180 | 191 | with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
181 | 192 | output = model(batch)
|
|
193 | 204 | grads = collect_gradients(model, layer_patterns=target_layers, max_params=15)
|
194 | 205 |
|
195 | 206 | cp2_results = {
|
196 |
| - 'logits': output.clone().detach().cpu(), |
197 |
| - 'loss': loss.clone().detach().cpu(), |
198 |
| - 'grad_norms': {name: grad.norm().item() for name, grad in grads.items()}, |
199 |
| - 'grads': grads, |
| 207 | + "logits": output.clone().detach().cpu(), |
| 208 | + "loss": loss.clone().detach().cpu(), |
| 209 | + "grad_norms": {name: grad.norm().item() for name, grad in grads.items()}, |
| 210 | + "grads": grads, |
200 | 211 | }
|
201 | 212 |
|
202 |
| - torch.save(cp2_results, f'/tmp/thd_cp2_rank_{dist_config.rank}_results.pt') |
| 213 | + torch.save(cp2_results, f"/tmp/thd_cp2_rank_{dist_config.rank}_results.pt") |
203 | 214 | dist.barrier()
|
204 |
| - |
| 215 | + |
205 | 216 | # Create completion marker on rank 0 to signal all processing is done
|
206 | 217 | if dist_config.rank == 0:
|
207 |
| - with open('/tmp/thd_complete.marker', 'w') as f: |
208 |
| - f.write('completed') |
209 |
| - |
| 218 | + with open("/tmp/thd_complete.marker", "w") as f: |
| 219 | + f.write("completed") |
| 220 | + |
210 | 221 | dist.destroy_process_group()
|
0 commit comments