You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
apply monkey patch to instance instead of class method to avoid conflicts in mixed usage scenario (#772)
In some cases, **revert_liger_kernel_to_XX** can’t restore the
replacement of **model.forward** by **monkey_patch**. For scenarios
involving a passed model instance, modifying the assignment logic of
**XX_lce_forward** can prevent such issues.
Using **llama** as an example to update the **monkey_patch.py**. Then
**revert_liger_kernel_to_llama** logic can restore all environment
module configurations.
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
When I use **pytest** to test the **test_sft_trainer_with_liger** and
**test_train_offloading** functions from
[test_sft_slow.py](https://github.com/huggingface/trl/blob/main/tests/slow/test_sft_slow.py)
under the same process, the monkey patch applied by the liger kernel in
the first executed function affects the subsequent function's test.
Even though I used the **revert_liger_kernel_to_XX** logic to restore
the module configuration of the environment, I found that
**model.forward** still retains a reference to **XX_lce_forward**.
When the model passed to **apply_liger_kernel_to_XX** is an instance,
using **MethodType** to modify only the forward method of the
instantiated model can avoid this issue.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
I wrote a simple test script to reproduce the bug:
```
from transformers import AutoModelForCausalLM
from transformers.utils import is_liger_kernel_available
from transformers.models.llama import modeling_llama
import importlib
if __name__ == '__main__':
model_path = 'trl-internal-testing/tiny-LlamaForCausalLM-3.2'
model_init_kwargs = {}
print(f"ori_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, ori_forward:{modeling_llama.LlamaForCausalLM.forward}")
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=model)
print(f"liger_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, liger_forward:{modeling_llama.LlamaForCausalLM.forward}")
print(f"liger_model:{model.forward}")
importlib.reload(modeling_llama)
print(f"reload_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, reload_forward:{modeling_llama.LlamaForCausalLM.forward}")
model_reload = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
print(f"reload_model:{model_reload.forward}")
```
The above code produces the following output:
```
ori_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, ori_forward:<function LlamaForCausalLM.forward at 0x7f7f3321eca0>
liger_LlamaRMSNorm:<class 'liger_kernel.transformers.rms_norm.LigerRMSNorm'>, liger_forward:<function lce_forward at 0x7f7e26382160>
liger_model:<bound method lce_forward of LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8)
(layers): ModuleList(
(0-1): 2 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8, out_features=8, bias=False)
(k_proj): Linear(in_features=8, out_features=4, bias=False)
(v_proj): Linear(in_features=8, out_features=4, bias=False)
(o_proj): Linear(in_features=8, out_features=8, bias=False)
)
(mlp): LigerSwiGLUMLP(
(gate_proj): Linear(in_features=8, out_features=32, bias=False)
(up_proj): Linear(in_features=8, out_features=32, bias=False)
(down_proj): Linear(in_features=32, out_features=8, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
(post_attention_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
)
)
(norm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8, out_features=128256, bias=False)
)>
reload_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, reload_forward:<function LlamaForCausalLM.forward at 0x7f7e263be700>
reload_model:<bound method lce_forward of LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8)
(layers): ModuleList(
(0-1): 2 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8, out_features=8, bias=False)
(k_proj): Linear(in_features=8, out_features=4, bias=False)
(v_proj): Linear(in_features=8, out_features=4, bias=False)
(o_proj): Linear(in_features=8, out_features=8, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8, out_features=32, bias=False)
(up_proj): Linear(in_features=8, out_features=32, bias=False)
(down_proj): Linear(in_features=32, out_features=8, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((8,), eps=1e-06)
(post_attention_layernorm): LlamaRMSNorm((8,), eps=1e-06)
)
)
(norm): LlamaRMSNorm((8,), eps=1e-06)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8, out_features=128256, bias=False)
)>
```
Obviously, ```reload_model: <bound method lce_forward of
LlamaForCausalLM(``` should not be lce_forward here.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
After modifying monkey_patch.py, reload_model can correctly output as
```reload_model: <bound method LlamaForCausalLM.forward of
LlamaForCausalLM(```.
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
XPU or CUDA
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
0 commit comments