Skip to content

Conversation

YangKai0616
Copy link
Contributor

@YangKai0616 YangKai0616 commented Jun 24, 2025

In some cases, revert_liger_kernel_to_XX can’t restore the replacement of model.forward by monkey_patch. For scenarios involving a passed model instance, modifying the assignment logic of XX_lce_forward can prevent such issues.

Using llama as an example to update the monkey_patch.py. Then revert_liger_kernel_to_llama logic can restore all environment module configurations.

Summary

When I use pytest to test the test_sft_trainer_with_liger and test_train_offloading functions from test_sft_slow.py under the same process, the monkey patch applied by the liger kernel in the first executed function affects the subsequent function's test.

Even though I used the revert_liger_kernel_to_XX logic to restore the module configuration of the environment, I found that model.forward still retains a reference to XX_lce_forward.

When the model passed to apply_liger_kernel_to_XX is an instance, using MethodType to modify only the forward method of the instantiated model can avoid this issue.

I wrote a simple test script to reproduce the bug:

from transformers import AutoModelForCausalLM
from transformers.utils import is_liger_kernel_available
from transformers.models.llama import modeling_llama
import importlib

if __name__ == '__main__':
    model_path = 'trl-internal-testing/tiny-LlamaForCausalLM-3.2'
    model_init_kwargs = {}
    print(f"ori_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, ori_forward:{modeling_llama.LlamaForCausalLM.forward}")
    model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)

    if is_liger_kernel_available():
        from liger_kernel.transformers import _apply_liger_kernel_to_instance
        _apply_liger_kernel_to_instance(model=model)
        print(f"liger_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, liger_forward:{modeling_llama.LlamaForCausalLM.forward}")
        print(f"liger_model:{model.forward}")
    
    importlib.reload(modeling_llama)
    print(f"reload_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, reload_forward:{modeling_llama.LlamaForCausalLM.forward}")
    model_reload = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
    print(f"reload_model:{model_reload.forward}")

The above code produces the following output:

ori_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, ori_forward:<function LlamaForCausalLM.forward at 0x7f7f3321eca0>
liger_LlamaRMSNorm:<class 'liger_kernel.transformers.rms_norm.LigerRMSNorm'>, liger_forward:<function lce_forward at 0x7f7e26382160>
liger_model:<bound method lce_forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 8)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=8, out_features=8, bias=False)
          (k_proj): Linear(in_features=8, out_features=4, bias=False)
          (v_proj): Linear(in_features=8, out_features=4, bias=False)
          (o_proj): Linear(in_features=8, out_features=8, bias=False)
        )
        (mlp): LigerSwiGLUMLP(
          (gate_proj): Linear(in_features=8, out_features=32, bias=False)
          (up_proj): Linear(in_features=8, out_features=32, bias=False)
          (down_proj): Linear(in_features=32, out_features=8, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
        (post_attention_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
      )
    )
    (norm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=8, out_features=128256, bias=False)
)>
reload_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, reload_forward:<function LlamaForCausalLM.forward at 0x7f7e263be700>
reload_model:<bound method lce_forward of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 8)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=8, out_features=8, bias=False)
          (k_proj): Linear(in_features=8, out_features=4, bias=False)
          (v_proj): Linear(in_features=8, out_features=4, bias=False)
          (o_proj): Linear(in_features=8, out_features=8, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=8, out_features=32, bias=False)
          (up_proj): Linear(in_features=8, out_features=32, bias=False)
          (down_proj): Linear(in_features=32, out_features=8, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((8,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((8,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((8,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=8, out_features=128256, bias=False)
)>

Obviously, reload_model: <bound method lce_forward of LlamaForCausalLM( should not be lce_forward here.

Testing Done

After modifying monkey_patch.py, reload_model can correctly output as reload_model: <bound method LlamaForCausalLM.forward of LlamaForCausalLM(.

XPU or CUDA

Taking llama as an example, _apply_liger_kernel_to_instance and revert_liger_kernel_to_llama can restore all environment modules.
@YangKai0616
Copy link
Contributor Author

@kaixuanliu

@YangKai0616 YangKai0616 changed the title In some cases, revert_liger_kernel_to_XX can’t restore the replacement of model.forward by monkey_patch. For scenarios involving a passed model instance, modifying the assignment logic of XX_lce_forward can prevent such issues. apply monkey patch to instance instead of class method to avoid conflicts in mixed usage scenario Jun 25, 2025
@YangKai0616 YangKai0616 marked this pull request as ready for review June 25, 2025 01:28
@YangKai0616 YangKai0616 reopened this Jun 25, 2025
@YangKai0616
Copy link
Contributor Author

@shimizust, @Tcc0403 and @vaibhavjindal pls help review.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 25, 2025

@YangKai0616 Thanks! Great catch!

Although there are some parts of monkey patch that are difficult to revert, patching lce_forward to any model instance is indeed something we can fix right now. Can you also apply your solution to all models and add model.forward() checks in test_apply_liger_kernel_to_instance_for_xxx as well?

By the way, I'm aware that there are multiple scenarios where revert_liger_kernel_xxx() functions won't work, so I wouldn't recommend relying on it at all.

@YangKai0616
Copy link
Contributor Author

@Tcc0403 Hi, I have modified the code as per your requirements and fixed some bugs along the way. Please review the code. The attached image shows the results of my test_monkey_patch.py test.

image

Since I need to test the functionality of different libraries in the same process, including liger_kernel. It is inevitable to restore the module configuration of the environment. Thx!

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 25, 2025

Thank you for the quick updates!

I have modified the code as per your requirements and fixed some bugs along the way.

Would you like to share what bugs you ran into? It could be a reflection of our test deisgn. Are they existing bugs before this PR?

Since I need to test the functionality of different libraries in the same process, including liger_kernel. It is inevitable to restore the module configuration of the environment. Thx!

Sorry if I got you confused. I was trying to explain the revert functions in our library has some known limitations against patching class modules, so one should not rely on them. However, patching and reverting model instances should leave NO side effects, as you proposed. Any fixes or suggestions working towards this goal are welcome!

Your PR can be a great starting point for us to examine our patching methods along with testings.

@YangKai0616
Copy link
Contributor Author

@Tcc0403 Hi, the bug encountered was still caused by changes to the global module configuration. During testing, I modified this line in the code. Without adding this patch, this section would be altered, affecting subsequent tests, specifically test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation and test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl. This is because these tests rely on modules that were modified but not restored during model loading.

I think this bug already existed before this PR was submitted.

Additionally, the apply_liger_kernel_to_paligemma also applies this patch, but it was not tested.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 26, 2025

@Tcc0403 Hi, the bug encountered was still caused by changes to the global module configuration. During testing, I modified this line in the code. Without adding this patch, this section would be altered, affecting subsequent tests, specifically test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation and test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl. This is because these tests rely on modules that were modified but not restored during model loading.

Do you think we need to patch layer_norm on model instances likewise? Since we are patching nn.Layernorm, it seems that all LayerNorm would be replaced with LigerLayerNorm in mixed usage scenario.

Additionally, the apply_liger_kernel_to_paligemma also applies this patch, but it was not tested.

Noted. Thanks for your share.

@YangKai0616
Copy link
Contributor Author

image
image

Do you think we need to patch layer_norm on model instances likewise? Since we are patching nn.Layernorm, it seems that all LayerNorm would be replaced with LigerLayerNorm in mixed usage scenario.

I think this is necessary. My test results show that importlib.reload(torch.nn) does not revert the patch's effect on layer_norm. You can see that the newly loaded model is still LigerLayerNorm. This is the same issue as with the previous forward problem.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 27, 2025

I think this is necessary. My test results show that importlib.reload(torch.nn) does not revert the patch's effect on layer_norm. You can see that the newly loaded model is still LigerLayerNorm. This is the same issue as with the previous forward problem.

Thanks for the test. Let's handle it in this PR as well.

@YangKai0616
Copy link
Contributor Author

Thanks for the test. Let's handle it in this PR as well.

Hi,please review the latest commit.

Tcc0403
Tcc0403 previously approved these changes Jun 29, 2025
Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@YangKai0616
Copy link
Contributor Author

Hi @vaibhavjindal , pls help approve the execution of workflows.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 30, 2025

@YangKai0616 Fix the coding style with make checkstyle then we can merge it

@YangKai0616
Copy link
Contributor Author

YangKai0616 commented Jun 30, 2025

@YangKai0616 Fix the coding style with make checkstyle then we can merge it

@vaibhavjindal and @Tcc0403, Sorry for my oversight. Please review it again.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 30, 2025

@YangKai0616 It seems dev/modal/benchmarks.py accidently inlcuded in your change. Can you revert it?

I'll fix its coding style in another PR.

@Tcc0403 Tcc0403 enabled auto-merge (squash) June 30, 2025 15:52
@Tcc0403 Tcc0403 disabled auto-merge June 30, 2025 15:52
@YangKai0616
Copy link
Contributor Author

@YangKai0616 It seems dev/modal/benchmarks.py accidently inlcuded in your change. Can you revert it?

Maybe the reason for the error is CI. I didn't change the dev/modal/benchmark.py file in the first CI, but it still went wrong. And I downloaded the code of the main branch, then executed make checkstyle and it reported the same error.
image

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jun 30, 2025

Maybe the reason for the error is CI. I didn't change the dev/modal/benchmark.py file in the first CI, but it still went wrong.

@YangKai0616 Yes there's our mistake.

Thanks for the contribution!

@Tcc0403 Tcc0403 merged commit 5e3bf99 into linkedin:main Jun 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants