-
Notifications
You must be signed in to change notification settings - Fork 407
Modified block and warp sizes for improved performance on XPU for both layernnorm and rmsnorm #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." | ||
) | ||
if X.device.type == "xpu": # XPU-specific optimization | ||
BLOCK_SIZE = torch.xpu.get_device_properties(X.device).max_work_group_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik we want the block size here to be triton.next_power_of_2(n)
. torch.xpu.get_device_properties(X.device).max_work_group_size
is supposedly giving the max possible block size on the xpu device. Better way to handle this is to change the calculate_settings function directly where you set the MAX_FUSED_SIZE
according to the device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shivam15s These changes are specific for the layernorm and rmsnorm.That's why i made the changes locally and this will impact other kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Tarakarevu1 how much is the perf improvment of layernorm and rmsnorm after this change?
else: | ||
BLOCK_SIZE, num_warps = calculate_settings(n_cols) | ||
|
||
if X.device.type == "xpu": # XPU-specific optimization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Tarakarevu1 I think this is not an optimization but rather a guardrail to avoid spilling beyond 64 KB ?
LGTM on the code changes although the runtime behaviour error can be tweaked to show more xpu usage ?
Summary
This change is related to performance tuning on the Intel Max 1550 GPUs. By keeping the block and warp sizes the same in the forward and backward Triton kernels.
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence