Skip to content

[RFC] Logits numerical issues in convergence test #742

@Tcc0403

Description

@Tcc0403

🐛 Describe the bug

This issue is to discuss how we should modify our convergence test to handle numerical issues of logits.

Context

In #704, we make FusedLinearCrossEntropy(flce) patched models have the freedom to switch between materializing logits and skipping logits. It allows us to compare logits from the last step again in flce convergence test, test/convergence/(bf16|fp32)/test_mini_models.py.

However, this change also exposed the hidden numerical issue that we didn't take care of.

The reasons why tests failed can be categorized into two types:

  1. numerical instability of floating points due to the limitation of bf16 representation, rounding errors, nonassociativity of floating point calculation, etc.
  2. actual bugs inside our implementation.

Ideally, we want to ignore type1 and handle type2 errors. But in reality, best we can do to ignore type1 errors is just increasing tolerences, which also increases the difficulty to capture type2 errors, the actual bugs in our kernels.

Existed methods

Before discussing how we can modify our convergence tests to address the issue, we can take a look of how other mainstream libraries handle it:

sglang

check_close_model_outputs()

techniques:
calculate the ROUGE-L score
compare top logprobs

vllm

check_logprobs_close()

techniques:
compare logprob only when token id mismatch

verl

test_hf_causal_models()

techniques:
set num_hidden_layers=1 to lower discrepancy
compare mean_of_logprob instead of top logprob only edit: it's more like cross entropy loss

Discussion

With above examples, we can see that instead of comparing raw logits, doing it in logprob space is more common. We should definitely do it without a doubt.

The other options are open for discussion:

  1. mean logprob or top logprob?
  2. setting num_hidden_layers=1 can speed up testing while minimizing the effect of numerical issue, but is one layer sufficient to represent model?
  3. forced to match top token?

Goal

A reliable convergence test with logits comparison, so we can distinguish whether the failure is due to numerical issue or actual bugs and no need to frequently adjust tolerances.

Known failure models (including passed models with relaxed tolerances, mark *)

Feel free to report as long as you find any model having logits mismatch. I'll add them to the list.

  • gemma
  • gemma2
  • gemma3*
  • granite3*
  • qwen3_moe*
  • qwen2_vl*
  • qwen2_5_vl*
  • mixtral (disabled for a long time)

cc @lancerts @yundai424 @qingquansong @shivam15s @shimizust @vaibhavjindal @austin362667

Reproduce

No response

Versions

liger_kernel==4a2da99

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions