Skip to content

Conversation

Tarakarevu1
Copy link
Contributor

@Tarakarevu1 Tarakarevu1 commented Apr 10, 2025

Summary

This change is related to performance tuning on the Intel Max 1550 GPUs. By keeping the block and warp sizes the same in the forward and backward Triton kernels.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
)
if X.device.type == "xpu": # XPU-specific optimization
BLOCK_SIZE = torch.xpu.get_device_properties(X.device).max_work_group_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik we want the block size here to be triton.next_power_of_2(n). torch.xpu.get_device_properties(X.device).max_work_group_size is supposedly giving the max possible block size on the xpu device. Better way to handle this is to change the calculate_settings function directly where you set the MAX_FUSED_SIZE according to the device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivam15s These changes are specific for the layernorm and rmsnorm.That's why i made the changes locally and this will impact other kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tarakarevu1 how much is the perf improvment of layernorm and rmsnorm after this change?

else:
BLOCK_SIZE, num_warps = calculate_settings(n_cols)

if X.device.type == "xpu": # XPU-specific optimization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tarakarevu1 I think this is not an optimization but rather a guardrail to avoid spilling beyond 64 KB ?
LGTM on the code changes although the runtime behaviour error can be tweaked to show more xpu usage ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants