Skip to content

Commit 5e3bf99

Browse files
authored
apply monkey patch to instance instead of class method to avoid conflicts in mixed usage scenario (#772)
In some cases, **revert_liger_kernel_to_XX** can’t restore the replacement of **model.forward** by **monkey_patch**. For scenarios involving a passed model instance, modifying the assignment logic of **XX_lce_forward** can prevent such issues. Using **llama** as an example to update the **monkey_patch.py**. Then **revert_liger_kernel_to_llama** logic can restore all environment module configurations. ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> When I use **pytest** to test the **test_sft_trainer_with_liger** and **test_train_offloading** functions from [test_sft_slow.py](https://github.com/huggingface/trl/blob/main/tests/slow/test_sft_slow.py) under the same process, the monkey patch applied by the liger kernel in the first executed function affects the subsequent function's test. Even though I used the **revert_liger_kernel_to_XX** logic to restore the module configuration of the environment, I found that **model.forward** still retains a reference to **XX_lce_forward**. When the model passed to **apply_liger_kernel_to_XX** is an instance, using **MethodType** to modify only the forward method of the instantiated model can avoid this issue. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> I wrote a simple test script to reproduce the bug: ``` from transformers import AutoModelForCausalLM from transformers.utils import is_liger_kernel_available from transformers.models.llama import modeling_llama import importlib if __name__ == '__main__': model_path = 'trl-internal-testing/tiny-LlamaForCausalLM-3.2' model_init_kwargs = {} print(f"ori_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, ori_forward:{modeling_llama.LlamaForCausalLM.forward}") model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) if is_liger_kernel_available(): from liger_kernel.transformers import _apply_liger_kernel_to_instance _apply_liger_kernel_to_instance(model=model) print(f"liger_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, liger_forward:{modeling_llama.LlamaForCausalLM.forward}") print(f"liger_model:{model.forward}") importlib.reload(modeling_llama) print(f"reload_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, reload_forward:{modeling_llama.LlamaForCausalLM.forward}") model_reload = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) print(f"reload_model:{model_reload.forward}") ``` The above code produces the following output: ``` ori_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, ori_forward:<function LlamaForCausalLM.forward at 0x7f7f3321eca0> liger_LlamaRMSNorm:<class 'liger_kernel.transformers.rms_norm.LigerRMSNorm'>, liger_forward:<function lce_forward at 0x7f7e26382160> liger_model:<bound method lce_forward of LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(128256, 8) (layers): ModuleList( (0-1): 2 x LlamaDecoderLayer( (self_attn): LlamaAttention( (q_proj): Linear(in_features=8, out_features=8, bias=False) (k_proj): Linear(in_features=8, out_features=4, bias=False) (v_proj): Linear(in_features=8, out_features=4, bias=False) (o_proj): Linear(in_features=8, out_features=8, bias=False) ) (mlp): LigerSwiGLUMLP( (gate_proj): Linear(in_features=8, out_features=32, bias=False) (up_proj): Linear(in_features=8, out_features=32, bias=False) (down_proj): Linear(in_features=32, out_features=8, bias=False) (act_fn): SiLU() ) (input_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True) (post_attention_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True) ) ) (norm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=8, out_features=128256, bias=False) )> reload_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, reload_forward:<function LlamaForCausalLM.forward at 0x7f7e263be700> reload_model:<bound method lce_forward of LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(128256, 8) (layers): ModuleList( (0-1): 2 x LlamaDecoderLayer( (self_attn): LlamaAttention( (q_proj): Linear(in_features=8, out_features=8, bias=False) (k_proj): Linear(in_features=8, out_features=4, bias=False) (v_proj): Linear(in_features=8, out_features=4, bias=False) (o_proj): Linear(in_features=8, out_features=8, bias=False) ) (mlp): LlamaMLP( (gate_proj): Linear(in_features=8, out_features=32, bias=False) (up_proj): Linear(in_features=8, out_features=32, bias=False) (down_proj): Linear(in_features=32, out_features=8, bias=False) (act_fn): SiLU() ) (input_layernorm): LlamaRMSNorm((8,), eps=1e-06) (post_attention_layernorm): LlamaRMSNorm((8,), eps=1e-06) ) ) (norm): LlamaRMSNorm((8,), eps=1e-06) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=8, out_features=128256, bias=False) )> ``` Obviously, ```reload_model: <bound method lce_forward of LlamaForCausalLM(``` should not be lce_forward here. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> After modifying monkey_patch.py, reload_model can correctly output as ```reload_model: <bound method LlamaForCausalLM.forward of LlamaForCausalLM(```. <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> XPU or CUDA
1 parent 2c766fe commit 5e3bf99

File tree

2 files changed

+202
-32
lines changed

2 files changed

+202
-32
lines changed

src/liger_kernel/transformers/monkey_patch.py

100644100755
Lines changed: 113 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33

44
from functools import partial
5+
from types import MethodType
56
from typing import Callable
67

78
import transformers
@@ -260,10 +261,16 @@ def apply_liger_kernel_to_llama(
260261

261262
if fused_linear_cross_entropy:
262263
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
263-
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264+
if model is not None:
265+
model.forward = MethodType(llama_lce_forward, model)
266+
else:
267+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264268
else: # if version < 4.46.1
265269
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
266-
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
270+
if model is not None:
271+
model.forward = MethodType(llama_lce_forward_deprecated, model)
272+
else:
273+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
267274

268275
if model is not None:
269276
# The model instance already exists, so we need to additionally patch the
@@ -318,9 +325,15 @@ def apply_liger_kernel_to_llava(
318325
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
319326
if fused_linear_cross_entropy:
320327
if transformer_version >= version.parse("4.52.0"):
321-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
328+
if model is not None:
329+
model.forward = MethodType(llava_lce_forward, model)
330+
else:
331+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
322332
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
323-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
333+
if model is not None:
334+
model.forward = MethodType(llava_lce_forward_deprecated, model)
335+
else:
336+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
324337
else: # if version < 4.49.0
325338
logger.warning(
326339
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
@@ -490,7 +503,7 @@ def apply_liger_kernel_to_mllama(
490503

491504
if rope:
492505
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
493-
if layer_norm:
506+
if layer_norm and model is None:
494507
modeling_mllama.nn.LayerNorm = LigerLayerNorm
495508
if rms_norm:
496509
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -506,10 +519,16 @@ def apply_liger_kernel_to_mllama(
506519
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
507520
if fused_linear_cross_entropy:
508521
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
509-
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
522+
if model is not None:
523+
model.forward = MethodType(mllama_lce_forward, model)
524+
else:
525+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
510526
else: # if version < 4.46.1
511527
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
512-
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
528+
if model is not None:
529+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
530+
else:
531+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
513532

514533
if model is not None:
515534
# The model instance already exists, so we need to additionally patch the
@@ -592,7 +611,10 @@ def apply_liger_kernel_to_mistral(
592611
if cross_entropy:
593612
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
594613
if fused_linear_cross_entropy:
595-
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
614+
if model is not None:
615+
model.forward = MethodType(mistral_lce_forward, model)
616+
else:
617+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
596618
if swiglu:
597619
modeling_mistral.MistralMLP = LigerSwiGLUMLP
598620

@@ -660,10 +682,16 @@ def apply_liger_kernel_to_mixtral(
660682

661683
if fused_linear_cross_entropy:
662684
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
663-
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
685+
if model is not None:
686+
model.forward = MethodType(mixtral_lce_forward, model)
687+
else:
688+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
664689
else: # if version < 4.46.1
665690
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
666-
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
691+
if model is not None:
692+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
693+
else:
694+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
667695
if swiglu:
668696
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
669697

@@ -737,10 +765,16 @@ def apply_liger_kernel_to_gemma(
737765
modeling_gemma.GemmaMLP = LigerGEGLUMLP
738766
if fused_linear_cross_entropy:
739767
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
740-
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
768+
if model is not None:
769+
model.forward = MethodType(gemma_lce_forward, model)
770+
else:
771+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
741772
else: # if version < 4.46.1
742773
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
743-
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
774+
if model is not None:
775+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
776+
else:
777+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
744778

745779
if model is not None:
746780
# The model instance already exists, so we need to additionally patch the
@@ -812,10 +846,16 @@ def apply_liger_kernel_to_gemma2(
812846
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
813847
if fused_linear_cross_entropy:
814848
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
815-
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
849+
if model is not None:
850+
model.forward = MethodType(gemma2_lce_forward, model)
851+
else:
852+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
816853
else:
817854
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
818-
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
855+
if model is not None:
856+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
857+
else:
858+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
819859
if geglu:
820860
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
821861

@@ -894,7 +934,10 @@ def apply_liger_kernel_to_gemma3_text(
894934
nn.functional.cross_entropy = liger_cross_entropy
895935

896936
if fused_linear_cross_entropy:
897-
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
937+
if model is not None:
938+
model.forward = MethodType(causal_forward, model)
939+
else:
940+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
898941

899942
if model is not None:
900943
# The model instance already exists, so we need to additionally patch the
@@ -964,7 +1007,7 @@ def apply_liger_kernel_to_gemma3(
9641007
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
9651008
)
9661009

967-
if layer_norm:
1010+
if layer_norm and model is None:
9681011
modeling_siglip.nn.LayerNorm = LigerLayerNorm
9691012

9701013
apply_liger_kernel_to_gemma3_text(
@@ -975,7 +1018,10 @@ def apply_liger_kernel_to_gemma3(
9751018
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
9761019

9771020
if fused_linear_cross_entropy:
978-
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1021+
if model is not None:
1022+
model.forward = MethodType(multimodal_forward, model)
1023+
else:
1024+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
9791025

9801026
if model is not None:
9811027
# The model instance already exists, so we need to additionally patch the
@@ -1054,7 +1100,7 @@ def apply_liger_kernel_to_paligemma(
10541100
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
10551101

10561102
# The vision_tower is a SiglipVisionModel
1057-
if layer_norm:
1103+
if layer_norm and model is None:
10581104
modeling_siglip.nn.LayerNorm = LigerLayerNorm
10591105

10601106
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -1072,10 +1118,16 @@ def apply_liger_kernel_to_paligemma(
10721118
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
10731119
if fused_linear_cross_entropy:
10741120
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1075-
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1121+
if model is not None:
1122+
model.forward = MethodType(lce_forward, model)
1123+
else:
1124+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
10761125
else: # if version < 4.46.1
10771126
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1078-
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1127+
if model is not None:
1128+
model.forward = MethodType(lce_forward_deprecated, model)
1129+
else:
1130+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
10791131

10801132
if model is not None:
10811133
# The model instance already exists, so we need to additionally patch the
@@ -1167,10 +1219,16 @@ def apply_liger_kernel_to_qwen2(
11671219

11681220
if fused_linear_cross_entropy:
11691221
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1170-
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1222+
if model is not None:
1223+
model.forward = MethodType(qwen2_lce_forward, model)
1224+
else:
1225+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
11711226
else: # if version < 4.46.1
11721227
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1173-
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1228+
if model is not None:
1229+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1230+
else:
1231+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
11741232

11751233
if swiglu:
11761234
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1226,7 +1284,10 @@ def apply_liger_kernel_to_qwen3(
12261284
nn.functional.cross_entropy = liger_cross_entropy
12271285

12281286
if fused_linear_cross_entropy:
1229-
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1287+
if model is not None:
1288+
model.forward = MethodType(qwen3_lce_forward, model)
1289+
else:
1290+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
12301291

12311292
if swiglu:
12321293
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1281,7 +1342,10 @@ def apply_liger_kernel_to_qwen3_moe(
12811342
nn.functional.cross_entropy = liger_cross_entropy
12821343

12831344
if fused_linear_cross_entropy:
1284-
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1345+
if model is not None:
1346+
model.forward = MethodType(qwen3_lce_forward, model)
1347+
else:
1348+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
12851349

12861350
if swiglu:
12871351
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
@@ -1350,12 +1414,15 @@ def apply_liger_kernel_to_qwen2_vl(
13501414
if rms_norm:
13511415
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
13521416
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1353-
if layer_norm:
1417+
if layer_norm and model is None:
13541418
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
13551419
if cross_entropy:
13561420
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
13571421
if fused_linear_cross_entropy:
1358-
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1422+
if model is not None:
1423+
model.forward = MethodType(qwen2_vl_lce_forward, model)
1424+
else:
1425+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
13591426
if swiglu:
13601427
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
13611428

@@ -1443,7 +1510,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
14431510
if cross_entropy:
14441511
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
14451512
if fused_linear_cross_entropy:
1446-
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1513+
if model is not None:
1514+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1515+
else:
1516+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
14471517
if swiglu:
14481518
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
14491519

@@ -1530,10 +1600,16 @@ def apply_liger_kernel_to_phi3(
15301600
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
15311601
if fused_linear_cross_entropy:
15321602
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1533-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1603+
if model is not None:
1604+
model.forward = MethodType(phi3_lce_forward, model)
1605+
else:
1606+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
15341607
else: # if version < 4.46.1
15351608
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1536-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1609+
if model is not None:
1610+
model.forward = MethodType(phi3_lce_forward_deprecated, model)
1611+
else:
1612+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
15371613

15381614
if model is not None:
15391615
# The model instance already exists, so we need to additionally patch the
@@ -1597,7 +1673,10 @@ def apply_liger_kernel_to_olmo2(
15971673

15981674
nn.functional.cross_entropy = liger_cross_entropy
15991675
if fused_linear_cross_entropy:
1600-
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1676+
if model is not None:
1677+
model.forward = MethodType(olmo2_lce_forward, model)
1678+
else:
1679+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
16011680

16021681
if model is not None:
16031682
# The model instance already exists, so we need to additionally patch the
@@ -1661,7 +1740,10 @@ def apply_liger_kernel_to_glm4(
16611740

16621741
nn.functional.cross_entropy = liger_cross_entropy
16631742
if fused_linear_cross_entropy:
1664-
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1743+
if model is not None:
1744+
model.forward = MethodType(glm4_lce_forward, model)
1745+
else:
1746+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
16651747

16661748
if model is not None:
16671749
# The model instance already exists, so we need to additionally patch the

0 commit comments

Comments
 (0)