Skip to content

Commit 5f9353b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 310a064 commit 5f9353b

File tree

9 files changed

+893
-630
lines changed

9 files changed

+893
-630
lines changed

examples/pytorch/transformer/context_parallel_runner_bshd.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from utils import get_dummy_data_bshd, collect_gradients, DistributedConfig
1616
from model import SimpleConfig, SimpleBSHDModel
17-
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import get_batch_on_this_cp_rank
17+
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
18+
get_batch_on_this_cp_rank,
19+
)
1820

1921

2022
random.seed(42)
@@ -23,12 +25,12 @@
2325
if torch.cuda.is_available():
2426
torch.cuda.manual_seed_all(42)
2527

26-
logging.basicConfig(level=logging.INFO, format='%(message)s')
28+
logging.basicConfig(level=logging.INFO, format="%(message)s")
2729
logger = logging.getLogger(__name__)
2830

2931
data = get_dummy_data_bshd()
3032

31-
DISTRIBUTED_MODE = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
33+
DISTRIBUTED_MODE = "RANK" in os.environ and "WORLD_SIZE" in os.environ
3234

3335
# STEP 1: RUN CP=1 (BASELINE) - NO DISTRIBUTED TRAINING
3436

@@ -55,7 +57,10 @@
5557
torch.cuda.manual_seed(42)
5658

5759
batch = get_dummy_data_bshd()
58-
batch = {k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
60+
batch = {
61+
k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v
62+
for k, v in batch.items()
63+
}
5964

6065
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
6166
output_cp1 = model_cp1(batch)
@@ -64,34 +69,36 @@
6469
labels_flat = batch["labels"].view(-1)
6570
valid_mask = labels_flat != -100
6671
if valid_mask.any():
67-
loss_cp1 = F.cross_entropy(logits_flat_cp1[valid_mask], labels_flat[valid_mask], reduction="mean")
72+
loss_cp1 = F.cross_entropy(
73+
logits_flat_cp1[valid_mask], labels_flat[valid_mask], reduction="mean"
74+
)
6875
else:
6976
loss_cp1 = logits_flat_cp1.sum() * 0.0
7077

7178
# Compute gradients for CP=1
7279
loss_cp1.backward()
7380

7481
target_layers = [
75-
'embedding', # Embedding layers
76-
'transformer_layers.0', # First transformer layer
77-
'transformer_layers.1', # Second transformer layer
78-
'linear' # Language model head
82+
"embedding", # Embedding layers
83+
"transformer_layers.0", # First transformer layer
84+
"transformer_layers.1", # Second transformer layer
85+
"linear", # Language model head
7986
]
8087

8188
grads_cp1 = collect_gradients(model_cp1, layer_patterns=target_layers, max_params=15)
8289

8390
initial_state_dict = {k: v.cpu().clone() for k, v in model_cp1.state_dict().items()}
84-
torch.save(initial_state_dict, '/tmp/bshd_initial_model_state.pt')
91+
torch.save(initial_state_dict, "/tmp/bshd_initial_model_state.pt")
8592

8693
cp1_results = {
87-
'logits': output_cp1.clone().detach().cpu(),
88-
'loss': loss_cp1.clone().detach().cpu(),
89-
'grad_norms': {name: grad.norm().item() for name, grad in grads_cp1.items()},
90-
'grads': grads_cp1,
94+
"logits": output_cp1.clone().detach().cpu(),
95+
"loss": loss_cp1.clone().detach().cpu(),
96+
"grad_norms": {name: grad.norm().item() for name, grad in grads_cp1.items()},
97+
"grads": grads_cp1,
9198
}
9299

93-
torch.save(cp1_results, '/tmp/bshd_cp1_results.pt')
94-
torch.save(data, '/tmp/bshd_data.pt')
100+
torch.save(cp1_results, "/tmp/bshd_cp1_results.pt")
101+
torch.save(data, "/tmp/bshd_data.pt")
95102
if not DISTRIBUTED_MODE:
96103
logger.info(f"CP=1 complete: loss={cp1_results['loss'].item():.6f}")
97104

@@ -100,26 +107,27 @@
100107
# Skip CP=2 if not in distributed mode
101108
if not DISTRIBUTED_MODE:
102109
import sys
110+
103111
sys.exit(0)
104112

105113
# Run CP=2 in distributed mode
106114
if DISTRIBUTED_MODE:
107-
if int(os.environ.get('RANK', 0)) == 0:
115+
if int(os.environ.get("RANK", 0)) == 0:
108116
logger.info("Running CP=2 distributed...")
109117

110118
cp_size = 2
111119
dist.init_process_group(backend="nccl")
112120
dist_config = DistributedConfig()
113121
torch.cuda.set_device(dist_config.local_rank)
114122
device_mesh = init_device_mesh(
115-
"cuda",
116-
mesh_shape=(1, cp_size, 1),
117-
mesh_dim_names=("fsdp", "cp", "tp"),
118-
)
123+
"cuda",
124+
mesh_shape=(1, cp_size, 1),
125+
mesh_dim_names=("fsdp", "cp", "tp"),
126+
)
119127
device = torch.device(f"cuda:{dist_config.local_rank}")
120128

121129
model = SimpleBSHDModel(config)
122-
initial_state_dict = torch.load('/tmp/bshd_initial_model_state.pt', map_location='cpu')
130+
initial_state_dict = torch.load("/tmp/bshd_initial_model_state.pt", map_location="cpu")
123131
model.load_state_dict(initial_state_dict)
124132

125133
model = model.to(device)
@@ -138,7 +146,7 @@
138146
transformer_layer.set_context_parallel_group(
139147
cp_group,
140148
torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
141-
torch.cuda.Stream()
149+
torch.cuda.Stream(),
142150
)
143151

144152
dist.barrier()
@@ -149,18 +157,21 @@
149157

150158
batch = get_dummy_data_bshd()
151159
input_ids, labels, position_ids = get_batch_on_this_cp_rank(
152-
cu_seqlens_padded=batch['cu_seqlens_q'],
153-
input_ids_padded=batch['input_ids'],
154-
labels_padded=batch['labels'],
155-
position_ids_padded=batch['position_ids'],
160+
cu_seqlens_padded=batch["cu_seqlens_q"],
161+
input_ids_padded=batch["input_ids"],
162+
labels_padded=batch["labels"],
163+
position_ids_padded=batch["position_ids"],
156164
cp_group=cp_group,
157-
qvk_format="bshd"
165+
qvk_format="bshd",
158166
)
159-
batch['input_ids'] = input_ids
160-
batch['labels'] = labels
161-
batch['position_ids'] = position_ids
162-
163-
batch = {k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
167+
batch["input_ids"] = input_ids
168+
batch["labels"] = labels
169+
batch["position_ids"] = position_ids
170+
171+
batch = {
172+
k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v
173+
for k, v in batch.items()
174+
}
164175

165176
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
166177
output = model(batch)
@@ -177,19 +188,19 @@
177188
grads = collect_gradients(model, layer_patterns=target_layers, max_params=15)
178189

179190
cp2_results = {
180-
'logits': output.clone().detach().cpu(),
181-
'loss': loss.clone().detach().cpu(),
182-
'grad_norms': {name: grad.norm().item() for name, grad in grads.items()},
183-
'grads': grads,
191+
"logits": output.clone().detach().cpu(),
192+
"loss": loss.clone().detach().cpu(),
193+
"grad_norms": {name: grad.norm().item() for name, grad in grads.items()},
194+
"grads": grads,
184195
}
185196

186-
torch.save(cp2_results, f'/tmp/bshd_cp2_rank_{dist_config.rank}_results.pt')
197+
torch.save(cp2_results, f"/tmp/bshd_cp2_rank_{dist_config.rank}_results.pt")
187198
dist.barrier()
188-
199+
189200
if dist_config.rank == 0:
190201
logger.info(f"CP=2 complete: rank0_loss={cp2_results['loss'].item():.6f}")
191202
# Create completion marker file to signal that all processing is done
192-
with open('/tmp/bshd_complete.marker', 'w') as f:
193-
f.write('completed')
194-
203+
with open("/tmp/bshd_complete.marker", "w") as f:
204+
f.write("completed")
205+
195206
dist.destroy_process_group()

examples/pytorch/transformer/context_parallel_runner_thd.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from utils import get_dummy_data_thd, collect_gradients, DistributedConfig
1616
from model import SimpleConfig, SimpleThDModel
17-
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import get_batch_on_this_cp_rank
17+
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
18+
get_batch_on_this_cp_rank,
19+
)
1820

1921

2022
random.seed(42)
@@ -23,12 +25,12 @@
2325
if torch.cuda.is_available():
2426
torch.cuda.manual_seed_all(42)
2527

26-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
2729
logger = logging.getLogger(__name__)
2830

2931
data = get_dummy_data_thd()
3032

31-
DISTRIBUTED_MODE = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
33+
DISTRIBUTED_MODE = "RANK" in os.environ and "WORLD_SIZE" in os.environ
3234

3335
device = torch.device("cuda:0")
3436
torch.cuda.set_device(0)
@@ -53,10 +55,13 @@
5355
torch.cuda.manual_seed(42)
5456

5557
batch = get_dummy_data_thd()
56-
batch['input_ids'] = batch['input_ids_padded']
57-
batch['labels'] = batch['labels_padded']
58-
batch['position_ids'] = batch['position_ids_padded']
59-
batch = {k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
58+
batch["input_ids"] = batch["input_ids_padded"]
59+
batch["labels"] = batch["labels_padded"]
60+
batch["position_ids"] = batch["position_ids_padded"]
61+
batch = {
62+
k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v
63+
for k, v in batch.items()
64+
}
6065

6166
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
6267
output_cp1 = model_cp1(batch)
@@ -65,69 +70,72 @@
6570
labels_flat = batch["labels"].view(-1)
6671
valid_mask = labels_flat != -100
6772
if valid_mask.any():
68-
loss_cp1 = F.cross_entropy(logits_flat_cp1[valid_mask], labels_flat[valid_mask], reduction="mean")
73+
loss_cp1 = F.cross_entropy(
74+
logits_flat_cp1[valid_mask], labels_flat[valid_mask], reduction="mean"
75+
)
6976
else:
7077
loss_cp1 = logits_flat_cp1.sum() * 0.0
7178

7279
# Compute gradients for CP=1
7380
loss_cp1.backward()
7481

7582
target_layers = [
76-
'embedding', # Embedding layers
77-
'transformer_layers.0', # First transformer layer
78-
'transformer_layers.1', # Second transformer layer
79-
'linear' # Language model head
83+
"embedding", # Embedding layers
84+
"transformer_layers.0", # First transformer layer
85+
"transformer_layers.1", # Second transformer layer
86+
"linear", # Language model head
8087
]
8188

8289
grads_cp1 = collect_gradients(model_cp1, layer_patterns=target_layers, max_params=15)
8390

8491
# logger.info(f"CP=1 collected {len(grads_cp1)} parameter gradients")
8592

8693
initial_state_dict = {k: v.cpu().clone() for k, v in model_cp1.state_dict().items()}
87-
torch.save(initial_state_dict, '/tmp/thd_initial_model_state.pt')
94+
torch.save(initial_state_dict, "/tmp/thd_initial_model_state.pt")
8895
# logger.info("Model state saved for CP=2 reuse")
8996

9097
cp1_results = {
91-
'logits': output_cp1.clone().detach().cpu(),
92-
'loss': loss_cp1.clone().detach().cpu(),
93-
'grad_norms': {name: grad.norm().item() for name, grad in grads_cp1.items()},
94-
'grads': grads_cp1,
98+
"logits": output_cp1.clone().detach().cpu(),
99+
"loss": loss_cp1.clone().detach().cpu(),
100+
"grad_norms": {name: grad.norm().item() for name, grad in grads_cp1.items()},
101+
"grads": grads_cp1,
95102
}
96103

97-
torch.save(cp1_results, '/tmp/thd_cp1_results.pt')
98-
torch.save(data, '/tmp/thd_data.pt')
104+
torch.save(cp1_results, "/tmp/thd_cp1_results.pt")
105+
torch.save(data, "/tmp/thd_data.pt")
99106

100107
# STEP 2: RUN CP=2 (CONTEXT PARALLELISM)
101108
# Skip CP=2 if not in distributed mode
102109
if not DISTRIBUTED_MODE:
103-
logger.info("="*50)
110+
logger.info("=" * 50)
104111
logger.info("SKIPPING CP=2 - Not running in distributed mode")
105112
logger.info("To test CP=2, run with: torchrun --nproc_per_node=2 run_context_parallel_thd.py")
106-
logger.info("="*50)
113+
logger.info("=" * 50)
107114
import sys
115+
108116
sys.exit(1)
109117

110118
# Run CP=2 in distributed mode
111119
if DISTRIBUTED_MODE:
112-
logger.info("="*50)
120+
logger.info("=" * 50)
113121
logger.info("RUNNING CP=2 (CONTEXT PARALLELISM)")
114-
logger.info("="*50)
122+
logger.info("=" * 50)
115123

116124
cp_size = 2
117125
dist.init_process_group(backend="nccl")
118126
dist_config = DistributedConfig()
119127
torch.cuda.set_device(dist_config.local_rank)
120128

121129
device_mesh = init_device_mesh(
122-
"cuda",
123-
mesh_shape=(1, cp_size, 1),
124-
mesh_dim_names=("fsdp", "cp", "tp"),
125-
)
130+
"cuda",
131+
mesh_shape=(1, cp_size, 1),
132+
mesh_dim_names=("fsdp", "cp", "tp"),
133+
)
126134
device = torch.device(f"cuda:{dist_config.local_rank}")
127135
model = SimpleThDModel(config)
128136

129137
try:
130-
initial_state_dict = torch.load('/tmp/thd_initial_model_state.pt', map_location='cpu')
138+
initial_state_dict = torch.load("/tmp/thd_initial_model_state.pt", map_location="cpu")
131139
model.load_state_dict(initial_state_dict)
132140
# logger.info(f"Rank {dist_config.rank}: Model state loaded successfully")
133141
except Exception as e:
@@ -152,7 +160,7 @@
152160
transformer_layer.set_context_parallel_group(
153161
cp_group,
154162
torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
155-
torch.cuda.Stream()
163+
torch.cuda.Stream(),
156164
)
157165

158166
dist.barrier()
@@ -162,20 +170,23 @@
162170
torch.cuda.manual_seed(42)
163171

164172
batch = get_dummy_data_thd()
165-
173+
166174
input_ids_padded, labels_padded, position_ids_padded = get_batch_on_this_cp_rank(
167-
cu_seqlens_padded=batch['cu_seqlens_q_padded'],
168-
input_ids_padded=batch['input_ids_padded'],
169-
labels_padded=batch['labels_padded'],
170-
position_ids_padded=batch['position_ids_padded'],
175+
cu_seqlens_padded=batch["cu_seqlens_q_padded"],
176+
input_ids_padded=batch["input_ids_padded"],
177+
labels_padded=batch["labels_padded"],
178+
position_ids_padded=batch["position_ids_padded"],
171179
cp_group=cp_group,
172-
qvk_format="thd"
180+
qvk_format="thd",
173181
)
174182

175-
batch['input_ids'] = input_ids_padded
176-
batch['labels'] = labels_padded
177-
batch['position_ids'] = position_ids_padded
178-
batch = {k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
183+
batch["input_ids"] = input_ids_padded
184+
batch["labels"] = labels_padded
185+
batch["position_ids"] = position_ids_padded
186+
batch = {
187+
k: v.to(device, non_blocking=True).contiguous() if isinstance(v, torch.Tensor) else v
188+
for k, v in batch.items()
189+
}
179190

180191
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
181192
output = model(batch)
@@ -193,18 +204,18 @@
193204
grads = collect_gradients(model, layer_patterns=target_layers, max_params=15)
194205

195206
cp2_results = {
196-
'logits': output.clone().detach().cpu(),
197-
'loss': loss.clone().detach().cpu(),
198-
'grad_norms': {name: grad.norm().item() for name, grad in grads.items()},
199-
'grads': grads,
207+
"logits": output.clone().detach().cpu(),
208+
"loss": loss.clone().detach().cpu(),
209+
"grad_norms": {name: grad.norm().item() for name, grad in grads.items()},
210+
"grads": grads,
200211
}
201212

202-
torch.save(cp2_results, f'/tmp/thd_cp2_rank_{dist_config.rank}_results.pt')
213+
torch.save(cp2_results, f"/tmp/thd_cp2_rank_{dist_config.rank}_results.pt")
203214
dist.barrier()
204-
215+
205216
# Create completion marker on rank 0 to signal all processing is done
206217
if dist_config.rank == 0:
207-
with open('/tmp/thd_complete.marker', 'w') as f:
208-
f.write('completed')
209-
218+
with open("/tmp/thd_complete.marker", "w") as f:
219+
f.write("completed")
220+
210221
dist.destroy_process_group()

0 commit comments

Comments
 (0)