Skip to content

Commit ecd814b

Browse files
committed
mostly gemma2
1 parent 0325dc4 commit ecd814b

21 files changed

+272
-260
lines changed

examples/modular-transformers/modeling_dummy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def forward(
272272
output_attentions: Optional[bool] = False,
273273
use_cache: Optional[bool] = False,
274274
cache_position: Optional[torch.LongTensor] = None,
275-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
275+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
276276
**kwargs: Unpack[FlashAttentionKwargs],
277277
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
278278
residual = hidden_states

examples/modular-transformers/modeling_multimodal1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def forward(
272272
output_attentions: Optional[bool] = False,
273273
use_cache: Optional[bool] = False,
274274
cache_position: Optional[torch.LongTensor] = None,
275-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
275+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
276276
**kwargs: Unpack[FlashAttentionKwargs],
277277
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
278278
residual = hidden_states

examples/modular-transformers/modeling_my_new_model2.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def forward(
207207
output_attentions: Optional[bool] = False,
208208
use_cache: Optional[bool] = False,
209209
cache_position: Optional[torch.LongTensor] = None,
210-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
210+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
211211
**kwargs: Unpack[FlashAttentionKwargs],
212212
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
213213
residual = hidden_states
@@ -245,6 +245,51 @@ def forward(
245245
return outputs
246246

247247

248+
MY_NEW_MODEL2_START_DOCSTRING = r"""
249+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
250+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
251+
etc.)
252+
253+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
254+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
255+
and behavior.
256+
257+
Parameters:
258+
config ([`MyNewModel2Config`]):
259+
Model configuration class with all the parameters of the model. Initializing with a config file does not
260+
load the weights associated with the model, only the configuration. Check out the
261+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
262+
"""
263+
264+
265+
@add_start_docstrings(
266+
"The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.",
267+
MY_NEW_MODEL2_START_DOCSTRING,
268+
)
269+
class MyNewModel2PreTrainedModel(PreTrainedModel):
270+
config_class = MyNewModel2Config
271+
base_model_prefix = "model"
272+
supports_gradient_checkpointing = True
273+
_no_split_modules = ["MyNewModel2DecoderLayer"]
274+
_skip_keys_device_placement = ["past_key_values"]
275+
_supports_flash_attn_2 = True
276+
_supports_sdpa = True
277+
_supports_cache_class = True
278+
_supports_quantized_cache = True
279+
_supports_static_cache = True
280+
281+
def _init_weights(self, module):
282+
std = self.config.initializer_range
283+
if isinstance(module, nn.Linear):
284+
module.weight.data.normal_(mean=0.0, std=std)
285+
if module.bias is not None:
286+
module.bias.data.zero_()
287+
elif isinstance(module, nn.Embedding):
288+
module.weight.data.normal_(mean=0.0, std=std)
289+
if module.padding_idx is not None:
290+
module.weight.data[module.padding_idx].zero_()
291+
292+
248293
class MyNewModel2RotaryEmbedding(nn.Module):
249294
def __init__(
250295
self,
@@ -310,51 +355,6 @@ def forward(self, x, position_ids):
310355
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
311356

312357

313-
MY_NEW_MODEL2_START_DOCSTRING = r"""
314-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
315-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
316-
etc.)
317-
318-
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
319-
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
320-
and behavior.
321-
322-
Parameters:
323-
config ([`MyNewModel2Config`]):
324-
Model configuration class with all the parameters of the model. Initializing with a config file does not
325-
load the weights associated with the model, only the configuration. Check out the
326-
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
327-
"""
328-
329-
330-
@add_start_docstrings(
331-
"The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.",
332-
MY_NEW_MODEL2_START_DOCSTRING,
333-
)
334-
class MyNewModel2PreTrainedModel(PreTrainedModel):
335-
config_class = MyNewModel2Config
336-
base_model_prefix = "model"
337-
supports_gradient_checkpointing = True
338-
_no_split_modules = ["MyNewModel2DecoderLayer"]
339-
_skip_keys_device_placement = ["past_key_values"]
340-
_supports_flash_attn_2 = True
341-
_supports_sdpa = True
342-
_supports_cache_class = True
343-
_supports_quantized_cache = True
344-
_supports_static_cache = True
345-
346-
def _init_weights(self, module):
347-
std = self.config.initializer_range
348-
if isinstance(module, nn.Linear):
349-
module.weight.data.normal_(mean=0.0, std=std)
350-
if module.bias is not None:
351-
module.bias.data.zero_()
352-
elif isinstance(module, nn.Embedding):
353-
module.weight.data.normal_(mean=0.0, std=std)
354-
if module.padding_idx is not None:
355-
module.weight.data[module.padding_idx].zero_()
356-
357-
358358
MY_NEW_MODEL2_INPUTS_DOCSTRING = r"""
359359
Args:
360360
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

examples/modular-transformers/modeling_super.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def forward(
269269
output_attentions: Optional[bool] = False,
270270
use_cache: Optional[bool] = False,
271271
cache_position: Optional[torch.LongTensor] = None,
272-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
272+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
273273
**kwargs: Unpack[FlashAttentionKwargs],
274274
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
275275
residual = hidden_states

src/transformers/integrations/flash_attention.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
1+
from typing import Optional
2+
13
import torch
24

35
from ..modeling_flash_attention_utils import _flash_attention_forward
46

57

68
def flash_attention_forward(
7-
config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
9+
module: torch.nn.Module,
10+
query: torch.Tensor,
11+
key: torch.Tensor,
12+
value: torch.Tensor,
13+
attention_mask: Optional[torch.Tensor] = None,
14+
dropout: float = 0.0,
15+
scaling: Optional[float] = None,
16+
sliding_window: Optional[int] = None,
17+
softcap: Optional[float] = None,
18+
target_dtype: torch.dtype = torch.float16,
19+
**kwargs,
820
):
9-
if attentions_mask is not None:
10-
seq_len = attentions_mask.shape[1]
21+
if attention_mask is not None:
22+
seq_len = attention_mask.shape[1]
1123
query = query[:, :, :seq_len]
1224
value = value[:, :, :seq_len]
1325
else:
@@ -18,8 +30,6 @@ def flash_attention_forward(
1830
key = key.transpose(1, 2)
1931
value = value.transpose(1, 2)
2032

21-
dropout_rate = config.attention_dropout if training else 0.0
22-
2333
input_dtype = query.dtype
2434
if input_dtype == torch.float32:
2535
query = query.to(target_dtype)
@@ -30,11 +40,14 @@ def flash_attention_forward(
3040
query,
3141
key,
3242
value,
33-
attentions_mask,
43+
attention_mask,
3444
seq_len,
35-
config=config,
36-
dropout=dropout_rate,
37-
layer_idx=layer_idx,
45+
module.is_causal,
46+
dropout=dropout,
47+
softmax_scale=scaling,
48+
sliding_window=sliding_window,
49+
softcap=softcap,
50+
use_top_left_mask=module._flash_attn_uses_top_left_mask,
3851
**kwargs,
3952
)
4053

src/transformers/integrations/flex_attention.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
1+
from typing import Optional
2+
3+
import torch
4+
15
from ..utils import is_torch_greater_or_equal
26

37

48
if is_torch_greater_or_equal("2.5"):
59
from torch.nn.attention.flex_attention import flex_attention
610

711

8-
def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
12+
def flex_attention_forward(
13+
module: torch.nn.Module,
14+
query: torch.Tensor,
15+
key: torch.Tensor,
16+
value: torch.Tensor,
17+
attention_mask: Optional[torch.Tensor] = None,
18+
scaling: Optional[float] = None,
19+
softcap: Optional[float] = None,
20+
**kwargs,
21+
):
922
causal_mask = attention_mask
1023
if causal_mask is not None:
1124
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
1225

1326
def causal_mod(score, b, h, q_idx, kv_idx):
27+
if softcap is not None:
28+
score = softcap * torch.tanh(score / softcap)
1429
if causal_mask is not None:
1530
score += causal_mask[b][0][q_idx][kv_idx]
1631
return score
@@ -21,8 +36,9 @@ def causal_mod(score, b, h, q_idx, kv_idx):
2136
value,
2237
score_mod=causal_mod,
2338
enable_gqa=True,
24-
scale=module.scaling,
39+
scale=scaling,
2540
return_lse=True,
2641
)
2742
attn_output = attn_output.transpose(1, 2).contiguous()
43+
2844
return attn_output, attention_weights

src/transformers/integrations/sdpa_attention.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24

35

@@ -13,7 +15,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1315
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1416

1517

16-
def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
18+
def sdpa_attention_forward(
19+
module: torch.nn.Module,
20+
query: torch.Tensor,
21+
key: torch.Tensor,
22+
value: torch.Tensor,
23+
attention_mask: Optional[torch.Tensor] = None,
24+
dropout: float = 0.0,
25+
scaling: Optional[float] = None,
26+
**kwargs,
27+
):
1728
key = repeat_kv(key, module.num_key_value_groups)
1829
value = repeat_kv(value, module.num_key_value_groups)
1930

@@ -31,9 +42,10 @@ def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kw
3142
key,
3243
value,
3344
attn_mask=causal_mask,
34-
dropout_p=module.config.attention_dropout if module.training else 0.0,
45+
dropout_p=dropout,
46+
scale=scaling,
3547
is_causal=is_causal,
36-
scale=module.scaling,
3748
)
3849
attn_output = attn_output.transpose(1, 2).contiguous()
50+
3951
return attn_output, None

src/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from dataclasses import dataclass
3131
from functools import partial, wraps
3232
from threading import Thread
33-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
33+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
3434
from zipfile import is_zipfile
3535

3636
import torch

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def forward(
588588
output_attentions: Optional[bool] = False,
589589
use_cache: Optional[bool] = False,
590590
cache_position: Optional[torch.LongTensor] = None,
591-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
591+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
592592
**kwargs: Unpack[FlashAttentionKwargs],
593593
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
594594
residual = hidden_states

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,6 @@ def __init__(self, config: Cohere2Config):
659659
[Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
660660
)
661661
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
662-
663-
self.gradient_checkpointing = False
664-
if getattr(config, "pretraining_tp", 1) != 1:
665-
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
666662
self.rotary_emb = Cohere2RotaryEmbedding(config=config)
667663

668664
# Initialize weights and apply final processing

0 commit comments

Comments
 (0)