-
Notifications
You must be signed in to change notification settings - Fork 406
Add support of Falcon-H1 models for liger kernels #874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR enables ROPE, rmsnorm and cross entropy liger kernels for Falcon-H1 models. Support for fused cross entropy and swiglu to be added later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If can't implement swiglu/flce in this impl, can you create a new issue to track this?
"smollm3": apply_liger_kernel_to_smollm3, | ||
"phi3": apply_liger_kernel_to_phi3, | ||
"paligemma": apply_liger_kernel_to_paligemma, | ||
"falcon_h1": apply_liger_kernel_to_falcon_h1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this to src/liger_kernel/transformers/init.py?
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm) | ||
|
||
|
||
def apply_liger_kernel_to_falcon_h1( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add corresponding convergence and monkey patch tests?
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. | ||
fused_linear_cross_entropy (bool): | ||
Whether to apply Liger's fused linear cross entropy loss. Default is True. | ||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True. | ||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. | ||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. | ||
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. | ||
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set correct defaults for flce and swiglu if not implemented
if swiglu: | ||
modeling_falcon_h1.FalconH1MLP = LigerSwiGLUMLP | ||
|
||
if cross_entropy: | ||
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): | ||
logger.info("Apply liger cross entropy") | ||
from transformers.loss.loss_utils import nn | ||
|
||
nn.functional.cross_entropy = liger_cross_entropy | ||
else: | ||
logger.warning(TRANSFORMER_DEPRECATION_WARNING) | ||
modeling_falcon_h1.CrossEntropyLoss = LigerCrossEntropyLoss | ||
|
||
# TODO: To be enabled | ||
if fused_linear_cross_entropy: | ||
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): | ||
modeling_falcon_h1.FalconH1ForCausalLM.forward = llama_lce_forward | ||
else: # if version < 4.46.1 | ||
logger.warning(TRANSFORMER_DEPRECATION_WARNING) | ||
modeling_falcon_h1.FalconH1ForCausalLM.forward = llama_lce_forward_deprecated | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you raise NotImplementedError in the case of swiglu and flce are set to true? Right now we're patching with incorrect modules (also in the case below where model instance exists)
Thanks for all the comments @shimizust . Will address soon |
Summary
This PR enables ROPE, rmsnorm and cross entropy liger kernels for Falcon-H1 models
Falcon H1 models -
https://huggingface.co/collections/tiiuae/falcon-h1-6819f2795bc406da60fab8df
Support for fused cross entropy and swiglu to be added later.
Testing Done
Verified fine tuning Falcon H1 models with and without enabling liger kernels and loss plot matches. And gives speedup as well as memory savings.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence