-
Notifications
You must be signed in to change notification settings - Fork 408
[Model] DeepseekV2 Support #499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
saurabhkoshatwar
wants to merge
12
commits into
linkedin:main
Choose a base branch
from
saurabhkoshatwar:feature/deepseekv2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
c62736a
initial patch code and test
saurabhkoshatwar f58471d
Add deepseekv2 convergence test
saurabhkoshatwar e28dc49
test fix
saurabhkoshatwar 1a7efbf
Merge branch 'feature/deepseekv2' of https://github.com/saurabhkoshat…
saurabhkoshatwar adfc644
Add test without logits
saurabhkoshatwar f1310e1
checkstyle fixes
saurabhkoshatwar a76931a
Merge branch 'main' into feature/deepseekv2
saurabhkoshatwar 0a17f0b
fused lce fix
saurabhkoshatwar b6287f1
Merge branch 'main' into feature/deepseekv2
austin362667 8e71b13
Merge branch 'main' into feature/deepseekv2
lancerts 2b5e749
Merge branch 'main' into feature/deepseekv2
lancerts a34774d
add docstring source link
saurabhkoshatwar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from typing import List | ||
from typing import Optional | ||
from typing import Tuple | ||
from typing import Union | ||
|
||
import torch | ||
|
||
from torch.nn import CrossEntropyLoss | ||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||
from transformers.utils import add_start_docstrings_to_model_forward | ||
from transformers.utils import replace_return_docstrings | ||
|
||
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss | ||
|
||
# This docstring is ported from the DeepSeek V2 model source code. | ||
# Source: https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py | ||
DeepseekV2_INPUTS_DOCSTRING = r""" | ||
Args: | ||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | ||
it. | ||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | ||
[`PreTrainedTokenizer.__call__`] for details. | ||
[What are input IDs?](../glossary#input-ids) | ||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | ||
- 1 for tokens that are **not masked**, | ||
- 0 for tokens that are **masked**. | ||
[What are attention masks?](../glossary#attention-mask) | ||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | ||
[`PreTrainedTokenizer.__call__`] for details. | ||
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see | ||
`past_key_values`). | ||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] | ||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more | ||
information on the default strategy. | ||
- 1 indicates the head is **not masked**, | ||
- 0 indicates the head is **masked**. | ||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | ||
config.n_positions - 1]`. | ||
[What are position IDs?](../glossary#position-ids) | ||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): | ||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | ||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` | ||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. | ||
Two formats are allowed: | ||
- a [`~cache_utils.Cache`] instance; | ||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of | ||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy | ||
cache format. | ||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the | ||
legacy cache format will be returned. | ||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't | ||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` | ||
of shape `(batch_size, sequence_length)`. | ||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | ||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | ||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the | ||
model's internal embedding lookup matrix. | ||
use_cache (`bool`, *optional*): | ||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | ||
`past_key_values`). | ||
output_attentions (`bool`, *optional*): | ||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | ||
tensors for more detail. | ||
output_hidden_states (`bool`, *optional*): | ||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | ||
more detail. | ||
return_dict (`bool`, *optional*): | ||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
""" | ||
|
||
_CONFIG_FOR_DOC = "DeepseekV2Config" | ||
|
||
|
||
@add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) | ||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||
def lce_forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
r""" | ||
Copy paste deepseekv2 forward but replace torch cross entropy with liger fused linear cross entropy | ||
|
||
Args: | ||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., | ||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. | ||
Returns: | ||
Example: | ||
```python | ||
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM | ||
>>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) | ||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) | ||
>>> prompt = "Hey, are you conscious? Can you talk to me?" | ||
>>> inputs = tokenizer(prompt, return_tensors="pt") | ||
>>> # Generate | ||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||
```""" | ||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
hidden_states = outputs[0] | ||
|
||
loss = None | ||
logits = None | ||
|
||
if self.training and labels is not None: | ||
shift_hidden_states = hidden_states[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
|
||
# flatten tokens | ||
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) | ||
shift_labels = shift_labels.view(-1) | ||
|
||
lce = LigerFusedLinearCrossEntropyLoss() | ||
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) | ||
else: | ||
logits = self.lm_head(hidden_states) | ||
|
||
loss = None | ||
if labels is not None: | ||
# Upcast to float if we need to compute the loss to avoid potential precision issues | ||
logits = logits.float() | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = CrossEntropyLoss() | ||
shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
loss = loss_fct(shift_logits, shift_labels) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return CausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do deepseek and deepseek-v3 share the same architecture? If so, perhaps this function should be called
apply_liger_kernel_to_deepseek
, if not, perhaps we should strengthen this check.