-
Notifications
You must be signed in to change notification settings - Fork 410
Description
🐛 Describe the bug
This issue is to discuss how we should modify our convergence test to handle numerical issues of logits.
Context
In #704, we make FusedLinearCrossEntropy
(flce) patched models have the freedom to switch between materializing logits and skipping logits. It allows us to compare logits from the last step again in flce convergence test, test/convergence/(bf16|fp32)/test_mini_models.py
.
However, this change also exposed the hidden numerical issue that we didn't take care of.
The reasons why tests failed can be categorized into two types:
- numerical instability of floating points due to the limitation of bf16 representation, rounding errors, nonassociativity of floating point calculation, etc.
- actual bugs inside our implementation.
Ideally, we want to ignore type1 and handle type2 errors. But in reality, best we can do to ignore type1 errors is just increasing tolerences, which also increases the difficulty to capture type2 errors, the actual bugs in our kernels.
Existed methods
Before discussing how we can modify our convergence tests to address the issue, we can take a look of how other mainstream libraries handle it:
sglang
techniques:
calculate the ROUGE-L score
compare top logprobs
vllm
techniques:
compare logprob only when token id mismatch
verl
techniques:
set num_hidden_layers=1
to lower discrepancy
compare edit: it's more like cross entropy lossmean_of_logprob
instead of top logprob only
Discussion
With above examples, we can see that instead of comparing raw logits, doing it in logprob space is more common. We should definitely do it without a doubt.
The other options are open for discussion:
- mean logprob or top logprob?
- setting num_hidden_layers=1 can speed up testing while minimizing the effect of numerical issue, but is one layer sufficient to represent model?
- forced to match top token?
Goal
A reliable convergence test with logits comparison, so we can distinguish whether the failure is due to numerical issue or actual bugs and no need to frequently adjust tolerances.
Known failure models (including passed models with relaxed tolerances, mark *)
Feel free to report as long as you find any model having logits mismatch. I'll add them to the list.
- gemma
- gemma2
- gemma3*
- granite3*
- qwen3_moe*
- qwen2_vl*
- qwen2_5_vl*
- mixtral (disabled for a long time)
cc @lancerts @yundai424 @qingquansong @shivam15s @shimizust @vaibhavjindal @austin362667
Reproduce
No response
Versions
liger_kernel==4a2da99