@@ -1937,7 +1937,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
1937
1937
def forward (ctx , is_training , max_seqlen , cu_seqlens , qkv , qkv_dtype , attn_bias , attn_scale ,
1938
1938
dropout_p , fast_zero_fill , qkv_layout , attn_bias_type , attn_mask_type ,
1939
1939
rng_gen , fused_attention_backend , use_FAv2_bwd ,
1940
- fp8 , fp8_meta , tp_size , tp_group ):
1940
+ fp8 , fp8_meta ):
1941
1941
if fp8 :
1942
1942
if _NVTE_DEBUG :
1943
1943
print ('[DotProductAttention]: using FP8 forward' )
@@ -2011,8 +2011,6 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias,
2011
2011
qkvo_tensors = (qkv , out_save ) if not ctx .fp8 else (None , None )
2012
2012
ctx .save_for_backward (* qkvo_tensors , cu_seqlens , * fp8_tensors )
2013
2013
ctx .fp8_meta = fp8_meta
2014
- ctx .tp_size = tp_size
2015
- ctx .tp_group = tp_group
2016
2014
ctx .aux_ctx_tensors = aux_ctx_tensors
2017
2015
ctx .max_seqlen = max_seqlen
2018
2016
ctx .qkv_dtype = qkv_dtype
@@ -2133,7 +2131,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
2133
2131
def forward (ctx , is_training , max_seqlen_q , max_seqlen_kv , cu_seqlens_q , cu_seqlens_kv ,
2134
2132
q , kv , qkv_dtype , attn_bias , attn_scale , dropout_p , fast_zero_fill ,
2135
2133
qkv_layout , attn_bias_type , attn_mask_type , rng_gen , fused_attention_backend ,
2136
- use_FAv2_bwd , fp8 , fp8_meta , tp_size , tp_group ):
2134
+ use_FAv2_bwd , fp8 , fp8_meta ):
2137
2135
if fp8 :
2138
2136
if _NVTE_DEBUG :
2139
2137
print ('[DotProductAttention]: using FP8 forward' )
@@ -2214,8 +2212,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
2214
2212
qkvo_tensors = (q , kv , out_save ) if not ctx .fp8 else (None , None , None )
2215
2213
ctx .save_for_backward (* qkvo_tensors , cu_seqlens_q , cu_seqlens_kv , * fp8_tensors )
2216
2214
ctx .fp8_meta = fp8_meta
2217
- ctx .tp_size = tp_size
2218
- ctx .tp_group = tp_group
2219
2215
ctx .aux_ctx_tensors = aux_ctx_tensors
2220
2216
ctx .max_seqlen_q = max_seqlen_q
2221
2217
ctx .max_seqlen_kv = max_seqlen_kv
@@ -2350,7 +2346,7 @@ class FusedAttnFunc(torch.autograd.Function):
2350
2346
def forward (ctx , is_training , max_seqlen_q , max_seqlen_kv , cu_seqlens_q , cu_seqlens_kv ,
2351
2347
q , k , v , qkv_dtype , attn_bias , attn_scale , dropout_p , fast_zero_fill ,
2352
2348
qkv_layout , attn_bias_type , attn_mask_type , rng_gen , fused_attention_backend ,
2353
- use_FAv2_bwd , fp8 , fp8_meta , tp_size , tp_group ):
2349
+ use_FAv2_bwd , fp8 , fp8_meta ):
2354
2350
if fp8 :
2355
2351
if _NVTE_DEBUG :
2356
2352
print ('[DotProductAttention]: using FP8 forward' )
@@ -2488,8 +2484,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql
2488
2484
qkvo_tensors = (q , k , v , out_save ) if not ctx .fp8 else (None , None , None , None )
2489
2485
ctx .save_for_backward (* qkvo_tensors , cu_seqlens_q , cu_seqlens_kv , * fp8_tensors )
2490
2486
ctx .fp8_meta = fp8_meta
2491
- ctx .tp_size = tp_size
2492
- ctx .tp_group = tp_group
2493
2487
ctx .aux_ctx_tensors = aux_ctx_tensors
2494
2488
ctx .max_seqlen_q = max_seqlen_q
2495
2489
ctx .max_seqlen_kv = max_seqlen_kv
@@ -2691,8 +2685,6 @@ def __init__(
2691
2685
attention_type : str = "self" ,
2692
2686
layer_number : Optional [int ] = None ,
2693
2687
deterministic : bool = False ,
2694
- tp_size : int = 1 ,
2695
- tp_group : Optional [dist_group_type ] = None ,
2696
2688
) -> None :
2697
2689
super ().__init__ ()
2698
2690
@@ -2719,9 +2711,6 @@ def __init__(
2719
2711
if os .environ ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" ] == "1" :
2720
2712
os .environ ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT" ] = "-1"
2721
2713
2722
- self .tp_size = tp_size
2723
- self .tp_group = tp_group
2724
-
2725
2714
def get_fp8_weights_scratchpad (
2726
2715
self ,
2727
2716
is_first_microbatch : Union [bool , None ],
@@ -2875,8 +2864,6 @@ def forward(
2875
2864
use_FAv2_bwd ,
2876
2865
self .fp8 and self .fp8_meta ["recipe" ].fp8_dpa ,
2877
2866
self .fp8_meta ,
2878
- self .tp_size ,
2879
- self .tp_group ,
2880
2867
)
2881
2868
2882
2869
# ...hd -> ...(hd)
@@ -3075,9 +3062,7 @@ def __init__(
3075
3062
attention_type = attention_type ,
3076
3063
layer_number = layer_number ,
3077
3064
deterministic = self .deterministic ,
3078
- ** attn_kwargs ,
3079
- tp_size = self .tp_size ,
3080
- tp_group = self .tp_group )
3065
+ ** attn_kwargs )
3081
3066
self .unfused_attention = UnfusedDotProductAttention (
3082
3067
norm_factor , ** attn_kwargs , layer_number = layer_number )
3083
3068
0 commit comments