Skip to content
Merged
144 changes: 113 additions & 31 deletions src/liger_kernel/transformers/monkey_patch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

from functools import partial
from types import MethodType
from typing import Callable

import transformers
Expand Down Expand Up @@ -260,10 +261,16 @@ def apply_liger_kernel_to_llama(

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
if model is not None:
model.forward = MethodType(llama_lce_forward, model)
else:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
if model is not None:
model.forward = MethodType(llama_lce_forward_deprecated, model)
else:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -318,9 +325,15 @@ def apply_liger_kernel_to_llava(
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse("4.52.0"):
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
if model is not None:
model.forward = MethodType(llava_lce_forward, model)
else:
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
if model is not None:
model.forward = MethodType(llava_lce_forward_deprecated, model)
else:
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
else: # if version < 4.49.0
logger.warning(
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
Expand Down Expand Up @@ -490,7 +503,7 @@ def apply_liger_kernel_to_mllama(

if rope:
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
if layer_norm:
if layer_norm and model is None:
modeling_mllama.nn.LayerNorm = LigerLayerNorm
if rms_norm:
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
Expand All @@ -506,10 +519,16 @@ def apply_liger_kernel_to_mllama(
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
if model is not None:
model.forward = MethodType(mllama_lce_forward, model)
else:
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
if model is not None:
model.forward = MethodType(mllama_lce_forward_deprecated, model)
else:
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -592,7 +611,10 @@ def apply_liger_kernel_to_mistral(
if cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
if model is not None:
model.forward = MethodType(mistral_lce_forward, model)
else:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
if swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP

Expand Down Expand Up @@ -660,10 +682,16 @@ def apply_liger_kernel_to_mixtral(

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
if model is not None:
model.forward = MethodType(mixtral_lce_forward, model)
else:
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
if model is not None:
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
else:
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
if swiglu:
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP

Expand Down Expand Up @@ -737,10 +765,16 @@ def apply_liger_kernel_to_gemma(
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
if model is not None:
model.forward = MethodType(gemma_lce_forward, model)
else:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
if model is not None:
model.forward = MethodType(gemma_lce_forward_deprecated, model)
else:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -812,10 +846,16 @@ def apply_liger_kernel_to_gemma2(
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
if model is not None:
model.forward = MethodType(gemma2_lce_forward, model)
else:
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
else:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
if model is not None:
model.forward = MethodType(gemma2_lce_forward_deprected, model)
else:
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
if geglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP

Expand Down Expand Up @@ -894,7 +934,10 @@ def apply_liger_kernel_to_gemma3_text(
nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
if model is not None:
model.forward = MethodType(causal_forward, model)
else:
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -964,7 +1007,7 @@ def apply_liger_kernel_to_gemma3(
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
)

if layer_norm:
if layer_norm and model is None:
modeling_siglip.nn.LayerNorm = LigerLayerNorm

apply_liger_kernel_to_gemma3_text(
Expand All @@ -975,7 +1018,10 @@ def apply_liger_kernel_to_gemma3(
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss

if fused_linear_cross_entropy:
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
if model is not None:
model.forward = MethodType(multimodal_forward, model)
else:
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -1054,7 +1100,7 @@ def apply_liger_kernel_to_paligemma(
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated

# The vision_tower is a SiglipVisionModel
if layer_norm:
if layer_norm and model is None:
modeling_siglip.nn.LayerNorm = LigerLayerNorm

# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
Expand All @@ -1072,10 +1118,16 @@ def apply_liger_kernel_to_paligemma(
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
if model is not None:
model.forward = MethodType(lce_forward, model)
else:
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
if model is not None:
model.forward = MethodType(lce_forward_deprecated, model)
else:
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -1167,10 +1219,16 @@ def apply_liger_kernel_to_qwen2(

if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
if model is not None:
model.forward = MethodType(qwen2_lce_forward, model)
else:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
if model is not None:
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
else:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated

if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
Expand Down Expand Up @@ -1226,7 +1284,10 @@ def apply_liger_kernel_to_qwen3(
nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
if model is not None:
model.forward = MethodType(qwen3_lce_forward, model)
else:
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward

if swiglu:
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
Expand Down Expand Up @@ -1281,7 +1342,10 @@ def apply_liger_kernel_to_qwen3_moe(
nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
if model is not None:
model.forward = MethodType(qwen3_lce_forward, model)
else:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward

if swiglu:
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
Expand Down Expand Up @@ -1350,12 +1414,15 @@ def apply_liger_kernel_to_qwen2_vl(
if rms_norm:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
if layer_norm:
if layer_norm and model is None:
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
if cross_entropy:
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if model is not None:
model.forward = MethodType(qwen2_vl_lce_forward, model)
else:
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP

Expand Down Expand Up @@ -1443,7 +1510,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
if model is not None:
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
else:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
if swiglu:
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP

Expand Down Expand Up @@ -1530,10 +1600,16 @@ def apply_liger_kernel_to_phi3(
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
if model is not None:
model.forward = MethodType(phi3_lce_forward, model)
else:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
if model is not None:
model.forward = MethodType(phi3_lce_forward_deprecated, model)
else:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -1597,7 +1673,10 @@ def apply_liger_kernel_to_olmo2(

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
if model is not None:
model.forward = MethodType(olmo2_lce_forward, model)
else:
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -1661,7 +1740,10 @@ def apply_liger_kernel_to_glm4(

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
if model is not None:
model.forward = MethodType(glm4_lce_forward, model)
else:
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down
Loading