Skip to content

Commit 9f0a4a4

Browse files
cyanguwaptrendx
authored andcommitted
[PyTorch] Fix tp_group_initialized error (#819)
remove tp_size/tp_group as amax reduction is handled by fp8_group() Signed-off-by: Charlene Yang <[email protected]>
1 parent 090e724 commit 9f0a4a4

File tree

1 file changed

+4
-19
lines changed

1 file changed

+4
-19
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
19371937
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
19381938
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
19391939
rng_gen, fused_attention_backend, use_FAv2_bwd,
1940-
fp8, fp8_meta, tp_size, tp_group):
1940+
fp8, fp8_meta):
19411941
if fp8:
19421942
if _NVTE_DEBUG:
19431943
print('[DotProductAttention]: using FP8 forward')
@@ -2011,8 +2011,6 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias,
20112011
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
20122012
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors)
20132013
ctx.fp8_meta = fp8_meta
2014-
ctx.tp_size = tp_size
2015-
ctx.tp_group = tp_group
20162014
ctx.aux_ctx_tensors = aux_ctx_tensors
20172015
ctx.max_seqlen = max_seqlen
20182016
ctx.qkv_dtype = qkv_dtype
@@ -2133,7 +2131,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
21332131
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
21342132
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
21352133
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
2136-
use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group):
2134+
use_FAv2_bwd, fp8, fp8_meta):
21372135
if fp8:
21382136
if _NVTE_DEBUG:
21392137
print('[DotProductAttention]: using FP8 forward')
@@ -2214,8 +2212,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
22142212
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
22152213
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
22162214
ctx.fp8_meta = fp8_meta
2217-
ctx.tp_size = tp_size
2218-
ctx.tp_group = tp_group
22192215
ctx.aux_ctx_tensors = aux_ctx_tensors
22202216
ctx.max_seqlen_q = max_seqlen_q
22212217
ctx.max_seqlen_kv = max_seqlen_kv
@@ -2350,7 +2346,7 @@ class FusedAttnFunc(torch.autograd.Function):
23502346
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
23512347
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
23522348
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
2353-
use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group):
2349+
use_FAv2_bwd, fp8, fp8_meta):
23542350
if fp8:
23552351
if _NVTE_DEBUG:
23562352
print('[DotProductAttention]: using FP8 forward')
@@ -2488,8 +2484,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
24882484
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
24892485
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
24902486
ctx.fp8_meta = fp8_meta
2491-
ctx.tp_size = tp_size
2492-
ctx.tp_group = tp_group
24932487
ctx.aux_ctx_tensors = aux_ctx_tensors
24942488
ctx.max_seqlen_q = max_seqlen_q
24952489
ctx.max_seqlen_kv = max_seqlen_kv
@@ -2691,8 +2685,6 @@ def __init__(
26912685
attention_type: str = "self",
26922686
layer_number: Optional[int] = None,
26932687
deterministic: bool = False,
2694-
tp_size: int = 1,
2695-
tp_group: Optional[dist_group_type] = None,
26962688
) -> None:
26972689
super().__init__()
26982690

@@ -2719,9 +2711,6 @@ def __init__(
27192711
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
27202712
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
27212713

2722-
self.tp_size = tp_size
2723-
self.tp_group = tp_group
2724-
27252714
def get_fp8_weights_scratchpad(
27262715
self,
27272716
is_first_microbatch: Union[bool, None],
@@ -2875,8 +2864,6 @@ def forward(
28752864
use_FAv2_bwd,
28762865
self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
28772866
self.fp8_meta,
2878-
self.tp_size,
2879-
self.tp_group,
28802867
)
28812868

28822869
# ...hd -> ...(hd)
@@ -3075,9 +3062,7 @@ def __init__(
30753062
attention_type=attention_type,
30763063
layer_number=layer_number,
30773064
deterministic=self.deterministic,
3078-
**attn_kwargs,
3079-
tp_size=self.tp_size,
3080-
tp_group=self.tp_group)
3065+
**attn_kwargs)
30813066
self.unfused_attention = UnfusedDotProductAttention(
30823067
norm_factor, **attn_kwargs, layer_number=layer_number)
30833068

0 commit comments

Comments
 (0)