Skip to content

Commit 08906d3

Browse files
authored
Add Cosine Similarity Loss for Distillation (#780)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This is a torch compiled, chunked fused linear Cosine Similarity Loss, aiming for knowledge distillation **Consine Similarity Loss** ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Benchmarks ![distill_cosine_loss_memory_full](https://github.com/user-attachments/assets/baa9fb21-ffb9-4507-9dad-0cbf6ad8ad67) ![distill_cosine_loss_speed_full](https://github.com/user-attachments/assets/123871bb-0863-4bf5-9f5b-4d309013a81b) <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 5e3bf99 commit 08906d3

File tree

6 files changed

+740
-0
lines changed

6 files changed

+740
-0
lines changed

benchmark/data/all_benchmark_data.csv

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,3 +1469,27 @@ fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,512,15
14691469
fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,1024,369.0234375,369.0234375,369.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10
14701470
fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,2048,1176.0234375,1176.0234375,1176.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10
14711471
fused_neighborhood_attention,torch,full,memory,MB,seq_len,sequence length,4096,4332.0234375,4332.0234375,4332.0234375,"{""batch_size"": 2, ""hidden_size"": 512, ""num_heads"": 8, ""kernel_size"": 7, ""dilation"": 2, ""bias"": true, ""dtype"": ""torch.float32""}",NVIDIA H100 80GB HBM3,2025-05-27 19:27:39,0.5.10
1472+
distill_cosine_loss,liger,forward,speed,ms,BT,B x T,1024,13.828096389770508,13.821133041381836,13.885849952697754,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10
1473+
distill_cosine_loss,liger,forward,speed,ms,BT,B x T,2048,27.57427215576172,27.52573432922363,27.579801940917967,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10
1474+
distill_cosine_loss,liger,forward,speed,ms,BT,B x T,4096,54.79423904418945,54.79423904418945,54.79423904418945,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10
1475+
distill_cosine_loss,liger,forward,speed,ms,BT,B x T,8192,109.73490905761719,109.73490905761719,109.73490905761719,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:19:52,0.5.10
1476+
distill_cosine_loss,torch,forward,speed,ms,BT,B x T,1024,16.456703186035156,15.045836448669434,16.761650466918944,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10
1477+
distill_cosine_loss,torch,forward,speed,ms,BT,B x T,2048,29.703168869018555,29.69333839416504,29.71177024841309,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10
1478+
distill_cosine_loss,torch,forward,speed,ms,BT,B x T,4096,59.177982330322266,59.177982330322266,59.177982330322266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10
1479+
distill_cosine_loss,torch,forward,speed,ms,BT,B x T,8192,118.3815689086914,118.3815689086914,118.3815689086914,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:20:34,0.5.10
1480+
distill_cosine_loss,liger,full,speed,ms,BT,B x T,1024,14.654463768005371,14.63398380279541,14.68006420135498,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10
1481+
distill_cosine_loss,liger,full,speed,ms,BT,B x T,2048,28.274688720703125,28.27284507751465,28.279603958129883,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10
1482+
distill_cosine_loss,liger,full,speed,ms,BT,B x T,4096,55.96672058105469,55.96672058105469,55.96672058105469,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10
1483+
distill_cosine_loss,liger,full,speed,ms,BT,B x T,8192,111.38764953613281,111.38764953613281,111.38764953613281,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:21:16,0.5.10
1484+
distill_cosine_loss,torch,full,speed,ms,BT,B x T,1024,37.45382308959961,37.42556076049805,37.482085418701175,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10
1485+
distill_cosine_loss,torch,full,speed,ms,BT,B x T,2048,73.56620788574219,73.56620788574219,73.56620788574219,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10
1486+
distill_cosine_loss,torch,full,speed,ms,BT,B x T,4096,145.73056030273438,145.73056030273438,145.73056030273438,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10
1487+
distill_cosine_loss,torch,full,speed,ms,BT,B x T,8192,291.5000305175781,291.5000305175781,291.5000305175781,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:01,0.5.10
1488+
distill_cosine_loss,liger,full,memory,MB,BT,B x T,1024,5059.26806640625,5059.26806640625,5059.26806640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10
1489+
distill_cosine_loss,liger,full,memory,MB,BT,B x T,2048,5087.27587890625,5087.27587890625,5087.27587890625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10
1490+
distill_cosine_loss,liger,full,memory,MB,BT,B x T,4096,5143.29150390625,5143.29150390625,5143.29150390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10
1491+
distill_cosine_loss,liger,full,memory,MB,BT,B x T,8192,5255.32275390625,5255.32275390625,5255.32275390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:22:43,0.5.10
1492+
distill_cosine_loss,torch,full,memory,MB,BT,B x T,1024,7566.2822265625,7566.2822265625,7566.2822265625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
1493+
distill_cosine_loss,torch,full,memory,MB,BT,B x T,2048,11590.3134765625,11590.3134765625,11590.3134765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
1494+
distill_cosine_loss,torch,full,memory,MB,BT,B x T,4096,19654.375,19654.375,19654.375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
1495+
distill_cosine_loss,torch,full,memory,MB,BT,B x T,8192,35782.5,35782.5,35782.5,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA A100-SXM4-80GB,2025-06-27 09:23:28,0.5.10
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
import torch.nn as nn
6+
import triton
7+
8+
from utils import QUANTILES
9+
from utils import SingleBenchmarkRunInput
10+
from utils import SingleBenchmarkRunOutput
11+
from utils import _test_memory
12+
from utils import parse_benchmark_script_args
13+
from utils import run_benchmarks
14+
15+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
16+
from liger_kernel.utils import infer_device
17+
18+
device = infer_device()
19+
20+
# Ensure the project root is in the path
21+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
22+
23+
24+
class TorchCosineSimilarityLoss(nn.Module):
25+
def __init__(
26+
self,
27+
H: int,
28+
V: int,
29+
dtype: torch.dtype,
30+
weight_hard_loss: float = 0.5,
31+
weight_soft_loss: float = 0.5,
32+
ignore_index: int = -100,
33+
temperature: float = 1.0,
34+
bias: bool = False,
35+
):
36+
from test.chunked_loss.test_cosine_loss import HFCosineLoss
37+
38+
super().__init__()
39+
self.student_lin = nn.Linear(in_features=H // 2, out_features=V, bias=bias).to(dtype=dtype)
40+
self.teacher_lin = nn.Linear(in_features=H, out_features=V, bias=bias).to(dtype=dtype)
41+
self.cosine_loss = HFCosineLoss(
42+
ignore_index=ignore_index,
43+
weight_hard_loss=weight_hard_loss,
44+
weight_soft_loss=weight_soft_loss,
45+
temperature=temperature,
46+
).get_batch_loss_metrics
47+
48+
def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Tensor):
49+
return self.cosine_loss(student, self.student_lin.weight, teacher, self.teacher_lin.weight, target)
50+
51+
52+
class LigerCosineSimilarityLoss(nn.Module):
53+
def __init__(
54+
self,
55+
H: int,
56+
V: int,
57+
dtype: torch.dtype,
58+
weight_hard_loss: float = 0.5,
59+
weight_soft_loss: float = 0.5,
60+
ignore_index: int = -100,
61+
temperature: float = 1.0,
62+
bias: bool = False,
63+
):
64+
super().__init__()
65+
self.student_lin = nn.Linear(in_features=H // 2, out_features=V, bias=bias).to(dtype=dtype)
66+
self.teacher_lin = nn.Linear(in_features=H, out_features=V, bias=bias).to(dtype=dtype)
67+
self.weight_hard_loss = weight_hard_loss
68+
self.weight_soft_loss = weight_soft_loss
69+
self.ignore_index = ignore_index
70+
self.temperature = temperature
71+
self.cosine_loss = LigerFusedLinearCosineSimilarityFunction.apply
72+
73+
def forward(self, student: torch.Tensor, teacher: torch.Tensor, target: torch.Tensor):
74+
return self.cosine_loss(
75+
student,
76+
self.student_lin.weight,
77+
teacher,
78+
self.teacher_lin.weight,
79+
target,
80+
self.student_lin.bias,
81+
self.teacher_lin.bias,
82+
self.weight_hard_loss,
83+
self.weight_soft_loss,
84+
)
85+
86+
87+
def bench_memory_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
88+
BT = input.x
89+
H = input.extra_benchmark_config["H"]
90+
V = input.extra_benchmark_config["V"]
91+
dtype = input.extra_benchmark_config["dtype"]
92+
bias = input.extra_benchmark_config["bias"]
93+
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
94+
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
95+
ignore_index = input.extra_benchmark_config["ignore_index"]
96+
provider = input.kernel_provider
97+
98+
torch_cosine_loss = TorchCosineSimilarityLoss(
99+
H=H,
100+
V=V,
101+
dtype=dtype,
102+
weight_hard_loss=weight_hard_loss,
103+
weight_soft_loss=weight_soft_loss,
104+
bias=bias,
105+
).to(device)
106+
liger_cosine_loss = LigerCosineSimilarityLoss(
107+
H=H,
108+
V=V,
109+
dtype=dtype,
110+
ignore_index=ignore_index,
111+
bias=bias,
112+
weight_hard_loss=weight_hard_loss,
113+
weight_soft_loss=weight_soft_loss,
114+
).to(device)
115+
116+
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
117+
student_input1 = _tensor.detach().clone().requires_grad_(True)
118+
student_input2 = _tensor.detach().clone().requires_grad_(True)
119+
120+
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
121+
122+
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
123+
124+
def fwd():
125+
if provider == "liger":
126+
return liger_cosine_loss(student_input1, teacher_input, target)
127+
elif provider == "torch":
128+
return torch_cosine_loss(student_input2, teacher_input, target)
129+
130+
def full():
131+
y = fwd()
132+
y.backward()
133+
134+
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
135+
return SingleBenchmarkRunOutput(
136+
y_20=mem_20,
137+
y_50=mem_50,
138+
y_80=mem_80,
139+
)
140+
141+
142+
def bench_speed_cosine_similarity_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
143+
BT = input.x
144+
H = input.extra_benchmark_config["H"]
145+
V = input.extra_benchmark_config["V"]
146+
dtype = input.extra_benchmark_config["dtype"]
147+
bias = input.extra_benchmark_config["bias"]
148+
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
149+
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
150+
ignore_index = input.extra_benchmark_config["ignore_index"]
151+
provider = input.kernel_provider
152+
mode = input.kernel_operation_mode
153+
154+
torch_cosine_loss = TorchCosineSimilarityLoss(
155+
H=H,
156+
V=V,
157+
dtype=dtype,
158+
ignore_index=ignore_index,
159+
bias=bias,
160+
weight_hard_loss=weight_hard_loss,
161+
weight_soft_loss=weight_soft_loss,
162+
).to(device)
163+
164+
liger_cosine_loss = LigerCosineSimilarityLoss(
165+
H=H,
166+
V=V,
167+
dtype=dtype,
168+
ignore_index=ignore_index,
169+
bias=bias,
170+
weight_hard_loss=weight_hard_loss,
171+
weight_soft_loss=weight_soft_loss,
172+
).to(device)
173+
174+
_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
175+
student_input1 = _tensor.detach().clone().requires_grad_(True)
176+
student_input2 = _tensor.detach().clone().requires_grad_(True)
177+
178+
teacher_input = torch.rand(BT, H, device=device, dtype=dtype)
179+
180+
target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)
181+
182+
def fwd():
183+
if provider == "liger":
184+
return liger_cosine_loss(student_input1, teacher_input, target)
185+
elif provider == "torch":
186+
return torch_cosine_loss(student_input2, teacher_input, target)
187+
188+
if mode == "forward":
189+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
190+
fwd,
191+
rep=100,
192+
quantiles=QUANTILES,
193+
)
194+
elif mode == "backward":
195+
y = fwd()
196+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
197+
fwd,
198+
rep=100,
199+
quantiles=QUANTILES,
200+
)
201+
elif mode == "backward":
202+
y = fwd()
203+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
204+
lambda: y.backward(retain_graph=True),
205+
grad_to_none=[student_input1, student_input2],
206+
rep=100,
207+
quantiles=QUANTILES,
208+
)
209+
elif mode == "full":
210+
211+
def full():
212+
y = fwd()
213+
y.backward()
214+
215+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
216+
full,
217+
rep=100,
218+
quantiles=QUANTILES,
219+
)
220+
221+
return SingleBenchmarkRunOutput(
222+
y_20=ms_20,
223+
y_50=ms_50,
224+
y_80=ms_80,
225+
)
226+
227+
228+
if __name__ == "__main__":
229+
args = parse_benchmark_script_args()
230+
231+
common_configs = {
232+
"kernel_name": "distill_cosine_loss",
233+
"x_name": "BT",
234+
"x_label": "B x T",
235+
"x_values": [2**i for i in range(10, 14)],
236+
"kernel_providers": ["liger", "torch"],
237+
"extra_benchmark_configs": [
238+
{
239+
"H": 4096,
240+
"V": 128256,
241+
"mode": "forward",
242+
"dtype": torch.bfloat16,
243+
"bias": False,
244+
"weight_hard_loss": 0.5,
245+
"weight_soft_loss": 0.5,
246+
"ignore_index": -100,
247+
}
248+
],
249+
"overwrite": args.overwrite,
250+
}
251+
252+
run_benchmarks(
253+
bench_test_fn=bench_speed_cosine_similarity_loss,
254+
kernel_operation_modes=["forward", "full"],
255+
metric_name="speed",
256+
metric_unit="ms",
257+
**common_configs,
258+
)
259+
260+
run_benchmarks(
261+
bench_test_fn=bench_memory_cosine_similarity_loss,
262+
kernel_operation_modes=["full"],
263+
metric_name="memory",
264+
metric_unit="MB",
265+
**common_configs,
266+
)

src/liger_kernel/chunked_loss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
12
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
23
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
34
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401

0 commit comments

Comments
 (0)