@@ -207,7 +207,7 @@ def forward(
207
207
output_attentions : Optional [bool ] = False ,
208
208
use_cache : Optional [bool ] = False ,
209
209
cache_position : Optional [torch .LongTensor ] = None ,
210
- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # will become mandatory in v4.46
210
+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
211
211
** kwargs : Unpack [FlashAttentionKwargs ],
212
212
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
213
213
residual = hidden_states
@@ -245,6 +245,51 @@ def forward(
245
245
return outputs
246
246
247
247
248
+ MY_NEW_MODEL2_START_DOCSTRING = r"""
249
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
250
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
251
+ etc.)
252
+
253
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
254
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
255
+ and behavior.
256
+
257
+ Parameters:
258
+ config ([`MyNewModel2Config`]):
259
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
260
+ load the weights associated with the model, only the configuration. Check out the
261
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
262
+ """
263
+
264
+
265
+ @add_start_docstrings (
266
+ "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top." ,
267
+ MY_NEW_MODEL2_START_DOCSTRING ,
268
+ )
269
+ class MyNewModel2PreTrainedModel (PreTrainedModel ):
270
+ config_class = MyNewModel2Config
271
+ base_model_prefix = "model"
272
+ supports_gradient_checkpointing = True
273
+ _no_split_modules = ["MyNewModel2DecoderLayer" ]
274
+ _skip_keys_device_placement = ["past_key_values" ]
275
+ _supports_flash_attn_2 = True
276
+ _supports_sdpa = True
277
+ _supports_cache_class = True
278
+ _supports_quantized_cache = True
279
+ _supports_static_cache = True
280
+
281
+ def _init_weights (self , module ):
282
+ std = self .config .initializer_range
283
+ if isinstance (module , nn .Linear ):
284
+ module .weight .data .normal_ (mean = 0.0 , std = std )
285
+ if module .bias is not None :
286
+ module .bias .data .zero_ ()
287
+ elif isinstance (module , nn .Embedding ):
288
+ module .weight .data .normal_ (mean = 0.0 , std = std )
289
+ if module .padding_idx is not None :
290
+ module .weight .data [module .padding_idx ].zero_ ()
291
+
292
+
248
293
class MyNewModel2RotaryEmbedding (nn .Module ):
249
294
def __init__ (
250
295
self ,
@@ -310,51 +355,6 @@ def forward(self, x, position_ids):
310
355
return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
311
356
312
357
313
- MY_NEW_MODEL2_START_DOCSTRING = r"""
314
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
315
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
316
- etc.)
317
-
318
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
319
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
320
- and behavior.
321
-
322
- Parameters:
323
- config ([`MyNewModel2Config`]):
324
- Model configuration class with all the parameters of the model. Initializing with a config file does not
325
- load the weights associated with the model, only the configuration. Check out the
326
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
327
- """
328
-
329
-
330
- @add_start_docstrings (
331
- "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top." ,
332
- MY_NEW_MODEL2_START_DOCSTRING ,
333
- )
334
- class MyNewModel2PreTrainedModel (PreTrainedModel ):
335
- config_class = MyNewModel2Config
336
- base_model_prefix = "model"
337
- supports_gradient_checkpointing = True
338
- _no_split_modules = ["MyNewModel2DecoderLayer" ]
339
- _skip_keys_device_placement = ["past_key_values" ]
340
- _supports_flash_attn_2 = True
341
- _supports_sdpa = True
342
- _supports_cache_class = True
343
- _supports_quantized_cache = True
344
- _supports_static_cache = True
345
-
346
- def _init_weights (self , module ):
347
- std = self .config .initializer_range
348
- if isinstance (module , nn .Linear ):
349
- module .weight .data .normal_ (mean = 0.0 , std = std )
350
- if module .bias is not None :
351
- module .bias .data .zero_ ()
352
- elif isinstance (module , nn .Embedding ):
353
- module .weight .data .normal_ (mean = 0.0 , std = std )
354
- if module .padding_idx is not None :
355
- module .weight .data [module .padding_idx ].zero_ ()
356
-
357
-
358
358
MY_NEW_MODEL2_INPUTS_DOCSTRING = r"""
359
359
Args:
360
360
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
0 commit comments